diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -21,8 +21,7 @@ class SwhWebRateThrottle(ScopedRateThrottle): - """Custom request rate limiter for DRF enabling to exempt - specific networks specified in swh-web configuration. + """Custom DRF request rate limiter for anonymous users Requests are grouped into scopes. It enables to apply different requests rate limiting based on the scope name but also the @@ -64,6 +63,13 @@ self.num_requests = 0 self.duration = 0 + def get_cache_key(self, request, view): + # do not handle throttling if user is authenticated + if request.user.is_authenticated: + return None + else: + return super().get_cache_key(request, view) + def get_exempted_networks( self, scope_name: str ) -> List[Union[IPv4Network, IPv6Network]]: @@ -79,11 +85,6 @@ return self.exempted_networks def allow_request(self, request: Request, view: APIView) -> bool: - # no throttling for staff users - if request.user.is_authenticated and ( - request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM) - ): - return True # class based view case if not self.scope: @@ -148,6 +149,35 @@ return request_allowed +class SwhWebUserRateThrottle(SwhWebRateThrottle): + """Custom DRF request rate limiter for authenticated users + + It has the same behavior than :class:`swh.web.api.throttling.SwhWebRateThrottle` + except the number of allowed requests for each throttle scope is increased by a + 1Ox factor. + """ + + NUM_REQUESTS_FACTOR = 10 + + def get_cache_key(self, request, view): + # do not handle throttling if user is not authenticated + if request.user.is_authenticated: + return super(SwhWebRateThrottle, self).get_cache_key(request, view) + else: + return None + + def parse_rate(self, rate): + # increase number of allowed requests + num_requests, duration = super().parse_rate(rate) + return (num_requests * self.NUM_REQUESTS_FACTOR, duration) + + def allow_request(self, request: Request, view: APIView) -> bool: + # no throttling for staff users or users with adequate permission + if request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM): + return True + return super().allow_request(request, view) + + def throttle_scope(scope: str) -> Callable[..., APIView]: """Decorator that allows the throttle scope of a DRF function based view to be set:: @@ -161,9 +191,12 @@ def decorator(func: APIView) -> APIView: SwhScopeRateThrottle = type( - "CustomScopeRateThrottle", (SwhWebRateThrottle,), {"scope": scope} + "SwhWebScopeRateThrottle", (SwhWebRateThrottle,), {"scope": scope} + ) + SwhScopeUserRateThrottle = type( + "SwhWebScopeUserRateThrottle", (SwhWebUserRateThrottle,), {"scope": scope}, ) - func.throttle_classes = (SwhScopeRateThrottle,) + func.throttle_classes = (SwhScopeRateThrottle, SwhScopeUserRateThrottle) return func return decorator diff --git a/swh/web/settings/common.py b/swh/web/settings/common.py --- a/swh/web/settings/common.py +++ b/swh/web/settings/common.py @@ -163,7 +163,10 @@ "swh.web.api.renderers.YAMLRenderer", "rest_framework.renderers.TemplateHTMLRenderer", ), - "DEFAULT_THROTTLE_CLASSES": ("swh.web.api.throttling.SwhWebRateThrottle",), + "DEFAULT_THROTTLE_CLASSES": ( + "swh.web.api.throttling.SwhWebRateThrottle", + "swh.web.api.throttling.SwhWebUserRateThrottle", + ), "DEFAULT_THROTTLE_RATES": throttle_rates, "DEFAULT_AUTHENTICATION_CLASSES": [ "rest_framework.authentication.SessionAuthentication", diff --git a/swh/web/tests/api/test_throttling.py b/swh/web/tests/api/test_throttling.py --- a/swh/web/tests/api/test_throttling.py +++ b/swh/web/tests/api/test_throttling.py @@ -16,6 +16,7 @@ from swh.web.api.throttling import ( API_THROTTLING_EXEMPTED_PERM, SwhWebRateThrottle, + SwhWebUserRateThrottle, throttle_scope, ) from swh.web.settings.tests import ( @@ -177,25 +178,39 @@ @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_non_staff_users_are_rate_limited(api_client): + user = User.objects.create_user(username="johndoe", password="", is_staff=False) api_client.force_login(user) - for i in range(scope2_limiter_rate): + scope2_limiter_rate_user = ( + scope2_limiter_rate * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR + ) + + for i in range(scope2_limiter_rate_user): response = api_client.get("/scope2_func") - check_response(response, 200, scope2_limiter_rate, scope2_limiter_rate - i - 1) + check_response( + response, 200, scope2_limiter_rate_user, scope2_limiter_rate_user - i - 1 + ) response = api_client.get("/scope2_func") - check_response(response, 429, scope2_limiter_rate, 0) + check_response(response, 429, scope2_limiter_rate_user, 0) - for i in range(scope2_limiter_rate_post): + scope2_limiter_rate_post_user = ( + scope2_limiter_rate_post * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR + ) + + for i in range(scope2_limiter_rate_post_user): response = api_client.post("/scope2_func") check_response( - response, 200, scope2_limiter_rate_post, scope2_limiter_rate_post - i - 1 + response, + 200, + scope2_limiter_rate_post_user, + scope2_limiter_rate_post_user - i - 1, ) response = api_client.post("/scope2_func") - check_response(response, 429, scope2_limiter_rate_post, 0) + check_response(response, 429, scope2_limiter_rate_post_user, 0) @override_settings(ROOT_URLCONF=__name__)