diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -75,6 +75,9 @@ return self.exempted_networks def allow_request(self, request: Request, view: APIView) -> bool: + # no throttling for staff users + if request.user.is_authenticated and request.user.is_staff: + return True # class based view case if not self.scope: diff --git a/swh/web/tests/api/test_throttling.py b/swh/web/tests/api/test_throttling.py --- a/swh/web/tests/api/test_throttling.py +++ b/swh/web/tests/api/test_throttling.py @@ -1,23 +1,24 @@ -# Copyright (C) 2017-2019 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.web.settings.tests import ( - scope1_limiter_rate, scope1_limiter_rate_post, - scope2_limiter_rate, scope2_limiter_rate_post, - scope3_limiter_rate, scope3_limiter_rate_post -) +import pytest from django.conf.urls import url +from django.contrib.auth.models import User from django.test.utils import override_settings from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.decorators import api_view - from swh.web.api.throttling import SwhWebRateThrottle, throttle_scope +from swh.web.settings.tests import ( + scope1_limiter_rate, scope1_limiter_rate_post, + scope2_limiter_rate, scope2_limiter_rate_post, + scope3_limiter_rate, scope3_limiter_rate_post +) class MockViewScope1(APIView): @@ -146,3 +147,45 @@ for _ in range(scope3_limiter_rate_post+1): response = api_client.post('/scope3_func') check_response(response, 200) + + +@override_settings(ROOT_URLCONF=__name__) +@pytest.mark.django_db +def test_staff_users_are_not_rate_limited(api_client): + staff_user = User.objects.create_user( + username='johndoe', password='', is_staff=True) + + api_client.force_login(staff_user) + + for _ in range(scope2_limiter_rate+1): + response = api_client.get('/scope2_func') + check_response(response, 200) + + for _ in range(scope2_limiter_rate_post+1): + response = api_client.post('/scope2_func') + check_response(response, 200) + + +@override_settings(ROOT_URLCONF=__name__) +@pytest.mark.django_db +def test_non_staff_users_are_rate_limited(api_client): + user = User.objects.create_user( + username='johndoe', password='', is_staff=False) + + api_client.force_login(user) + + for i in range(scope2_limiter_rate): + response = api_client.get('/scope2_func') + check_response(response, 200, scope2_limiter_rate, + scope2_limiter_rate - i - 1) + + response = api_client.get('/scope2_func') + check_response(response, 429, scope2_limiter_rate, 0) + + for i in range(scope2_limiter_rate_post): + response = api_client.post('/scope2_func') + check_response(response, 200, scope2_limiter_rate_post, + scope2_limiter_rate_post - i - 1) + + response = api_client.post('/scope2_func') + check_response(response, 429, scope2_limiter_rate_post, 0)