diff --git a/swh/web/common/throttling.py b/swh/web/common/throttling.py index 4156d48d..c27ee40a 100644 --- a/swh/web/common/throttling.py +++ b/swh/web/common/throttling.py @@ -1,92 +1,129 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import ipaddress from rest_framework.throttling import ScopedRateThrottle from swh.web.config import get_config class SwhWebRateThrottle(ScopedRateThrottle): """Custom request rate limiter for DRF enabling to exempt specific networks specified in swh-web configuration. Requests are grouped into scopes. It enables to apply different - requests rate limiting based on the scope name. + requests rate limiting based on the scope name but also the + input HTTP request types. To associate a scope to requests, one must add a 'throttle_scope' attribute when using a class based view, or call the 'throttle_scope' decorator when using a function based view. By default, requests do not have an associated scope and are not rate limited. - For instance, the following YAML configuration section sets a rate of - 60 requests per minute for the 'swh_api' scope while exempting those - comming from the 127.0.0.0/8 ip network + Rate limiting can also be configured according to the type + of the input HTTP requests for fine grained tuning. + + For instance, the following YAML configuration section sets a rate of: + - 1 per minute for POST requests + - 60 per minute for other request types + for the 'swh_api' scope while exempting those comming from the + 127.0.0.0/8 ip network. .. code-block:: yaml throttling: scopes: swh_api: - limiter_rate: 60/min + limiter_rate: + default: 60/m + POST: 1/m exempted_networks: - 127.0.0.0/8 """ scope = None def __init__(self): super().__init__() self.exempted_networks = None - scopes = get_config()['throttling']['scopes'] - scope = scopes.get(self.scope) - if scope: - networks = scope.get('exempted_networks') - if networks: - self.exempted_networks = [ipaddress.ip_network(network) - for network in networks] + + def get_exempted_networks(self, scope_name): + if not self.exempted_networks: + scopes = get_config()['throttling']['scopes'] + scope = scopes.get(scope_name) + if scope: + networks = scope.get('exempted_networks') + if networks: + self.exempted_networks = [ipaddress.ip_network(network) + for network in networks] + return self.exempted_networks def allow_request(self, request, view): # class based view case if not self.scope: - request_allowed = \ - super(SwhWebRateThrottle, self).allow_request(request, view) + default_scope = getattr(view, self.scope_attr, None) + # check if there is a specific rate limiting associated + # to the request type + try: + request_scope = default_scope + '_' + request.method.lower() + setattr(view, self.scope_attr, request_scope) + request_allowed = \ + super(SwhWebRateThrottle, self).allow_request(request, view) # noqa + setattr(view, self.scope_attr, default_scope) + # use default rate limiting otherwise + except: + setattr(view, self.scope_attr, default_scope) + request_allowed = \ + super(SwhWebRateThrottle, self).allow_request(request, view) # noqa + # function based view case else: - self.rate = self.get_rate() + default_scope = self.scope + # check if there is a specific rate limiting associated + # to the request type + try: + self.scope = default_scope + '_' + request.method.lower() + self.rate = self.get_rate() + # use default rate limiting otherwise + except: + self.scope = default_scope + self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) request_allowed = \ super(ScopedRateThrottle, self).allow_request(request, view) + self.scope = default_scope + + exempted_networks = self.get_exempted_networks(default_scope) - if self.exempted_networks: + if exempted_networks: remote_address = ipaddress.ip_address(self.get_ident(request)) return any(remote_address in network - for network in self.exempted_networks) or \ + for network in exempted_networks) or \ request_allowed return request_allowed def throttle_scope(scope): """Decorator that allows the throttle scope of a DRF function based view to be set:: @api_view(['GET', ]) @throttle_scope('scope') def view(request): ... """ def decorator(func): SwhScopeRateThrottle = type( 'CustomScopeRateThrottle', (SwhWebRateThrottle,), {'scope': scope} ) func.throttle_classes = (SwhScopeRateThrottle, ) return func return decorator diff --git a/swh/web/settings/common.py b/swh/web/settings/common.py index 85065fde..c41d28af 100644 --- a/swh/web/settings/common.py +++ b/swh/web/settings/common.py @@ -1,189 +1,201 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ Django settings for swhweb project. Generated by 'django-admin startproject' using Django 1.11.3. For more information on this file, see https://docs.djangoproject.com/en/1.11/topics/settings/ For the full list of settings and their values, see https://docs.djangoproject.com/en/1.11/ref/settings/ """ import os from swh.web.config import get_config swh_web_config = get_config() # Build paths inside the project like this: os.path.join(BASE_DIR, ...) PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = swh_web_config['secret_key'] # SECURITY WARNING: don't run with debug turned on in production! DEBUG = swh_web_config['debug'] DEBUG_PROPAGATE_EXCEPTIONS = swh_web_config['debug'] ALLOWED_HOSTS = ['127.0.0.1', 'localhost'] + swh_web_config['allowed_hosts'] # Application definition INSTALLED_APPS = [ 'django.contrib.admin', 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', 'rest_framework', 'swh.web.api', 'swh.web.browse' ] MIDDLEWARE = [ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware' ] ROOT_URLCONF = 'swh.web.urls' TEMPLATES = [ { 'BACKEND': 'django.template.backends.django.DjangoTemplates', 'DIRS': [os.path.join(PROJECT_DIR, "../templates")], 'APP_DIRS': True, 'OPTIONS': { 'context_processors': [ 'django.template.context_processors.debug', 'django.template.context_processors.request', 'django.contrib.auth.context_processors.auth', 'django.contrib.messages.context_processors.messages', ], 'libraries': { 'swh_templatetags': 'swh.web.common.swh_templatetags', }, }, }, ] WSGI_APPLICATION = 'swh.web.wsgi.application' # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': os.path.join(PROJECT_DIR, 'db.sqlite3'), } } # Password validation # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ { 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', # noqa }, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ LANGUAGE_CODE = 'en-us' TIME_ZONE = 'UTC' USE_I18N = True USE_L10N = True USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ STATIC_URL = '/static/' STATICFILES_DIRS = [ os.path.join(PROJECT_DIR, "../static") ] INTERNAL_IPS = ['127.0.0.1'] throttle_rates = {} +http_requests = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH'] + throttling = swh_web_config['throttling'] for limiter_scope, limiter_conf in throttling['scopes'].items(): - throttle_rates[limiter_scope] = limiter_conf['limiter_rate'] + if 'default' in limiter_conf['limiter_rate']: + throttle_rates[limiter_scope] = limiter_conf['limiter_rate']['default'] + # for backward compatibility + else: + throttle_rates[limiter_scope] = limiter_conf['limiter_rate'] + # register sub scopes specific for HTTP request types + for http_request in http_requests: + if http_request in limiter_conf['limiter_rate']: + throttle_rates[limiter_scope + '_' + http_request.lower()] = \ + limiter_conf['limiter_rate'][http_request] + REST_FRAMEWORK = { 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'swh.web.api.renderers.YAMLRenderer', 'rest_framework.renderers.TemplateHTMLRenderer' ), 'DEFAULT_THROTTLE_CLASSES': ( 'swh.web.common.throttling.SwhWebRateThrottle', ), 'DEFAULT_THROTTLE_RATES': throttle_rates } LOGGING = { 'version': 1, 'disable_existing_loggers': False, 'filters': { 'require_debug_false': { '()': 'django.utils.log.RequireDebugFalse', }, 'require_debug_true': { '()': 'django.utils.log.RequireDebugTrue', }, }, 'handlers': { 'console': { 'level': 'DEBUG', 'filters': ['require_debug_true'], 'class': 'logging.StreamHandler', }, 'file': { 'level': 'INFO', 'filters': ['require_debug_false'], 'class': 'logging.FileHandler', 'filename': os.path.join(swh_web_config['log_dir'], 'swh-web.log'), }, }, 'loggers': { 'django': { 'handlers': ['console', 'file'], 'level': 'DEBUG' if DEBUG else 'INFO', 'propagate': True, } }, } diff --git a/swh/web/settings/tests.py b/swh/web/settings/tests.py index 22eaa9d9..18dc8c38 100644 --- a/swh/web/settings/tests.py +++ b/swh/web/settings/tests.py @@ -1,38 +1,56 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # flake8: noqa from swh.web.config import get_config scope1_limiter_rate = 3 +scope1_limiter_rate_post = 1 scope2_limiter_rate = 5 +scope2_limiter_rate_post = 2 +scope3_limiter_rate = 1 +scope3_limiter_rate_post = 1 swh_web_config = get_config() swh_web_config.update({ 'debug': True, 'secret_key': 'test', 'throttling': { 'cache_uri': None, 'scopes': { 'swh_api': { - 'limiter_rate': '60/min', + 'limiter_rate': { + 'default': '60/min' + }, 'exempted_networks': ['127.0.0.0/8'] }, 'scope1': { - 'limiter_rate': '%s/min' % scope1_limiter_rate + 'limiter_rate': { + 'default': '%s/min' % scope1_limiter_rate, + 'POST': '%s/min' % scope1_limiter_rate_post, + } }, 'scope2': { - 'limiter_rate': '%s/min' % scope2_limiter_rate, + 'limiter_rate': { + 'default': '%s/min' % scope2_limiter_rate, + 'POST': '%s/min' % scope2_limiter_rate_post + } + }, + 'scope3': { + 'limiter_rate': { + 'default': '%s/min' % scope3_limiter_rate, + 'POST': '%s/min' % scope3_limiter_rate_post + }, 'exempted_networks': ['127.0.0.0/8'] } } } }) from .common import * ALLOWED_HOSTS += ['testserver'] # noqa \ No newline at end of file diff --git a/swh/web/tests/common/test_throttling.py b/swh/web/tests/common/test_throttling.py index a902ef19..b5a764ea 100644 --- a/swh/web/tests/common/test_throttling.py +++ b/swh/web/tests/common/test_throttling.py @@ -1,66 +1,137 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.web.settings.tests import ( - scope1_limiter_rate, scope2_limiter_rate + scope1_limiter_rate, scope1_limiter_rate_post, + scope2_limiter_rate, scope2_limiter_rate_post, + scope3_limiter_rate, scope3_limiter_rate_post ) from django.test import TestCase from django.core.cache import cache from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.test import APIRequestFactory from rest_framework.decorators import api_view from nose.tools import istest from swh.web.common.throttling import ( SwhWebRateThrottle, throttle_scope ) -class MockView(APIView): +class MockViewScope1(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = 'scope1' def get(self, request): - return Response('foo') + return Response('foo_get') + def post(self, request): + return Response('foo_post') -@api_view(['GET', ]) + +@api_view(['GET', 'POST']) @throttle_scope('scope2') -def mock_view(request): - return Response('bar') +def mock_view_scope2(request): + if request.method == 'GET': + return Response('bar_get') + elif request.method == 'POST': + return Response('bar_post') + + +class MockViewScope3(APIView): + throttle_classes = (SwhWebRateThrottle,) + throttle_scope = 'scope3' + + def get(self, request): + return Response('foo_get') + + def post(self, request): + return Response('foo_post') + + +@api_view(['GET', 'POST']) +@throttle_scope('scope3') +def mock_view_scope3(request): + if request.method == 'GET': + return Response('bar_get') + elif request.method == 'POST': + return Response('bar_post') class ThrottlingTests(TestCase): def setUp(self): """ Reset the cache so that no throttles will be active """ cache.clear() self.factory = APIRequestFactory() @istest def scope1_requests_are_throttled(self): """ Ensure request rate is limited in scope1 """ request = self.factory.get('/') - for dummy in range(scope1_limiter_rate+1): - response = MockView.as_view()(request) + for _ in range(scope1_limiter_rate): + response = MockViewScope1.as_view()(request) + assert response.status_code == 200 + response = MockViewScope1.as_view()(request) + assert response.status_code == 429 + + request = self.factory.post('/') + for _ in range(scope1_limiter_rate_post): + response = MockViewScope1.as_view()(request) + assert response.status_code == 200 + response = MockViewScope1.as_view()(request) assert response.status_code == 429 @istest def scope2_requests_are_throttled(self): """ - Ensure request rate is not limited in scope2 as + Ensure request rate is limited in scope2 + """ + request = self.factory.get('/') + for _ in range(scope2_limiter_rate): + response = mock_view_scope2(request) + assert response.status_code == 200 + response = mock_view_scope2(request) + assert response.status_code == 429 + + request = self.factory.post('/') + for _ in range(scope2_limiter_rate_post): + response = mock_view_scope2(request) + assert response.status_code == 200 + response = mock_view_scope2(request) + assert response.status_code == 429 + + @istest + def scope3_requests_are_throttled_exempted(self): + """ + Ensure request rate is not limited in scope3 as requests coming from localhost are exempted from rate limit. """ request = self.factory.get('/') - for dummy in range(scope2_limiter_rate+1): - response = mock_view(request) - assert response.status_code == 200 + for _ in range(scope3_limiter_rate+1): + response = MockViewScope3.as_view()(request) + assert response.status_code == 200 + + request = self.factory.post('/') + for _ in range(scope3_limiter_rate_post+1): + response = MockViewScope3.as_view()(request) + assert response.status_code == 200 + + request = self.factory.get('/') + for _ in range(scope3_limiter_rate+1): + response = mock_view_scope3(request) + assert response.status_code == 200 + + request = self.factory.post('/') + for _ in range(scope3_limiter_rate_post+1): + response = mock_view_scope3(request) + assert response.status_code == 200