diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -17,6 +17,8 @@ APIView = TypeVar("APIView", bound="rest_framework.views.APIView") Request = rest_framework.request.Request +API_THROTTLING_EXEMPTED_PERM = "swh.web.api.throttling_exempted" + class SwhWebRateThrottle(ScopedRateThrottle): """Custom request rate limiter for DRF enabling to exempt @@ -76,7 +78,9 @@ def allow_request(self, request: Request, view: APIView) -> bool: # no throttling for staff users - if request.user.is_authenticated and request.user.is_staff: + if request.user.is_authenticated and ( + request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM) + ): return True # class based view case if not self.scope: diff --git a/swh/web/tests/api/test_throttling.py b/swh/web/tests/api/test_throttling.py --- a/swh/web/tests/api/test_throttling.py +++ b/swh/web/tests/api/test_throttling.py @@ -6,14 +6,19 @@ import pytest from django.conf.urls import url -from django.contrib.auth.models import User +from django.contrib.auth.models import Permission, User +from django.contrib.contenttypes.models import ContentType from django.test.utils import override_settings from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.decorators import api_view -from swh.web.api.throttling import SwhWebRateThrottle, throttle_scope +from swh.web.api.throttling import ( + SwhWebRateThrottle, + throttle_scope, + API_THROTTLING_EXEMPTED_PERM, +) from swh.web.settings.tests import ( scope1_limiter_rate, scope1_limiter_rate_post, @@ -192,3 +197,29 @@ response = api_client.post("/scope2_func") check_response(response, 429, scope2_limiter_rate_post, 0) + + +@override_settings(ROOT_URLCONF=__name__) +@pytest.mark.django_db +def test_users_with_throttling_exempted_perm_are_not_rate_limited(api_client): + user = User.objects.create_user(username="johndoe", password="") + perm_splitted = API_THROTTLING_EXEMPTED_PERM.split(".") + app_label = ".".join(perm_splitted[:-1]) + perm_name = perm_splitted[-1] + content_type = ContentType.objects.create(app_label=app_label, model="dummy") + permission = Permission.objects.create( + codename=perm_name, name=perm_name, content_type=content_type, + ) + user.user_permissions.add(permission) + + assert user.has_perm(API_THROTTLING_EXEMPTED_PERM) + + api_client.force_login(user) + + for _ in range(scope2_limiter_rate + 1): + response = api_client.get("/scope2_func") + check_response(response, 200) + + for _ in range(scope2_limiter_rate_post + 1): + response = api_client.post("/scope2_func") + check_response(response, 200)