diff --git a/swh/web/common/middlewares.py b/swh/web/common/middlewares.py index f8ede89c..19cc8f9d 100644 --- a/swh/web/common/middlewares.py +++ b/swh/web/common/middlewares.py @@ -1,46 +1,66 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from bs4 import BeautifulSoup from htmlmin import minify class HtmlPrettifyMiddleware(object): """ Django middleware for prettifying generated HTML in development mode. """ def __init__(self, get_response): self.get_response = get_response def __call__(self, request): response = self.get_response(request) if 'text/html' in response.get('Content-Type', ''): response.content = \ BeautifulSoup(response.content, 'lxml').prettify() return response class HtmlMinifyMiddleware(object): """ Django middleware for minifying generated HTML in production mode. """ def __init__(self, get_response=None): self.get_response = get_response def __call__(self, request): response = self.get_response(request) if 'text/html' in response.get('Content-Type', ''): try: minified_html = minify(response.content.decode('utf-8'), convert_charrefs=False) response.content = minified_html.encode('utf-8') except Exception: pass return response + + +class ThrottlingHeadersMiddleware(object): + """ + Django middleware for inserting rate limiting related + headers in HTTP response. + """ + + def __init__(self, get_response=None): + self.get_response = get_response + + def __call__(self, request): + resp = self.get_response(request) + if 'RateLimit-Limit' in request.META: + resp['X-RateLimit-Limit'] = request.META['RateLimit-Limit'] + if 'RateLimit-Remaining' in request.META: + resp['X-RateLimit-Remaining'] = request.META['RateLimit-Remaining'] + if 'RateLimit-Reset' in request.META: + resp['X-RateLimit-Reset'] = request.META['RateLimit-Reset'] + return resp diff --git a/swh/web/common/throttling.py b/swh/web/common/throttling.py index 95d3b90f..93d443c5 100644 --- a/swh/web/common/throttling.py +++ b/swh/web/common/throttling.py @@ -1,130 +1,141 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import ipaddress from rest_framework.throttling import ScopedRateThrottle from swh.web.config import get_config class SwhWebRateThrottle(ScopedRateThrottle): """Custom request rate limiter for DRF enabling to exempt specific networks specified in swh-web configuration. Requests are grouped into scopes. It enables to apply different requests rate limiting based on the scope name but also the input HTTP request types. To associate a scope to requests, one must add a 'throttle_scope' attribute when using a class based view, or call the 'throttle_scope' decorator when using a function based view. By default, requests do not have an associated scope and are not rate limited. Rate limiting can also be configured according to the type of the input HTTP requests for fine grained tuning. For instance, the following YAML configuration section sets a rate of: - 1 per minute for POST requests - 60 per minute for other request types for the 'swh_api' scope while exempting those coming from the 127.0.0.0/8 ip network. .. code-block:: yaml throttling: scopes: swh_api: limiter_rate: default: 60/m POST: 1/m exempted_networks: - 127.0.0.0/8 """ scope = None def __init__(self): super().__init__() self.exempted_networks = None def get_exempted_networks(self, scope_name): if not self.exempted_networks: scopes = get_config()['throttling']['scopes'] scope = scopes.get(scope_name) if scope: networks = scope.get('exempted_networks') if networks: self.exempted_networks = [ipaddress.ip_network(network) for network in networks] return self.exempted_networks def allow_request(self, request, view): # class based view case if not self.scope: default_scope = getattr(view, self.scope_attr, None) # check if there is a specific rate limiting associated # to the request type try: request_scope = default_scope + '_' + request.method.lower() setattr(view, self.scope_attr, request_scope) request_allowed = \ super(SwhWebRateThrottle, self).allow_request(request, view) # noqa setattr(view, self.scope_attr, default_scope) # use default rate limiting otherwise except Exception: setattr(view, self.scope_attr, default_scope) request_allowed = \ super(SwhWebRateThrottle, self).allow_request(request, view) # noqa # function based view case else: default_scope = self.scope # check if there is a specific rate limiting associated # to the request type try: self.scope = default_scope + '_' + request.method.lower() self.rate = self.get_rate() # use default rate limiting otherwise except Exception: self.scope = default_scope self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) + request_allowed = \ super(ScopedRateThrottle, self).allow_request(request, view) self.scope = default_scope exempted_networks = self.get_exempted_networks(default_scope) + exempted_ip = False if exempted_networks: remote_address = ipaddress.ip_address(self.get_ident(request)) - return any(remote_address in network - for network in exempted_networks) or \ - request_allowed + exempted_ip = any(remote_address in network + for network in exempted_networks) + request_allowed = exempted_ip or request_allowed + + # set throttling related data in the request metadata + # in order for the ThrottlingHeadersMiddleware to + # add X-RateLimit-* headers in the HTTP response + if not exempted_ip and hasattr(self, 'history'): + hit_count = len(self.history) + request.META['RateLimit-Limit'] = self.num_requests + request.META['RateLimit-Remaining'] = self.num_requests - hit_count + request.META['RateLimit-Reset'] = int(self.now + self.wait()) return request_allowed def throttle_scope(scope): """Decorator that allows the throttle scope of a DRF function based view to be set:: @api_view(['GET', ]) @throttle_scope('scope') def view(request): ... """ def decorator(func): SwhScopeRateThrottle = type( 'CustomScopeRateThrottle', (SwhWebRateThrottle,), {'scope': scope} ) func.throttle_classes = (SwhScopeRateThrottle, ) return func return decorator diff --git a/swh/web/settings/common.py b/swh/web/settings/common.py index bccd6465..d1fc0b12 100644 --- a/swh/web/settings/common.py +++ b/swh/web/settings/common.py @@ -1,213 +1,213 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information """ Django common settings for swh-web. """ import os from swh.web.config import get_config swh_web_config = get_config() # Build paths inside the project like this: os.path.join(BASE_DIR, ...) PROJECT_DIR = os.path.dirname(os.path.abspath(__file__)) # Quick-start development settings - unsuitable for production # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = swh_web_config['secret_key'] # SECURITY WARNING: don't run with debug turned on in production! DEBUG = swh_web_config['debug'] DEBUG_PROPAGATE_EXCEPTIONS = swh_web_config['debug'] ALLOWED_HOSTS = ['127.0.0.1', 'localhost'] + swh_web_config['allowed_hosts'] # Application definition INSTALLED_APPS = [ 'django.contrib.admin', 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', 'rest_framework', 'swh.web.api', 'swh.web.browse', 'webpack_loader', 'django_js_reverse' ] MIDDLEWARE = [ 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware' + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'swh.web.common.middlewares.ThrottlingHeadersMiddleware' ] # Compress all assets (static ones and dynamically generated html) # served by django in a local development environement context. # In a production environment, assets compression will be directly # handled by web servers like apache or nginx. if swh_web_config['debug']: MIDDLEWARE.insert(0, 'django.middleware.gzip.GZipMiddleware') ROOT_URLCONF = 'swh.web.urls' TEMPLATES = [ { 'BACKEND': 'django.template.backends.django.DjangoTemplates', 'DIRS': [os.path.join(PROJECT_DIR, "../templates")], 'APP_DIRS': True, 'OPTIONS': { 'context_processors': [ 'django.template.context_processors.debug', 'django.template.context_processors.request', 'django.contrib.auth.context_processors.auth', 'django.contrib.messages.context_processors.messages', ], 'libraries': { 'swh_templatetags': 'swh.web.common.swh_templatetags', }, }, }, ] WSGI_APPLICATION = 'swh.web.wsgi.application' # Database # https://docs.djangoproject.com/en/1.11/ref/settings/#databases DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': os.path.join(PROJECT_DIR, 'db.sqlite3'), } } # Password validation # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators AUTH_PASSWORD_VALIDATORS = [ { 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', # noqa }, { 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', # noqa }, ] # Internationalization # https://docs.djangoproject.com/en/1.11/topics/i18n/ LANGUAGE_CODE = 'en-us' TIME_ZONE = 'UTC' USE_I18N = True USE_L10N = True USE_TZ = True # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/1.11/howto/static-files/ STATIC_URL = '/static/' STATICFILES_DIRS = [ os.path.join(PROJECT_DIR, "../static") ] INTERNAL_IPS = ['127.0.0.1'] throttle_rates = {} http_requests = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'OPTIONS', 'PATCH'] throttling = swh_web_config['throttling'] for limiter_scope, limiter_conf in throttling['scopes'].items(): if 'default' in limiter_conf['limiter_rate']: throttle_rates[limiter_scope] = limiter_conf['limiter_rate']['default'] # for backward compatibility else: throttle_rates[limiter_scope] = limiter_conf['limiter_rate'] # register sub scopes specific for HTTP request types for http_request in http_requests: if http_request in limiter_conf['limiter_rate']: throttle_rates[limiter_scope + '_' + http_request.lower()] = \ limiter_conf['limiter_rate'][http_request] - REST_FRAMEWORK = { 'DEFAULT_RENDERER_CLASSES': ( 'rest_framework.renderers.JSONRenderer', 'swh.web.api.renderers.YAMLRenderer', 'rest_framework.renderers.TemplateHTMLRenderer' ), 'DEFAULT_THROTTLE_CLASSES': ( 'swh.web.common.throttling.SwhWebRateThrottle', ), 'DEFAULT_THROTTLE_RATES': throttle_rates } LOGGING = { 'version': 1, 'disable_existing_loggers': False, 'filters': { 'require_debug_false': { '()': 'django.utils.log.RequireDebugFalse', }, 'require_debug_true': { '()': 'django.utils.log.RequireDebugTrue', }, }, 'handlers': { 'console': { 'level': 'DEBUG', 'filters': ['require_debug_true'], 'class': 'logging.StreamHandler', }, 'file': { 'level': 'INFO', 'filters': ['require_debug_false'], 'class': 'logging.FileHandler', 'filename': os.path.join(swh_web_config['log_dir'], 'swh-web.log'), }, }, 'loggers': { 'django': { 'handlers': ['console', 'file'], 'level': 'DEBUG' if DEBUG else 'INFO', 'propagate': True, } }, } WEBPACK_LOADER = { 'DEFAULT': { 'CACHE': not DEBUG, 'BUNDLE_DIR_NAME': './', 'STATS_FILE': os.path.join(PROJECT_DIR, '../static/webpack-stats.json'), # noqa 'POLL_INTERVAL': 0.1, 'TIMEOUT': None, 'IGNORE': ['.+\.hot-update.js', '.+\.map'] } } diff --git a/swh/web/tests/browse/views/test_origin.py b/swh/web/tests/browse/views/test_origin.py index 413c12f8..2efdc84f 100644 --- a/swh/web/tests/browse/views/test_origin.py +++ b/swh/web/tests/browse/views/test_origin.py @@ -1,833 +1,833 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information # flake8: noqa from unittest.mock import patch from nose.tools import istest, nottest from django.test import TestCase from django.utils.html import escape from swh.web.common.exc import NotFoundExc from swh.web.common.utils import ( reverse, gen_path_info, format_utc_iso_date, parse_timestamp, get_swh_persistent_id ) from swh.web.tests.testbase import SWHWebTestBase from .data.origin_test_data import ( origin_info_test_data, origin_visits_test_data, stub_content_origin_info, stub_content_origin_visit_id, stub_content_origin_visit_unix_ts, stub_content_origin_visit_iso_date, stub_content_origin_branch, stub_content_origin_visits, stub_content_origin_snapshot, stub_origin_info, stub_visit_id, stub_origin_visits, stub_origin_snapshot, stub_origin_root_directory_entries, stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_path, stub_origin_sub_directory_entries, stub_visit_unix_ts, stub_visit_iso_date ) from .data.content_test_data import ( stub_content_root_dir, stub_content_text_data, stub_content_text_path ) stub_origin_info_no_type = dict(stub_origin_info) stub_origin_info_no_type['type'] = None class SwhBrowseOriginTest(SWHWebTestBase, TestCase): @patch('swh.web.browse.views.origin.get_origin_info') @patch('swh.web.browse.views.origin.get_origin_visits') @patch('swh.web.browse.views.origin.service') @istest def origin_visits_browse(self, mock_service, mock_get_origin_visits, mock_get_origin_info): mock_service.lookup_origin.return_value = origin_info_test_data mock_get_origin_info.return_value = origin_info_test_data mock_get_origin_visits.return_value = origin_visits_test_data url = reverse('browse-origin-visits', kwargs={'origin_type': origin_info_test_data['type'], 'origin_url': origin_info_test_data['url']}) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('origin-visits.html') url = reverse('browse-origin-visits', kwargs={'origin_url': origin_info_test_data['url']}) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('origin-visits.html') @nottest def origin_content_view_test(self, origin_info, origin_visits, origin_branches, origin_releases, origin_branch, root_dir_sha1, content_sha1, content_sha1_git, content_path, content_data, content_language, visit_id=None, timestamp=None): url_args = {'origin_type': origin_info['type'], 'origin_url': origin_info['url'], 'path': content_path} if not visit_id: visit_id = origin_visits[-1]['visit'] query_params = {} if timestamp: url_args['timestamp'] = timestamp if visit_id: query_params['visit_id'] = visit_id url = reverse('browse-origin-content', kwargs=url_args, query_params=query_params) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('content.html') self.assertContains(resp, '' % content_language) self.assertContains(resp, escape(content_data)) split_path = content_path.split('/') filename = split_path[-1] path = content_path.replace(filename, '')[:-1] path_info = gen_path_info(path) del url_args['path'] if timestamp: url_args['timestamp'] = \ format_utc_iso_date(parse_timestamp(timestamp).isoformat(), '%Y-%m-%dT%H:%M:%S') root_dir_url = reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) self.assertContains(resp, '
  • ', count=len(path_info)+1) self.assertContains(resp, '%s' % (root_dir_url, root_dir_sha1[:7])) for p in path_info: url_args['path'] = p['path'] dir_url = reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) self.assertContains(resp, '%s' % (dir_url, p['name'])) self.assertContains(resp, '
  • %s
  • ' % filename) query_string = 'sha1_git:' + content_sha1 url_raw = reverse('browse-content-raw', kwargs={'query_string': query_string}, query_params={'filename': filename}) self.assertContains(resp, url_raw) del url_args['path'] origin_branches_url = \ reverse('browse-origin-branches', kwargs=url_args, query_params=query_params) self.assertContains(resp, 'Branches (%s)' % (origin_branches_url, len(origin_branches))) origin_releases_url = \ reverse('browse-origin-releases', kwargs=url_args, query_params=query_params) self.assertContains(resp, 'Releases (%s)' % (origin_releases_url, len(origin_releases))) self.assertContains(resp, '
  • ', count=len(origin_branches)) url_args['path'] = content_path for branch in origin_branches: query_params['branch'] = branch['name'] root_dir_branch_url = \ reverse('browse-origin-content', kwargs=url_args, query_params=query_params) self.assertContains(resp, '' % root_dir_branch_url) self.assertContains(resp, '
  • ', count=len(origin_releases)) query_params['branch'] = None for release in origin_releases: query_params['release'] = release['name'] root_dir_release_url = \ reverse('browse-origin-content', kwargs=url_args, query_params=query_params) self.assertContains(resp, '' % root_dir_release_url) del url_args['origin_type'] url = reverse('browse-origin-content', kwargs=url_args, query_params=query_params) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('content.html') swh_cnt_id = get_swh_persistent_id('content', content_sha1_git) swh_cnt_id_url = reverse('browse-swh-id', kwargs={'swh_id': swh_cnt_id}) self.assertContains(resp, swh_cnt_id) self.assertContains(resp, swh_cnt_id_url) @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.views.utils.snapshot_context.service') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.utils.snapshot_context.request_content') @istest def origin_content_view(self, mock_request_content, mock_utils_service, mock_service, mock_get_origin_visit_snapshot, mock_get_origin_visits): stub_content_text_sha1 = stub_content_text_data['checksums']['sha1'] stub_content_text_sha1_git = stub_content_text_data['checksums']['sha1_git'] mock_get_origin_visits.return_value = stub_content_origin_visits mock_get_origin_visit_snapshot.return_value = stub_content_origin_snapshot mock_service.lookup_directory_with_path.return_value = \ {'target': stub_content_text_sha1} mock_request_content.return_value = stub_content_text_data mock_utils_service.lookup_origin.return_value = stub_content_origin_info self.origin_content_view_test(stub_content_origin_info, stub_content_origin_visits, stub_content_origin_snapshot[0], stub_content_origin_snapshot[1], stub_content_origin_branch, stub_content_root_dir, stub_content_text_sha1, stub_content_text_sha1_git, stub_content_text_path, stub_content_text_data['raw_data'], 'cpp') self.origin_content_view_test(stub_content_origin_info, stub_content_origin_visits, stub_content_origin_snapshot[0], stub_content_origin_snapshot[1], stub_content_origin_branch, stub_content_root_dir, stub_content_text_sha1, stub_content_text_sha1_git, stub_content_text_path, stub_content_text_data['raw_data'], 'cpp', visit_id=stub_content_origin_visit_id) self.origin_content_view_test(stub_content_origin_info, stub_content_origin_visits, stub_content_origin_snapshot[0], stub_content_origin_snapshot[1], stub_content_origin_branch, stub_content_root_dir, stub_content_text_sha1, stub_content_text_sha1_git, stub_content_text_path, stub_content_text_data['raw_data'], 'cpp', timestamp=stub_content_origin_visit_unix_ts) self.origin_content_view_test(stub_content_origin_info, stub_content_origin_visits, stub_content_origin_snapshot[0], stub_content_origin_snapshot[1], stub_content_origin_branch, stub_content_root_dir, stub_content_text_sha1, stub_content_text_sha1_git, stub_content_text_path, stub_content_text_data['raw_data'], 'cpp', timestamp=stub_content_origin_visit_iso_date) @nottest def origin_directory_view(self, origin_info, origin_visits, origin_branches, origin_releases, origin_branch, root_directory_sha1, directory_entries, visit_id=None, timestamp=None, path=None): dirs = [e for e in directory_entries if e['type'] == 'dir'] files = [e for e in directory_entries if e['type'] == 'file'] if not visit_id: visit_id = origin_visits[-1]['visit'] url_args = {'origin_type': origin_info['type'], 'origin_url': origin_info['url']} query_params = {} if timestamp: url_args['timestamp'] = timestamp else: query_params['visit_id'] = visit_id if path: url_args['path'] = path url = reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('directory.html') self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('directory.html') self.assertContains(resp, '', count=len(dirs)) self.assertContains(resp, '', count=len(files)) if timestamp: url_args['timestamp'] = \ format_utc_iso_date(parse_timestamp(timestamp).isoformat(), '%Y-%m-%dT%H:%M:%S') for d in dirs: dir_path = d['name'] if path: dir_path = "%s/%s" % (path, d['name']) dir_url_args = dict(url_args) dir_url_args['path'] = dir_path dir_url = reverse('browse-origin-directory', kwargs=dir_url_args, query_params=query_params) self.assertContains(resp, dir_url) for f in files: file_path = f['name'] if path: file_path = "%s/%s" % (path, f['name']) file_url_args = dict(url_args) file_url_args['path'] = file_path file_url = reverse('browse-origin-content', kwargs=file_url_args, query_params=query_params) self.assertContains(resp, file_url) if 'path' in url_args: del url_args['path'] root_dir_branch_url = \ reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) nb_bc_paths = 1 if path: nb_bc_paths = len(path.split('/')) + 1 self.assertContains(resp, '
  • ', count=nb_bc_paths) self.assertContains(resp, '%s' % (root_dir_branch_url, root_directory_sha1[:7])) origin_branches_url = \ reverse('browse-origin-branches', kwargs=url_args, query_params=query_params) self.assertContains(resp, 'Branches (%s)' % (origin_branches_url, len(origin_branches))) origin_releases_url = \ reverse('browse-origin-releases', kwargs=url_args, query_params=query_params) self.assertContains(resp, 'Releases (%s)' % (origin_releases_url, len(origin_releases))) if path: url_args['path'] = path self.assertContains(resp, '
  • ', count=len(origin_branches)) for branch in origin_branches: query_params['branch'] = branch['name'] root_dir_branch_url = \ reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) self.assertContains(resp, '' % root_dir_branch_url) self.assertContains(resp, '
  • ', count=len(origin_releases)) query_params['branch'] = None for release in origin_releases: query_params['release'] = release['name'] root_dir_release_url = \ reverse('browse-origin-directory', kwargs=url_args, query_params=query_params) self.assertContains(resp, '' % root_dir_release_url) self.assertContains(resp, 'vault-cook-directory') self.assertContains(resp, 'vault-cook-revision') swh_dir_id = get_swh_persistent_id('directory', directory_entries[0]['dir_id']) # noqa swh_dir_id_url = reverse('browse-swh-id', kwargs={'swh_id': swh_dir_id}) self.assertContains(resp, swh_dir_id) self.assertContains(resp, swh_dir_id_url) @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.origin.service') @istest - def test_origin_root_directory_view(self, mock_origin_service, + def origin_root_directory_view(self, mock_origin_service, mock_utils_service, mock_get_origin_visit_snapshot, mock_get_origin_visits): mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_utils_service.lookup_directory.return_value = \ stub_origin_root_directory_entries mock_utils_service.lookup_origin.return_value = stub_origin_info self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, visit_id=stub_visit_id) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, timestamp=stub_visit_unix_ts) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, timestamp=stub_visit_iso_date) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, visit_id=stub_visit_id) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, timestamp=stub_visit_unix_ts) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_root_directory_entries, timestamp=stub_visit_iso_date) @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.utils.snapshot_context.service') @istest def origin_sub_directory_view(self, mock_origin_service, mock_utils_service, mock_get_origin_visit_snapshot, mock_get_origin_visits): mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_utils_service.lookup_directory.return_value = \ stub_origin_sub_directory_entries mock_origin_service.lookup_directory_with_path.return_value = \ {'target': stub_origin_sub_directory_entries[0]['dir_id'], 'type' : 'dir'} mock_utils_service.lookup_origin.return_value = stub_origin_info self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, visit_id=stub_visit_id, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, timestamp=stub_visit_unix_ts, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, timestamp=stub_visit_iso_date, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, visit_id=stub_visit_id, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, timestamp=stub_visit_unix_ts, path=stub_origin_sub_directory_path) self.origin_directory_view(stub_origin_info_no_type, stub_origin_visits, stub_origin_snapshot[0], stub_origin_snapshot[1], stub_origin_master_branch, stub_origin_root_directory_sha1, stub_origin_sub_directory_entries, timestamp=stub_visit_iso_date, path=stub_origin_sub_directory_path) @patch('swh.web.browse.views.utils.snapshot_context.request_content') @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.origin.service') @patch('swh.web.browse.views.utils.snapshot_context.service') @patch('swh.web.browse.views.origin.get_origin_info') @istest - def test_origin_request_errors(self, mock_get_origin_info, + def origin_request_errors(self, mock_get_origin_info, mock_snapshot_service, mock_origin_service, mock_utils_service, mock_get_origin_visit_snapshot, mock_get_origin_visits, mock_request_content): mock_get_origin_info.side_effect = \ NotFoundExc('origin not found') url = reverse('browse-origin-visits', kwargs={'origin_type': 'foo', 'origin_url': 'bar'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertContains(resp, 'origin not found', status_code=404) mock_utils_service.lookup_origin.side_effect = None mock_utils_service.lookup_origin.return_value = origin_info_test_data mock_get_origin_visits.return_value = [] url = reverse('browse-origin-directory', kwargs={'origin_type': 'foo', 'origin_url': 'bar'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertContains(resp, "No SWH visit", status_code=404) mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.side_effect = \ NotFoundExc('visit not found') url = reverse('browse-origin-directory', kwargs={'origin_type': 'foo', 'origin_url': 'bar'}, query_params={'visit_id': len(stub_origin_visits)+1}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertRegex(resp.content.decode('utf-8'), 'Visit.*not found') mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.side_effect = None mock_get_origin_visit_snapshot.return_value = ([], []) url = reverse('browse-origin-directory', kwargs={'origin_type': 'foo', 'origin_url': 'bar'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertRegex(resp.content.decode('utf-8'), 'Origin.*has an empty list of branches') mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_utils_service.lookup_directory.side_effect = \ NotFoundExc('Directory not found') url = reverse('browse-origin-directory', kwargs={'origin_type': 'foo', 'origin_url': 'bar'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertContains(resp, 'Directory not found', status_code=404) mock_origin_service.lookup_origin.side_effect = None mock_origin_service.lookup_origin.return_value = origin_info_test_data mock_get_origin_visits.return_value = [] url = reverse('browse-origin-content', kwargs={'origin_type': 'foo', 'origin_url': 'bar', 'path': 'foo'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertContains(resp, "No SWH visit", status_code=404) mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.side_effect = \ NotFoundExc('visit not found') url = reverse('browse-origin-content', kwargs={'origin_type': 'foo', 'origin_url': 'bar', 'path': 'foo'}, query_params={'visit_id': len(stub_origin_visits)+1}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertRegex(resp.content.decode('utf-8'), 'Visit.*not found') mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.side_effect = None mock_get_origin_visit_snapshot.return_value = ([], []) url = reverse('browse-origin-content', kwargs={'origin_type': 'foo', 'origin_url': 'bar', 'path': 'baz'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertRegex(resp.content.decode('utf-8'), 'Origin.*has an empty list of branches') mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_snapshot_service.lookup_directory_with_path.return_value = \ {'target': stub_content_text_data['checksums']['sha1']} mock_request_content.side_effect = \ NotFoundExc('Content not found') url = reverse('browse-origin-content', kwargs={'origin_type': 'foo', 'origin_url': 'bar', 'path': 'baz'}) resp = self.client.get(url) self.assertEquals(resp.status_code, 404) self.assertTemplateUsed('error.html') self.assertContains(resp, 'Content not found', status_code=404) @nottest def origin_branches_test(self, origin_info, origin_snapshot): url_args = {'origin_type': origin_info['type'], 'origin_url': origin_info['url']} url = reverse('browse-origin-branches', kwargs=url_args) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('branches.html') origin_branches = origin_snapshot[0] origin_releases = origin_snapshot[1] origin_branches_url = \ reverse('browse-origin-branches', kwargs=url_args) self.assertContains(resp, 'Branches (%s)' % (origin_branches_url, len(origin_branches))) origin_releases_url = \ reverse('browse-origin-releases', kwargs=url_args) self.assertContains(resp, 'Releases (%s)' % (origin_releases_url, len(origin_releases))) self.assertContains(resp, '', count=len(origin_branches)) for branch in origin_branches: browse_branch_url = reverse('browse-origin-directory', kwargs={'origin_type': origin_info['type'], 'origin_url': origin_info['url']}, query_params={'branch': branch['name']}) self.assertContains(resp, '%s' % (escape(browse_branch_url), branch['name'])) browse_revision_url = reverse('browse-revision', kwargs={'sha1_git': branch['revision']}, query_params={'origin_type': origin_info['type'], 'origin': origin_info['url']}) self.assertContains(resp, '%s' % (escape(browse_revision_url), branch['revision'][:7])) @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.origin.service') @istest def origin_branches(self, mock_origin_service, mock_utils_service, mock_get_origin_visit_snapshot, mock_get_origin_visits): mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_utils_service.lookup_origin.return_value = stub_origin_info self.origin_branches_test(stub_origin_info, stub_origin_snapshot) self.origin_branches_test(stub_origin_info_no_type, stub_origin_snapshot) @nottest def origin_releases_test(self, origin_info, origin_snapshot): url_args = {'origin_type': origin_info['type'], 'origin_url': origin_info['url']} url = reverse('browse-origin-releases', kwargs=url_args) resp = self.client.get(url) self.assertEquals(resp.status_code, 200) self.assertTemplateUsed('releases.html') origin_branches = origin_snapshot[0] origin_releases = origin_snapshot[1] origin_branches_url = \ reverse('browse-origin-branches', kwargs=url_args) self.assertContains(resp, 'Branches (%s)' % (origin_branches_url, len(origin_branches))) origin_releases_url = \ reverse('browse-origin-releases', kwargs=url_args) self.assertContains(resp, 'Releases (%s)' % (origin_releases_url, len(origin_releases))) self.assertContains(resp, '', count=len(origin_releases)) for release in origin_releases: browse_release_url = reverse('browse-release', kwargs={'sha1_git': release['id']}, query_params={'origin_type': origin_info['type'], 'origin': origin_info['url']}) self.assertContains(resp, '%s' % (escape(browse_release_url), release['name'])) @patch('swh.web.browse.utils.get_origin_visits') @patch('swh.web.browse.utils.get_origin_visit_snapshot') @patch('swh.web.browse.utils.service') @patch('swh.web.browse.views.origin.service') @istest def origin_releases(self, mock_origin_service, mock_utils_service, mock_get_origin_visit_snapshot, mock_get_origin_visits): mock_get_origin_visits.return_value = stub_origin_visits mock_get_origin_visit_snapshot.return_value = stub_origin_snapshot mock_utils_service.lookup_origin.return_value = stub_origin_info self.origin_releases_test(stub_origin_info, stub_origin_snapshot) self.origin_releases_test(stub_origin_info_no_type, stub_origin_snapshot) diff --git a/swh/web/tests/common/test_throttling.py b/swh/web/tests/common/test_throttling.py index a4175cba..04608b89 100644 --- a/swh/web/tests/common/test_throttling.py +++ b/swh/web/tests/common/test_throttling.py @@ -1,137 +1,161 @@ # Copyright (C) 2017-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.web.settings.tests import ( scope1_limiter_rate, scope1_limiter_rate_post, scope2_limiter_rate, scope2_limiter_rate_post, scope3_limiter_rate, scope3_limiter_rate_post ) from django.test import TestCase +from django.conf.urls import url from django.core.cache import cache +from django.test.utils import override_settings from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.test import APIRequestFactory from rest_framework.decorators import api_view -from nose.tools import istest +from nose.tools import istest, nottest from swh.web.common.throttling import ( SwhWebRateThrottle, throttle_scope ) class MockViewScope1(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = 'scope1' def get(self, request): return Response('foo_get') def post(self, request): return Response('foo_post') @api_view(['GET', 'POST']) @throttle_scope('scope2') def mock_view_scope2(request): if request.method == 'GET': return Response('bar_get') elif request.method == 'POST': return Response('bar_post') class MockViewScope3(APIView): throttle_classes = (SwhWebRateThrottle,) throttle_scope = 'scope3' def get(self, request): return Response('foo_get') def post(self, request): return Response('foo_post') @api_view(['GET', 'POST']) @throttle_scope('scope3') def mock_view_scope3(request): if request.method == 'GET': return Response('bar_get') elif request.method == 'POST': return Response('bar_post') +urlpatterns = [ + url(r'^scope1_class$', MockViewScope1.as_view()), + url(r'^scope2_func$', mock_view_scope2), + url(r'^scope3_class$', MockViewScope3.as_view()), + url(r'^scope3_func$', mock_view_scope3) +] + + +@override_settings(ROOT_URLCONF=__name__) class ThrottlingTests(TestCase): def setUp(self): """ Reset the cache so that no throttles will be active """ cache.clear() self.factory = APIRequestFactory() + @nottest + def check_response(self, response, status_code, + limit=None, remaining=None): + assert response.status_code == status_code + if limit is not None: + assert response['X-RateLimit-Limit'] == str(limit) + else: + assert 'X-RateLimit-Limit' not in response + if remaining is not None: + assert response['X-RateLimit-Remaining'] == str(remaining) + else: + assert 'X-RateLimit-Remaining' not in response + @istest def scope1_requests_are_throttled(self): """ Ensure request rate is limited in scope1 """ - request = self.factory.get('/') - for _ in range(scope1_limiter_rate): - response = MockViewScope1.as_view()(request) - assert response.status_code == 200 - response = MockViewScope1.as_view()(request) - assert response.status_code == 429 - - request = self.factory.post('/') - for _ in range(scope1_limiter_rate_post): - response = MockViewScope1.as_view()(request) - assert response.status_code == 200 - response = MockViewScope1.as_view()(request) - assert response.status_code == 429 + for i in range(scope1_limiter_rate): + response = self.client.get('/scope1_class') + self.check_response(response, 200, scope1_limiter_rate, + scope1_limiter_rate - i - 1) + + response = self.client.get('/scope1_class') + self.check_response(response, 429, scope1_limiter_rate, 0) + + for i in range(scope1_limiter_rate_post): + response = self.client.post('/scope1_class') + self.check_response(response, 200, scope1_limiter_rate_post, + scope1_limiter_rate_post - i - 1) + + response = self.client.post('/scope1_class') + self.check_response(response, 429, scope1_limiter_rate_post, 0) @istest def scope2_requests_are_throttled(self): """ Ensure request rate is limited in scope2 """ - request = self.factory.get('/') - for _ in range(scope2_limiter_rate): - response = mock_view_scope2(request) - assert response.status_code == 200 - response = mock_view_scope2(request) - assert response.status_code == 429 - - request = self.factory.post('/') - for _ in range(scope2_limiter_rate_post): - response = mock_view_scope2(request) - assert response.status_code == 200 - response = mock_view_scope2(request) - assert response.status_code == 429 + for i in range(scope2_limiter_rate): + response = self.client.get('/scope2_func') + self.check_response(response, 200, scope2_limiter_rate, + scope2_limiter_rate - i - 1) + + response = self.client.get('/scope2_func') + self.check_response(response, 429, scope2_limiter_rate, 0) + + for i in range(scope2_limiter_rate_post): + response = self.client.post('/scope2_func') + self.check_response(response, 200, scope2_limiter_rate_post, + scope2_limiter_rate_post - i - 1) + + response = self.client.post('/scope2_func') + self.check_response(response, 429, scope2_limiter_rate_post, 0) @istest def scope3_requests_are_throttled_exempted(self): """ Ensure request rate is not limited in scope3 as requests coming from localhost are exempted from rate limit. """ - request = self.factory.get('/') for _ in range(scope3_limiter_rate+1): - response = MockViewScope3.as_view()(request) - assert response.status_code == 200 + response = self.client.get('/scope3_class') + self.check_response(response, 200) - request = self.factory.post('/') for _ in range(scope3_limiter_rate_post+1): - response = MockViewScope3.as_view()(request) - assert response.status_code == 200 + response = self.client.post('/scope3_class') + self.check_response(response, 200) - request = self.factory.get('/') for _ in range(scope3_limiter_rate+1): - response = mock_view_scope3(request) - assert response.status_code == 200 + response = self.client.get('/scope3_func') + self.check_response(response, 200) - request = self.factory.post('/') for _ in range(scope3_limiter_rate_post+1): - response = mock_view_scope3(request) - assert response.status_code == 200 + response = self.client.post('/scope3_func') + self.check_response(response, 200)