diff --git a/swh/web/api/throttling.py b/swh/web/api/throttling.py --- a/swh/web/api/throttling.py +++ b/swh/web/api/throttling.py @@ -12,6 +12,7 @@ import rest_framework from rest_framework.throttling import ScopedRateThrottle +from swh.web.auth.utils import API_SAVE_ORIGIN_PERMISSION from swh.web.config import get_config APIView = TypeVar("APIView", bound="rest_framework.views.APIView") @@ -84,6 +85,14 @@ ] return self.exempted_networks + def get_scope(self, view: APIView): + if not self.scope: + # class based view case + return getattr(view, self.scope_attr, None) + else: + # function based view case + return self.scope + def allow_request(self, request: Request, view: APIView) -> bool: # class based view case if not self.scope: @@ -172,8 +181,12 @@ return (num_requests * self.NUM_REQUESTS_FACTOR, duration) def allow_request(self, request: Request, view: APIView) -> bool: - # no throttling for staff users or users with adequate permission if request.user.is_staff or request.user.has_perm(API_THROTTLING_EXEMPTED_PERM): + # no throttling for staff users or users with adequate permission + return True + scope = self.get_scope(view) + if scope == "save_origin" and request.user.has_perm(API_SAVE_ORIGIN_PERMISSION): + # no throttling on save origin endpoint for users with adequate permission return True return super().allow_request(request, view) diff --git a/swh/web/api/views/origin_save.py b/swh/web/api/views/origin_save.py --- a/swh/web/api/views/origin_save.py +++ b/swh/web/api/views/origin_save.py @@ -5,7 +5,11 @@ from swh.web.api.apidoc import api_doc, format_docstring from swh.web.api.apiurls import api_route -from swh.web.auth.utils import privileged_user +from swh.web.auth.utils import ( + API_SAVE_ORIGIN_PERMISSION, + SWH_AMBASSADOR_PERMISSION, + privileged_user, +) from swh.web.common.origin_save import ( create_save_origin_request, get_savable_visit_types, @@ -100,7 +104,10 @@ sor = create_save_origin_request( visit_type, origin_url, - privileged_user(request), + privileged_user( + request, + permissions=[SWH_AMBASSADOR_PERMISSION, API_SAVE_ORIGIN_PERMISSION], + ), user_id=request.user.id, **data, ) diff --git a/swh/web/auth/utils.py b/swh/web/auth/utils.py --- a/swh/web/auth/utils.py +++ b/swh/web/auth/utils.py @@ -4,15 +4,19 @@ # See top-level LICENSE file for more information from base64 import urlsafe_b64encode +from typing import List from cryptography.fernet import Fernet from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from django.http.request import HttpRequest + OIDC_SWH_WEB_CLIENT_ID = "swh-web" SWH_AMBASSADOR_PERMISSION = "swh.ambassador" +API_SAVE_ORIGIN_PERMISSION = "swh.web.api.save_origin" def _get_fernet(password: bytes, salt: bytes) -> Fernet: @@ -72,13 +76,20 @@ return _get_fernet(password, salt).decrypt(data) -def privileged_user(request) -> bool: +def privileged_user(request: HttpRequest, permissions: List[str] = []) -> bool: """Determine whether a user is authenticated and is a privileged one (e.g ambassador). This allows such user to have access to some more actions (e.g. bypass save code now - review, access to 'archives' type...) + review, access to 'archives' type...). + A user is considered as privileged if he is a staff member or has any permission + from those provided as parameters. + Args: + request: Input django HTTP request + permissions: list of permission names to determine if user is privileged or not + Returns: + Whether the user is privileged or not. """ user = request.user return user.is_authenticated and ( - user.is_staff or user.has_perm(SWH_AMBASSADOR_PERMISSION) + user.is_staff or any([user.has_perm(perm) for perm in permissions]) ) diff --git a/swh/web/misc/origin_save.py b/swh/web/misc/origin_save.py --- a/swh/web/misc/origin_save.py +++ b/swh/web/misc/origin_save.py @@ -9,7 +9,7 @@ from django.http import JsonResponse from django.shortcuts import render -from swh.web.auth.utils import privileged_user +from swh.web.auth.utils import SWH_AMBASSADOR_PERMISSION, privileged_user from swh.web.common.models import SaveOriginRequest from swh.web.common.origin_save import ( get_savable_visit_types, @@ -23,7 +23,9 @@ "misc/origin-save.html", { "heading": ("Request the saving of a software origin into the archive"), - "visit_types": get_savable_visit_types(privileged_user(request)), + "visit_types": get_savable_visit_types( + privileged_user(request, permissions=[SWH_AMBASSADOR_PERMISSION]) + ), }, ) diff --git a/swh/web/tests/api/views/test_origin_save.py b/swh/web/tests/api/views/test_origin_save.py --- a/swh/web/tests/api/views/test_origin_save.py +++ b/swh/web/tests/api/views/test_origin_save.py @@ -11,7 +11,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.utils import timezone -from swh.web.auth.utils import SWH_AMBASSADOR_PERMISSION +from swh.web.auth.utils import API_SAVE_ORIGIN_PERMISSION, SWH_AMBASSADOR_PERMISSION from swh.web.common.models import ( SAVE_REQUEST_ACCEPTED, SAVE_REQUEST_PENDING, @@ -34,6 +34,7 @@ check_api_get_responses, check_api_post_response, check_api_post_responses, + create_django_permission, ) pytestmark = pytest.mark.django_db @@ -332,24 +333,30 @@ _origin_url = "https://github.com/python/cpython" -def test_save_requests_rate_limit(api_client, mocker): - create_save_origin_request = mocker.patch( - "swh.web.api.views.origin_save.create_save_origin_request" +def test_save_requests_rate_limit(api_client, swh_scheduler): + + url = reverse( + "api-1-save-origin", + url_args={"visit_type": _visit_type, "origin_url": _origin_url}, + ) + + for _ in range(save_origin_rate_post): + check_api_post_response(api_client, url, status_code=200) + + check_api_post_response(api_client, url, status_code=429) + + +def test_save_requests_no_rate_limit_if_permission( + api_client, regular_user, swh_scheduler +): + + regular_user.user_permissions.add( + create_django_permission(API_SAVE_ORIGIN_PERMISSION) ) - def _save_request_dict(*args, **kwargs): - return { - "id": 1, - "visit_type": _visit_type, - "origin_url": _origin_url, - "save_request_date": datetime.now().isoformat(), - "save_request_status": SAVE_REQUEST_ACCEPTED, - "save_task_status": SAVE_TASK_NOT_YET_SCHEDULED, - "visit_date": None, - "visit_status": None, - } + assert regular_user.has_perm(API_SAVE_ORIGIN_PERMISSION) - create_save_origin_request.side_effect = _save_request_dict + api_client.force_login(regular_user) url = reverse( "api-1-save-origin", @@ -359,7 +366,37 @@ for _ in range(save_origin_rate_post): check_api_post_response(api_client, url, status_code=200) - check_api_post_response(api_client, url, status_code=429) + check_api_post_response(api_client, url, status_code=200) + + +def test_save_request_unknown_repo_with_permission( + api_client, regular_user, mocker, swh_scheduler +): + + regular_user.user_permissions.add( + create_django_permission(API_SAVE_ORIGIN_PERMISSION) + ) + + assert regular_user.has_perm(API_SAVE_ORIGIN_PERMISSION) + + api_client.force_login(regular_user) + + origin_url = "https://unkwownforge.org/user/repo" + check_created_save_request_status( + api_client, + mocker, + origin_url, + expected_request_status=SAVE_REQUEST_ACCEPTED, + expected_task_status=SAVE_TASK_NOT_YET_SCHEDULED, + ) + check_save_request_status( + api_client, + mocker, + swh_scheduler, + origin_url, + expected_request_status=SAVE_REQUEST_ACCEPTED, + expected_task_status=SAVE_TASK_NOT_YET_SCHEDULED, + ) def test_save_request_form_server_error(api_client, mocker):