diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py index cd34d026..ed449924 100644 --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -1,169 +1,202 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from ipaddress import IPv4Network, IPv6Network, ip_address, ip_network from typing import Callable, List, TypeVar, Union import sentry_sdk from django.core.exceptions import ImproperlyConfigured import rest_framework from rest_framework.throttling import ScopedRateThrottle from swh.web.config import get_config APIView = TypeVar("APIView", bound="rest_framework.views.APIView") Request = rest_framework.request.Request API_THROTTLING_EXEMPTED_PERM = "swh.web.api.throttling_exempted" class SwhWebRateThrottle(ScopedRateThrottle): - """Custom request rate limiter for DRF enabling to exempt - specific networks specified in swh-web configuration. + """Custom DRF request rate limiter for anonymous users Requests are grouped into scopes. It enables to apply different requests rate limiting based on the scope name but also the input HTTP request types. To associate a scope to requests, one must add a 'throttle_scope' attribute when using a class based view, or call the 'throttle_scope' decorator when using a function based view. By default, requests do not have an associated scope and are not rate limited. Rate limiting can also be configured according to the type of the input HTTP requests for fine grained tuning. For instance, the following YAML configuration section sets a rate of: - 1 per minute for POST requests - 60 per minute for other request types for the 'swh_api' scope while exempting those coming from the 127.0.0.0/8 ip network. .. code-block:: yaml throttling: scopes: swh_api: limiter_rate: default: 60/m POST: 1/m exempted_networks: - 127.0.0.0/8 """ scope = None def __init__(self): super().__init__() self.exempted_networks = None self.num_requests = 0 self.duration = 0 + def get_cache_key(self, request, view): + # do not handle throttling if user is authenticated + if request.user.is_authenticated: + return None + else: + return super().get_cache_key(request, view) + def get_exempted_networks( self, scope_name: str ) -> List[Union[IPv4Network, IPv6Network]]: if not self.exempted_networks: scopes = get_config()["throttling"]["scopes"] scope = scopes.get(scope_name) if scope: networks = scope.get("exempted_networks") if networks: self.exempted_networks = [ ip_network(network) for network in networks ] return self.exempted_networks def allow_request(self, request: Request, view: APIView) -> bool: - # no throttling for staff users - if request.user.is_authenticated and ( - request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM) - ): - return True # class based view case if not self.scope: default_scope = getattr(view, self.scope_attr, None) request_allowed = None if default_scope is not None: # check if there is a specific rate limiting associated # to the request type assert request.method is not None request_scope = f"{default_scope}_{request.method.lower()}" setattr(view, self.scope_attr, request_scope) try: request_allowed = super().allow_request(request, view) # use default rate limiting otherwise except ImproperlyConfigured as exc: sentry_sdk.capture_exception(exc) setattr(view, self.scope_attr, default_scope) if request_allowed is None: request_allowed = super().allow_request(request, view) # function based view case else: default_scope = self.scope # check if there is a specific rate limiting associated # to the request type self.scope = default_scope + "_" + request.method.lower() try: self.rate = self.get_rate() # use default rate limiting otherwise except ImproperlyConfigured: self.scope = default_scope self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) request_allowed = super(ScopedRateThrottle, self).allow_request( request, view ) self.scope = default_scope exempted_networks = self.get_exempted_networks(default_scope) exempted_ip = False if exempted_networks: remote_address = ip_address(self.get_ident(request)) exempted_ip = any( remote_address in network for network in exempted_networks ) request_allowed = exempted_ip or request_allowed # set throttling related data in the request metadata # in order for the ThrottlingHeadersMiddleware to # add X-RateLimit-* headers in the HTTP response if not exempted_ip and hasattr(self, "history"): hit_count = len(self.history) request.META["RateLimit-Limit"] = self.num_requests request.META["RateLimit-Remaining"] = self.num_requests - hit_count wait = self.wait() if wait is not None: request.META["RateLimit-Reset"] = int(self.now + wait) return request_allowed +class SwhWebUserRateThrottle(SwhWebRateThrottle): + """Custom DRF request rate limiter for authenticated users + + It has the same behavior than :class:`swh.web.api.throttling.SwhWebRateThrottle` + except the number of allowed requests for each throttle scope is increased by a + 1Ox factor. + """ + + NUM_REQUESTS_FACTOR = 10 + + def get_cache_key(self, request, view): + # do not handle throttling if user is not authenticated + if request.user.is_authenticated: + return super(SwhWebRateThrottle, self).get_cache_key(request, view) + else: + return None + + def parse_rate(self, rate): + # increase number of allowed requests + num_requests, duration = super().parse_rate(rate) + return (num_requests * self.NUM_REQUESTS_FACTOR, duration) + + def allow_request(self, request: Request, view: APIView) -> bool: + # no throttling for staff users or users with adequate permission + if request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM): + return True + return super().allow_request(request, view) + + def throttle_scope(scope: str) -> Callable[..., APIView]: """Decorator that allows the throttle scope of a DRF function based view to be set:: @api_view(['GET', ]) @throttle_scope('scope') def view(request): ... """ def decorator(func: APIView) -> APIView: SwhScopeRateThrottle = type( - "CustomScopeRateThrottle", (SwhWebRateThrottle,), {"scope": scope} + "SwhWebScopeRateThrottle", (SwhWebRateThrottle,), {"scope": scope} + ) + SwhScopeUserRateThrottle = type( + "SwhWebScopeUserRateThrottle", (SwhWebUserRateThrottle,), {"scope": scope}, ) - func.throttle_classes = (SwhScopeRateThrottle,) + func.throttle_classes = (SwhScopeRateThrottle, SwhScopeUserRateThrottle) return func return decorator diff --git a/swh/web/settings/common.py b/swh/web/settings/common.py index d61949b6..ac6e4878 100644 --- a/swh/web/settings/common.py +++ b/swh/web/settings/common.py @@ -1,288 +1,291 @@ # Copyright (C) 2017-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information """ Django common settings for swh-web. """ import os import sys from typing import Any, Dict from swh.web.auth.utils import OIDC_SWH_WEB_CLIENT_ID from swh.web.config import get_config swh_web_config = get_config() # Build paths inside the project like this: os.path.join(BASE_DIR, ...) PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = swh_web_config["secret_key"] # SECURITY WARNING: don't run with debug turned on in production! DEBUG = swh_web_config["debug"] DEBUG_PROPAGATE_EXCEPTIONS = swh_web_config["debug"] ALLOWED_HOSTS = ["127.0.0.1", "localhost"] + swh_web_config["allowed_hosts"] # Application definition INSTALLED_APPS = [ "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", "rest_framework", "swh.web.common", "swh.web.api", "swh.web.auth", "swh.web.browse", "webpack_loader", "django_js_reverse", "corsheaders", ] MIDDLEWARE = [ "django.middleware.security.SecurityMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "corsheaders.middleware.CorsMiddleware", "django.middleware.common.CommonMiddleware", "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "swh.auth.django.middlewares.OIDCSessionExpiredMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", "swh.web.common.middlewares.ThrottlingHeadersMiddleware", "swh.web.common.middlewares.ExceptionMiddleware", ] # Compress all assets (static ones and dynamically generated html) # served by django in a local development environment context. # In a production environment, assets compression will be directly # handled by web servers like apache or nginx. if swh_web_config["serve_assets"]: MIDDLEWARE.insert(0, "django.middleware.gzip.GZipMiddleware") ROOT_URLCONF = "swh.web.urls" TEMPLATES = [ { "BACKEND": "django.template.backends.django.DjangoTemplates", "DIRS": [os.path.join(PROJECT_DIR, "../templates")], "APP_DIRS": True, "OPTIONS": { "context_processors": [ "django.template.context_processors.debug", "django.template.context_processors.request", "django.contrib.auth.context_processors.auth", "django.contrib.messages.context_processors.messages", "swh.web.common.utils.context_processor", ], "libraries": {"swh_templatetags": "swh.web.common.swh_templatetags",}, }, }, ] DATABASES = { "default": { "ENGINE": "django.db.backends.sqlite3", "NAME": swh_web_config.get("development_db", ""), } } # Password validation # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ { "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", # noqa }, {"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",}, {"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",}, {"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",}, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ LANGUAGE_CODE = "en-us" TIME_ZONE = "UTC" USE_I18N = True USE_L10N = True USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ STATIC_URL = "/static/" # static folder location when swh-web has been installed with pip STATIC_DIR = os.path.join(sys.prefix, "share/swh/web/static") if not os.path.exists(STATIC_DIR): # static folder location when developping swh-web STATIC_DIR = os.path.join(PROJECT_DIR, "../../../static") STATICFILES_DIRS = [STATIC_DIR] INTERNAL_IPS = ["127.0.0.1"] throttle_rates = {} http_requests = ["GET", "HEAD", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"] throttling = swh_web_config["throttling"] for limiter_scope, limiter_conf in throttling["scopes"].items(): if "default" in limiter_conf["limiter_rate"]: throttle_rates[limiter_scope] = limiter_conf["limiter_rate"]["default"] # for backward compatibility else: throttle_rates[limiter_scope] = limiter_conf["limiter_rate"] # register sub scopes specific for HTTP request types for http_request in http_requests: if http_request in limiter_conf["limiter_rate"]: throttle_rates[limiter_scope + "_" + http_request.lower()] = limiter_conf[ "limiter_rate" ][http_request] REST_FRAMEWORK: Dict[str, Any] = { "DEFAULT_RENDERER_CLASSES": ( "rest_framework.renderers.JSONRenderer", "swh.web.api.renderers.YAMLRenderer", "rest_framework.renderers.TemplateHTMLRenderer", ), - "DEFAULT_THROTTLE_CLASSES": ("swh.web.api.throttling.SwhWebRateThrottle",), + "DEFAULT_THROTTLE_CLASSES": ( + "swh.web.api.throttling.SwhWebRateThrottle", + "swh.web.api.throttling.SwhWebUserRateThrottle", + ), "DEFAULT_THROTTLE_RATES": throttle_rates, "DEFAULT_AUTHENTICATION_CLASSES": [ "rest_framework.authentication.SessionAuthentication", "swh.auth.django.backends.OIDCBearerTokenAuthentication", ], "EXCEPTION_HANDLER": "swh.web.api.apiresponse.error_response_handler", } LOGGING = { "version": 1, "disable_existing_loggers": False, "filters": { "require_debug_false": {"()": "django.utils.log.RequireDebugFalse",}, "require_debug_true": {"()": "django.utils.log.RequireDebugTrue",}, }, "formatters": { "request": { "format": "[%(asctime)s] [%(levelname)s] %(request)s %(status_code)s", "datefmt": "%d/%b/%Y %H:%M:%S", }, "simple": { "format": "[%(asctime)s] [%(levelname)s] %(message)s", "datefmt": "%d/%b/%Y %H:%M:%S", }, "verbose": { "format": ( "[%(asctime)s] [%(levelname)s] %(name)s.%(funcName)s:%(lineno)s " "- %(message)s" ), "datefmt": "%d/%b/%Y %H:%M:%S", }, }, "handlers": { "console": { "level": "DEBUG", "filters": ["require_debug_true"], "class": "logging.StreamHandler", "formatter": "simple", }, "file": { "level": "WARNING", "filters": ["require_debug_false"], "class": "logging.FileHandler", "filename": os.path.join(swh_web_config["log_dir"], "swh-web.log"), "formatter": "simple", }, "file_request": { "level": "WARNING", "filters": ["require_debug_false"], "class": "logging.FileHandler", "filename": os.path.join(swh_web_config["log_dir"], "swh-web.log"), "formatter": "request", }, "console_verbose": { "level": "DEBUG", "filters": ["require_debug_true"], "class": "logging.StreamHandler", "formatter": "verbose", }, "file_verbose": { "level": "WARNING", "filters": ["require_debug_false"], "class": "logging.FileHandler", "filename": os.path.join(swh_web_config["log_dir"], "swh-web.log"), "formatter": "verbose", }, "null": {"class": "logging.NullHandler",}, }, "loggers": { "": { "handlers": ["console_verbose", "file_verbose"], "level": "DEBUG" if DEBUG else "WARNING", }, "django": { "handlers": ["console"], "level": "DEBUG" if DEBUG else "WARNING", "propagate": False, }, "django.request": { "handlers": ["file_request"], "level": "DEBUG" if DEBUG else "WARNING", "propagate": False, }, "django.db.backends": {"handlers": ["null"], "propagate": False}, "django.utils.autoreload": {"level": "INFO",}, }, } WEBPACK_LOADER = { "DEFAULT": { "CACHE": False, "BUNDLE_DIR_NAME": "./", "STATS_FILE": os.path.join(STATIC_DIR, "webpack-stats.json"), "POLL_INTERVAL": 0.1, "TIMEOUT": None, "IGNORE": [".+\\.hot-update.js", ".+\\.map"], } } LOGIN_URL = "/admin/login/" LOGIN_REDIRECT_URL = "admin" SESSION_ENGINE = "django.contrib.sessions.backends.cache" CACHES = { "default": {"BACKEND": "django.core.cache.backends.locmem.LocMemCache"}, } JS_REVERSE_JS_MINIFY = False CORS_ORIGIN_ALLOW_ALL = True CORS_URLS_REGEX = r"^/(badge|api)/.*$" AUTHENTICATION_BACKENDS = [ "django.contrib.auth.backends.ModelBackend", "swh.auth.django.backends.OIDCAuthorizationCodePKCEBackend", ] SWH_AUTH_SERVER_URL = swh_web_config["keycloak"]["server_url"] SWH_AUTH_REALM_NAME = swh_web_config["keycloak"]["realm_name"] SWH_AUTH_CLIENT_ID = OIDC_SWH_WEB_CLIENT_ID SWH_AUTH_SESSION_EXPIRED_REDIRECT_VIEW = "logout" diff --git a/swh/web/tests/api/test_throttling.py b/swh/web/tests/api/test_throttling.py index 79818390..3a8c3a8d 100644 --- a/swh/web/tests/api/test_throttling.py +++ b/swh/web/tests/api/test_throttling.py @@ -1,224 +1,239 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from django.conf.urls import url from django.contrib.auth.models import Permission, User from django.contrib.contenttypes.models import ContentType from django.test.utils import override_settings from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.views import APIView from swh.web.api.throttling import ( API_THROTTLING_EXEMPTED_PERM, SwhWebRateThrottle, + SwhWebUserRateThrottle, throttle_scope, ) from swh.web.settings.tests import ( scope1_limiter_rate, scope1_limiter_rate_post, scope2_limiter_rate, scope2_limiter_rate_post, scope3_limiter_rate, scope3_limiter_rate_post, ) from swh.web.urls import urlpatterns class MockViewScope1(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = "scope1" def get(self, request): return Response("foo_get") def post(self, request): return Response("foo_post") @api_view(["GET", "POST"]) @throttle_scope("scope2") def mock_view_scope2(request): if request.method == "GET": return Response("bar_get") elif request.method == "POST": return Response("bar_post") class MockViewScope3(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = "scope3" def get(self, request): return Response("foo_get") def post(self, request): return Response("foo_post") @api_view(["GET", "POST"]) @throttle_scope("scope3") def mock_view_scope3(request): if request.method == "GET": return Response("bar_get") elif request.method == "POST": return Response("bar_post") urlpatterns += [ url(r"^scope1_class$", MockViewScope1.as_view()), url(r"^scope2_func$", mock_view_scope2), url(r"^scope3_class$", MockViewScope3.as_view()), url(r"^scope3_func$", mock_view_scope3), ] def check_response(response, status_code, limit=None, remaining=None): assert response.status_code == status_code if limit is not None: assert response["X-RateLimit-Limit"] == str(limit) else: assert "X-RateLimit-Limit" not in response if remaining is not None: assert response["X-RateLimit-Remaining"] == str(remaining) else: assert "X-RateLimit-Remaining" not in response @override_settings(ROOT_URLCONF=__name__) def test_scope1_requests_are_throttled(api_client): """ Ensure request rate is limited in scope1 """ for i in range(scope1_limiter_rate): response = api_client.get("/scope1_class") check_response(response, 200, scope1_limiter_rate, scope1_limiter_rate - i - 1) response = api_client.get("/scope1_class") check_response(response, 429, scope1_limiter_rate, 0) for i in range(scope1_limiter_rate_post): response = api_client.post("/scope1_class") check_response( response, 200, scope1_limiter_rate_post, scope1_limiter_rate_post - i - 1 ) response = api_client.post("/scope1_class") check_response(response, 429, scope1_limiter_rate_post, 0) @override_settings(ROOT_URLCONF=__name__) def test_scope2_requests_are_throttled(api_client): """ Ensure request rate is limited in scope2 """ for i in range(scope2_limiter_rate): response = api_client.get("/scope2_func") check_response(response, 200, scope2_limiter_rate, scope2_limiter_rate - i - 1) response = api_client.get("/scope2_func") check_response(response, 429, scope2_limiter_rate, 0) for i in range(scope2_limiter_rate_post): response = api_client.post("/scope2_func") check_response( response, 200, scope2_limiter_rate_post, scope2_limiter_rate_post - i - 1 ) response = api_client.post("/scope2_func") check_response(response, 429, scope2_limiter_rate_post, 0) @override_settings(ROOT_URLCONF=__name__) def test_scope3_requests_are_throttled_exempted(api_client): """ Ensure request rate is not limited in scope3 as requests coming from localhost are exempted from rate limit. """ for _ in range(scope3_limiter_rate + 1): response = api_client.get("/scope3_class") check_response(response, 200) for _ in range(scope3_limiter_rate_post + 1): response = api_client.post("/scope3_class") check_response(response, 200) for _ in range(scope3_limiter_rate + 1): response = api_client.get("/scope3_func") check_response(response, 200) for _ in range(scope3_limiter_rate_post + 1): response = api_client.post("/scope3_func") check_response(response, 200) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_staff_users_are_not_rate_limited(api_client): staff_user = User.objects.create_user( username="johndoe", password="", is_staff=True ) api_client.force_login(staff_user) for _ in range(scope2_limiter_rate + 1): response = api_client.get("/scope2_func") check_response(response, 200) for _ in range(scope2_limiter_rate_post + 1): response = api_client.post("/scope2_func") check_response(response, 200) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_non_staff_users_are_rate_limited(api_client): + user = User.objects.create_user(username="johndoe", password="", is_staff=False) api_client.force_login(user) - for i in range(scope2_limiter_rate): + scope2_limiter_rate_user = ( + scope2_limiter_rate * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR + ) + + for i in range(scope2_limiter_rate_user): response = api_client.get("/scope2_func") - check_response(response, 200, scope2_limiter_rate, scope2_limiter_rate - i - 1) + check_response( + response, 200, scope2_limiter_rate_user, scope2_limiter_rate_user - i - 1 + ) response = api_client.get("/scope2_func") - check_response(response, 429, scope2_limiter_rate, 0) + check_response(response, 429, scope2_limiter_rate_user, 0) - for i in range(scope2_limiter_rate_post): + scope2_limiter_rate_post_user = ( + scope2_limiter_rate_post * SwhWebUserRateThrottle.NUM_REQUESTS_FACTOR + ) + + for i in range(scope2_limiter_rate_post_user): response = api_client.post("/scope2_func") check_response( - response, 200, scope2_limiter_rate_post, scope2_limiter_rate_post - i - 1 + response, + 200, + scope2_limiter_rate_post_user, + scope2_limiter_rate_post_user - i - 1, ) response = api_client.post("/scope2_func") - check_response(response, 429, scope2_limiter_rate_post, 0) + check_response(response, 429, scope2_limiter_rate_post_user, 0) @override_settings(ROOT_URLCONF=__name__) @pytest.mark.django_db def test_users_with_throttling_exempted_perm_are_not_rate_limited(api_client): user = User.objects.create_user(username="johndoe", password="") perm_splitted = API_THROTTLING_EXEMPTED_PERM.split(".") app_label = ".".join(perm_splitted[:-1]) perm_name = perm_splitted[-1] content_type = ContentType.objects.create(app_label=app_label, model="dummy") permission = Permission.objects.create( codename=perm_name, name=perm_name, content_type=content_type, ) user.user_permissions.add(permission) assert user.has_perm(API_THROTTLING_EXEMPTED_PERM) api_client.force_login(user) for _ in range(scope2_limiter_rate + 1): response = api_client.get("/scope2_func") check_response(response, 200) for _ in range(scope2_limiter_rate_post + 1): response = api_client.post("/scope2_func") check_response(response, 200)