diff --git a/swh/web/tests/api/test_throttling.py b/swh/web/tests/api/test_throttling.py index 3a8c3a8d..3e0c66f4 100644 --- a/swh/web/tests/api/test_throttling.py +++ b/swh/web/tests/api/test_throttling.py @@ -1,239 +1,232 @@ -# Copyright (C) 2017-2020 The Software Heritage developers +# Copyright (C) 2017-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from django.conf.urls import url -from django.contrib.auth.models import Permission, User -from django.contrib.contenttypes.models import ContentType +from django.contrib.auth.models import User from django.test.utils import override_settings from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.views import APIView from swh.web.api.throttling import ( API_THROTTLING_EXEMPTED_PERM, SwhWebRateThrottle, SwhWebUserRateThrottle, throttle_scope, ) from swh.web.settings.tests import ( scope1_limiter_rate, scope1_limiter_rate_post, scope2_limiter_rate, scope2_limiter_rate_post, scope3_limiter_rate, scope3_limiter_rate_post, ) +from swh.web.tests.utils import create_django_permission from swh.web.urls import urlpatterns class MockViewScope1(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = "scope1" def get(self, request): return Response("foo_get") def post(self, request): return Response("foo_post") @api_view(["GET", "POST"]) @throttle_scope("scope2") def mock_view_scope2(request): if request.method == "GET": return Response("bar_get") elif request.method == "POST": return Response("bar_post") class MockViewScope3(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = "scope3" def get(self, request): return Response("foo_get") def post(self, request): return Response("foo_post") @api_view(["GET", "POST"]) @throttle_scope("scope3") def mock_view_scope3(request): if request.method == "GET": return Response("bar_get") elif request.method == "POST": return Response("bar_post") urlpatterns += [ url(r"^scope1_class$", MockViewScope1.as_view()), url(r"^scope2_func$", mock_view_scope2), url(r"^scope3_class$", MockViewScope3.as_view()), url(r"^scope3_func$", mock_view_scope3), ] def check_response(response, status_code, limit=None, remaining=None): assert response.status_code == status_code if limit is not None: assert response["X-RateLimit-Limit"] == str(limit) else: assert "X-RateLimit-Limit" not in response if remaining is not None: assert response["X-RateLimit-Remaining"] == str(remaining) else: assert "X-RateLimit-Remaining" not in response @override_settings(ROOT_URLCONF=__name__) def test_scope1_requests_are_throttled(api_client): """ Ensure request rate is limited in scope1 """ for i in range(scope1_limiter_rate): response = api_client.get("/scope1_class") check_response(response, 200, scope1_limiter_rate, scope1_limiter_rate - i - 1) response = api_client.get("/scope1_class") check_response(response, 429, scope1_limiter_rate, 0) for i in range(scope1_limiter_rate_post): response = api_client.post("/scope1_class") check_response( response, 200, scope1_limiter_rate_post, scope1_limiter_rate_post - i - 1 ) response = api_client.post("/scope1_class") check_response(response, 429, scope1_limiter_rate_post, 0) @override_settings(ROOT_URLCONF=__name__) def test_scope2_requests_are_throttled(api_client): """ Ensure request rate is limited in scope2 """ for i in range(scope2_limiter_rate): response = api_client.get("/scope2_func") check_response(response, 200, scope2_limiter_rate, scope2_limiter_rate - i - 1) response = api_client.get("/scope2_func") check_response(response, 429, scope2_limiter_rate, 0) for i in range(scope2_limiter_rate_post): response = api_client.post("/scope2_func") check_response( response, 200, scope2_limiter_rate_post, scope2_limiter_rate_post - i - 1 ) response = api_client.post("/scope2_func") check_response(response, 429, scope2_limiter_rate_post, 0) @override_settings(ROOT_URLCONF=__name__) def test_scope3_requests_are_throttled_exempted(api_client): """ Ensure request rate is not limited in scope3 as requests coming from localhost are exempted from rate limit. """ for _ in range(scope3_limiter_rate + 1): response = api_client.get("/scope3_class") check_response(response, 200) for _ in range(scope3_limiter_rate_post + 1): response = api_client.post("/scope3_class") check_response(response, 200) for _ in range(scope3_limiter_rate + 1): response = api_client.get("/scope3_func") check_response(response, 200) for _ in range(scope3_limiter_rate_post + 1): response = api_client.post("/scope3_func") check_response(response, 200) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_staff_users_are_not_rate_limited(api_client): staff_user = User.objects.create_user( username="johndoe", password="", is_staff=True ) api_client.force_login(staff_user) for _ in range(scope2_limiter_rate + 1): response = api_client.get("/scope2_func") check_response(response, 200) for _ in range(scope2_limiter_rate_post + 1): response = api_client.post("/scope2_func") check_response(response, 200) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_non_staff_users_are_rate_limited(api_client): user = User.objects.create_user(username="johndoe", password="", is_staff=False) api_client.force_login(user) scope2_limiter_rate_user = ( scope2_limiter_rate * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR ) for i in range(scope2_limiter_rate_user): response = api_client.get("/scope2_func") check_response( response, 200, scope2_limiter_rate_user, scope2_limiter_rate_user - i - 1 ) response = api_client.get("/scope2_func") check_response(response, 429, scope2_limiter_rate_user, 0) scope2_limiter_rate_post_user = ( scope2_limiter_rate_post * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR ) for i in range(scope2_limiter_rate_post_user): response = api_client.post("/scope2_func") check_response( response, 200, scope2_limiter_rate_post_user, scope2_limiter_rate_post_user - i - 1, ) response = api_client.post("/scope2_func") check_response(response, 429, scope2_limiter_rate_post_user, 0) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_users_with_throttling_exempted_perm_are_not_rate_limited(api_client): user = User.objects.create_user(username="johndoe", password="") - perm_splitted = API_THROTTLING_EXEMPTED_PERM.split(".") - app_label = ".".join(perm_splitted[:-1]) - perm_name = perm_splitted[-1] - content_type = ContentType.objects.create(app_label=app_label, model="dummy") - permission = Permission.objects.create( - codename=perm_name, name=perm_name, content_type=content_type, - ) - user.user_permissions.add(permission) + user.user_permissions.add(create_django_permission(API_THROTTLING_EXEMPTED_PERM)) assert user.has_perm(API_THROTTLING_EXEMPTED_PERM) api_client.force_login(user) for _ in range(scope2_limiter_rate + 1): response = api_client.get("/scope2_func") check_response(response, 200) for _ in range(scope2_limiter_rate_post + 1): response = api_client.post("/scope2_func") check_response(response, 200) diff --git a/swh/web/tests/create_test_users.py b/swh/web/tests/create_test_users.py index dfdd24d6..f92d3604 100644 --- a/swh/web/tests/create_test_users.py +++ b/swh/web/tests/create_test_users.py @@ -1,16 +1,29 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Dict, List, Tuple from django.contrib.auth import get_user_model +from swh.web.auth.utils import SWH_AMBASSADOR_PERMISSION +from swh.web.tests.utils import create_django_permission + User = get_user_model() -username = "user" -password = "user" -email = "user@swh-web.org" -if not User.objects.filter(username=username).exists(): - User.objects.create_user(username, email, password) +users: Dict[str, Tuple[str, str, List[str]]] = { + "user": ("user", "user@swh-web.org", []), + "ambassador": ("ambassador", "ambassador@swh-web.org", [SWH_AMBASSADOR_PERMISSION]), +} + +for username, (password, email, permissions) in users.items(): + if not User.objects.filter(username=username).exists(): + user = User.objects.create_user(username, email, password) + if permissions: + for perm_name in permissions: + permission = create_django_permission(perm_name) + user.user_permissions.add(permission) + + user.save() diff --git a/swh/web/tests/test_create_users.py b/swh/web/tests/test_create_users.py new file mode 100644 index 00000000..75694ef5 --- /dev/null +++ b/swh/web/tests/test_create_users.py @@ -0,0 +1,16 @@ +# Copyright (C) 2021 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU Affero General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def test_create_users_test_users_exist(db): + from .create_test_users import User, users + + for username, (_, _, permissions) in users.items(): + + user = User.objects.filter(username=username).get() + assert user is not None + + for permission in permissions: + assert user.has_perm(permission) diff --git a/swh/web/tests/utils.py b/swh/web/tests/utils.py index d3d63a37..e744cb2b 100644 --- a/swh/web/tests/utils.py +++ b/swh/web/tests/utils.py @@ -1,209 +1,231 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Dict, Optional, cast +from django.contrib.auth.models import Permission +from django.contrib.contenttypes.models import ContentType from django.http import HttpResponse, StreamingHttpResponse from django.test.client import Client from rest_framework.response import Response from rest_framework.test import APIClient from swh.web.tests.django_asserts import assert_template_used def _assert_http_response( response: HttpResponse, status_code: int, content_type: str ) -> HttpResponse: if isinstance(response, Response): drf_response = cast(Response, response) error_context = ( drf_response.data.pop("traceback") if isinstance(drf_response.data, dict) and "traceback" in drf_response.data else drf_response.data ) elif isinstance(response, StreamingHttpResponse): error_context = getattr(response, "traceback", response.streaming_content) else: error_context = getattr(response, "traceback", response.content) assert response.status_code == status_code, error_context if content_type != "*/*": assert response["Content-Type"].startswith(content_type) return response def check_http_get_response( client: Client, url: str, status_code: int, content_type: str = "*/*", http_origin: Optional[str] = None, server_name: Optional[str] = None, ) -> HttpResponse: """Helper function to check HTTP response for a GET request. Args: client: Django test client url: URL to check response status_code: expected HTTP status code content_type: expected response content type http_origin: optional HTTP_ORIGIN header value Returns: The HTTP response """ return _assert_http_response( response=client.get( url, HTTP_ACCEPT=content_type, HTTP_ORIGIN=http_origin, SERVER_NAME=server_name if server_name else "testserver", ), status_code=status_code, content_type=content_type, ) def check_http_post_response( client: Client, url: str, status_code: int, content_type: str = "*/*", data: Optional[Dict[str, Any]] = None, http_origin: Optional[str] = None, ) -> HttpResponse: """Helper function to check HTTP response for a POST request. Args: client: Django test client url: URL to check response status_code: expected HTTP status code content_type: expected response content type data: optional POST data Returns: The HTTP response """ return _assert_http_response( response=client.post( url, data=data, content_type="application/json", HTTP_ACCEPT=content_type, HTTP_ORIGIN=http_origin, ), status_code=status_code, content_type=content_type, ) def check_api_get_responses( api_client: APIClient, url: str, status_code: int ) -> Response: """Helper function to check Web API responses for GET requests for all accepted content types (JSON, YAML, HTML). Args: api_client: DRF test client url: Web API URL to check responses status_code: expected HTTP status code Returns: The Web API JSON response """ # check JSON response response_json = check_http_get_response( api_client, url, status_code, content_type="application/json" ) # check HTML response (API Web UI) check_http_get_response(api_client, url, status_code, content_type="text/html") # check YAML response check_http_get_response( api_client, url, status_code, content_type="application/yaml" ) return cast(Response, response_json) def check_api_post_response( api_client: APIClient, url: str, status_code: int, content_type: str = "*/*", data: Optional[Dict[str, Any]] = None, ) -> HttpResponse: """Helper function to check Web API response for a POST request for all accepted content types. Args: api_client: DRF test client url: Web API URL to check response status_code: expected HTTP status code Returns: The HTTP response """ return _assert_http_response( response=api_client.post( url, data=data, format="json", HTTP_ACCEPT=content_type, ), status_code=status_code, content_type=content_type, ) def check_api_post_responses( api_client: APIClient, url: str, status_code: int, data: Optional[Dict[str, Any]] = None, ) -> Response: """Helper function to check Web API responses for POST requests for all accepted content types (JSON, YAML). Args: api_client: DRF test client url: Web API URL to check responses status_code: expected HTTP status code Returns: The Web API JSON response """ # check JSON response response_json = check_api_post_response( api_client, url, status_code, content_type="application/json", data=data ) # check YAML response check_api_post_response( api_client, url, status_code, content_type="application/yaml", data=data ) return cast(Response, response_json) def check_html_get_response( client: Client, url: str, status_code: int, template_used: Optional[str] = None ) -> HttpResponse: """Helper function to check HTML responses for a GET request. Args: client: Django test client url: URL to check responses status_code: expected HTTP status code template_used: optional used Django template to check Returns: The HTML response """ response = check_http_get_response( client, url, status_code, content_type="text/html" ) if template_used is not None: assert_template_used(response, template_used) return response + + +def create_django_permission(perm_name: str) -> Permission: + """Create permission out of a permission name string + + Args: + perm_name: Permission name (e.g. swh.web.api.throttling_exempted, + swh.ambassador, ...) + + Returns: + The persisted permission + + """ + perm_splitted = perm_name.split(".") + app_label = ".".join(perm_splitted[:-1]) + perm_name = perm_splitted[-1] + content_type = ContentType.objects.create(app_label=app_label, model="dummy") + return Permission.objects.create( + codename=perm_name, name=perm_name, content_type=content_type, + )