Changeset View
Changeset View
Standalone View
Standalone View
swh/web/api/throttling.py
# Copyright (C) 2017-2020 The Software Heritage developers | # Copyright (C) 2017-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU Affero General Public License version 3, or any later version | # License: GNU Affero General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from ipaddress import IPv4Network, IPv6Network, ip_address, ip_network | from ipaddress import IPv4Network, IPv6Network, ip_address, ip_network | ||||
from typing import Callable, List, TypeVar, Union | from typing import Callable, List, TypeVar, Union | ||||
import sentry_sdk | import sentry_sdk | ||||
from django.core.exceptions import ImproperlyConfigured | from django.core.exceptions import ImproperlyConfigured | ||||
import rest_framework | import rest_framework | ||||
from rest_framework.throttling import ScopedRateThrottle | from rest_framework.throttling import ScopedRateThrottle | ||||
from swh.web.auth.utils import API_SAVE_ORIGIN_PERMISSION | |||||
from swh.web.config import get_config | from swh.web.config import get_config | ||||
APIView = TypeVar("APIView", bound="rest_framework.views.APIView") | APIView = TypeVar("APIView", bound="rest_framework.views.APIView") | ||||
Request = rest_framework.request.Request | Request = rest_framework.request.Request | ||||
API_THROTTLING_EXEMPTED_PERM = "swh.web.api.throttling_exempted" | API_THROTTLING_EXEMPTED_PERM = "swh.web.api.throttling_exempted" | ||||
▲ Show 20 Lines • Show All 56 Lines • ▼ Show 20 Lines | ) -> List[Union[IPv4Network, IPv6Network]]: | ||||
if scope: | if scope: | ||||
networks = scope.get("exempted_networks") | networks = scope.get("exempted_networks") | ||||
if networks: | if networks: | ||||
self.exempted_networks = [ | self.exempted_networks = [ | ||||
ip_network(network) for network in networks | ip_network(network) for network in networks | ||||
] | ] | ||||
return self.exempted_networks | return self.exempted_networks | ||||
def get_scope(self, view: APIView): | |||||
if not self.scope: | |||||
# class based view case | |||||
return getattr(view, self.scope_attr, None) | |||||
else: | |||||
# function based view case | |||||
return self.scope | |||||
def allow_request(self, request: Request, view: APIView) -> bool: | def allow_request(self, request: Request, view: APIView) -> bool: | ||||
# class based view case | # class based view case | ||||
if not self.scope: | if not self.scope: | ||||
default_scope = getattr(view, self.scope_attr, None) | default_scope = getattr(view, self.scope_attr, None) | ||||
request_allowed = None | request_allowed = None | ||||
if default_scope is not None: | if default_scope is not None: | ||||
# check if there is a specific rate limiting associated | # check if there is a specific rate limiting associated | ||||
▲ Show 20 Lines • Show All 72 Lines • ▼ Show 20 Lines | def get_cache_key(self, request, view): | ||||
return None | return None | ||||
def parse_rate(self, rate): | def parse_rate(self, rate): | ||||
# increase number of allowed requests | # increase number of allowed requests | ||||
num_requests, duration = super().parse_rate(rate) | num_requests, duration = super().parse_rate(rate) | ||||
return (num_requests * self.NUM_REQUESTS_FACTOR, duration) | return (num_requests * self.NUM_REQUESTS_FACTOR, duration) | ||||
def allow_request(self, request: Request, view: APIView) -> bool: | def allow_request(self, request: Request, view: APIView) -> bool: | ||||
# no throttling for staff users or users with adequate permission | |||||
if request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM): | if request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM): | ||||
# no throttling for staff users or users with adequate permission | |||||
return True | |||||
scope = self.get_scope(view) | |||||
if scope == "save_origin" and request.user.has_perm(API_SAVE_ORIGIN_PERMISSION): | |||||
# no throttling on save origin endpoint for users with adequate permission | |||||
return True | return True | ||||
return super().allow_request(request, view) | return super().allow_request(request, view) | ||||
def throttle_scope(scope: str) -> Callable[..., APIView]: | def throttle_scope(scope: str) -> Callable[..., APIView]: | ||||
"""Decorator that allows the throttle scope of a DRF | """Decorator that allows the throttle scope of a DRF | ||||
function based view to be set:: | function based view to be set:: | ||||
Show All 18 Lines |