diff --git a/swh/web/api/views/graph.py b/swh/web/api/views/graph.py --- a/swh/web/api/views/graph.py +++ b/swh/web/api/views/graph.py @@ -18,7 +18,7 @@ from swh.web.api.apiurls import api_route from swh.web.api.renderers import PlainTextRenderer from swh.web.common import archive -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config API_GRAPH_PERM = "swh.web.api.graph" @@ -122,12 +122,13 @@ @api_route(r"/graph/(?P.+)/", "api-1-graph") @renderer_classes([PlainTextRenderer]) def api_graph_proxy(request: Request, graph_query: str) -> Response: - if not bool(request.user and request.user.is_authenticated): - return Response("Authentication credentials were not provided.", status=401) - if not request.user.has_perm(API_GRAPH_PERM): - return Response( - "You do not have permission to perform this action.", status=403 - ) + if request.get_host() != SWH_WEB_INTERNAL_SERVER_NAME: + if not bool(request.user and request.user.is_authenticated): + return Response("Authentication credentials were not provided.", status=401) + if not request.user.has_perm(API_GRAPH_PERM): + return Response( + "You do not have permission to perform this action.", status=403 + ) graph_query_url = get_config()["graph"]["server_url"] graph_query_url += graph_query if request.GET: diff --git a/swh/web/config.py b/swh/web/config.py --- a/swh/web/config.py +++ b/swh/web/config.py @@ -14,6 +14,8 @@ from swh.vault import get_vault from swh.web import settings +SWH_WEB_INTERNAL_SERVER_NAME = "archive.internal.softwareheritage.org" + SETTINGS_DIR = os.path.dirname(settings.__file__) DEFAULT_CONFIG = { diff --git a/swh/web/settings/tests.py b/swh/web/settings/tests.py --- a/swh/web/settings/tests.py +++ b/swh/web/settings/tests.py @@ -10,7 +10,7 @@ import os import sys -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config scope1_limiter_rate = 3 scope1_limiter_rate_post = 1 @@ -103,7 +103,7 @@ test_data["storage"], test_data["idx_storage"], test_data["search"] ) else: - ALLOWED_HOSTS += ["testserver"] + ALLOWED_HOSTS += ["testserver", SWH_WEB_INTERNAL_SERVER_NAME] # Silent DEBUG output when running unit tests LOGGING["handlers"]["console"]["level"] = "INFO" # type: ignore diff --git a/swh/web/tests/api/views/test_graph.py b/swh/web/tests/api/views/test_graph.py --- a/swh/web/tests/api/views/test_graph.py +++ b/swh/web/tests/api/views/test_graph.py @@ -12,13 +12,26 @@ from swh.model.identifiers import ORIGIN, SNAPSHOT, swhid from swh.web.api.views.graph import API_GRAPH_PERM from swh.web.common.utils import reverse -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config from swh.web.tests.auth.keycloak_mock import mock_keycloak from swh.web.tests.auth.sample_data import oidc_profile from swh.web.tests.strategies import origin from swh.web.tests.utils import check_http_get_response +def test_graph_endpoint_no_authentication_for_vpn_users(api_client, requests_mock): + graph_query = "stats" + url = reverse("api-1-graph", url_args={"graph_query": graph_query}) + requests_mock.get( + get_config()["graph"]["server_url"] + graph_query, + json={}, + headers={"Content-Type": "application/json"}, + ) + check_http_get_response( + api_client, url, status_code=200, server_name=SWH_WEB_INTERNAL_SERVER_NAME + ) + + def test_graph_endpoint_needs_authentication(api_client): url = reverse("api-1-graph", url_args={"graph_query": "stats"}) check_http_get_response(api_client, url, status_code=401) diff --git a/swh/web/tests/utils.py b/swh/web/tests/utils.py --- a/swh/web/tests/utils.py +++ b/swh/web/tests/utils.py @@ -39,6 +39,7 @@ status_code: int, content_type: str = "*/*", http_origin: Optional[str] = None, + server_name: Optional[str] = None, ) -> HttpResponse: """Helper function to check HTTP response for a GET request. @@ -53,7 +54,12 @@ The HTTP response """ return _assert_http_response( - response=client.get(url, HTTP_ACCEPT=content_type, HTTP_ORIGIN=http_origin), + response=client.get( + url, + HTTP_ACCEPT=content_type, + HTTP_ORIGIN=http_origin, + SERVER_NAME=server_name if server_name else "testserver", + ), status_code=status_code, content_type=content_type, )