diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -1,17 +1,25 @@ -# Copyright (C) 2017-2018 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -import ipaddress +from ipaddress import ( + ip_address, ip_network, IPv4Network, IPv6Network +) +from typing import Callable, List, TypeVar, Union from django.core.exceptions import ImproperlyConfigured +import rest_framework from rest_framework.throttling import ScopedRateThrottle import sentry_sdk from swh.web.config import get_config +APIView = TypeVar('APIView', bound='rest_framework.views.APIView') +Request = rest_framework.request.Request + + class SwhWebRateThrottle(ScopedRateThrottle): """Custom request rate limiter for DRF enabling to exempt specific networks specified in swh-web configuration. @@ -54,18 +62,19 @@ super().__init__() self.exempted_networks = None - def get_exempted_networks(self, scope_name): + def get_exempted_networks(self, scope_name: str + ) -> List[Union[IPv4Network, IPv6Network]]: if not self.exempted_networks: scopes = get_config()['throttling']['scopes'] scope = scopes.get(scope_name) if scope: networks = scope.get('exempted_networks') if networks: - self.exempted_networks = [ipaddress.ip_network(network) + self.exempted_networks = [ip_network(network) for network in networks] return self.exempted_networks - def allow_request(self, request, view): + def allow_request(self, request: Request, view: APIView) -> bool: # class based view case if not self.scope: @@ -74,7 +83,8 @@ if default_scope is not None: # check if there is a specific rate limiting associated # to the request type - request_scope = default_scope + '_' + request.method.lower() + assert request.method is not None + request_scope = f'{default_scope}_{request.method.lower()}' setattr(view, self.scope_attr, request_scope) try: request_allowed = super().allow_request(request, view) @@ -109,7 +119,7 @@ exempted_ip = False if exempted_networks: - remote_address = ipaddress.ip_address(self.get_ident(request)) + remote_address = ip_address(self.get_ident(request)) exempted_ip = any(remote_address in network for network in exempted_networks) request_allowed = exempted_ip or request_allowed @@ -121,12 +131,14 @@ hit_count = len(self.history) request.META['RateLimit-Limit'] = self.num_requests request.META['RateLimit-Remaining'] = self.num_requests - hit_count - request.META['RateLimit-Reset'] = int(self.now + self.wait()) + wait = self.wait() + if wait is not None: + request.META['RateLimit-Reset'] = int(self.now + wait) return request_allowed -def throttle_scope(scope): +def throttle_scope(scope: str) -> Callable[..., APIView]: """Decorator that allows the throttle scope of a DRF function based view to be set:: @@ -136,7 +148,7 @@ ... """ - def decorator(func): + def decorator(func: APIView) -> APIView: SwhScopeRateThrottle = type( 'CustomScopeRateThrottle', (SwhWebRateThrottle,),