diff --git a/swh/web/api/views/graph.py b/swh/web/api/views/graph.py --- a/swh/web/api/views/graph.py +++ b/swh/web/api/views/graph.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021 The Software Heritage developers +# Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -6,9 +6,11 @@ from distutils.util import strtobool import json from typing import Dict, Iterator, Union +from urllib.parse import unquote, urlparse, urlunparse import requests +from django.http import QueryDict from django.http.response import StreamingHttpResponse from rest_framework.decorators import renderer_classes from rest_framework.renderers import JSONRenderer @@ -136,9 +138,19 @@ "You do not have permission to perform this action.", status=403 ) graph_query_url = get_config()["graph"]["server_url"] + + graph_query = unquote(graph_query) graph_query_url += graph_query - if request.GET: - graph_query_url += "?" + request.GET.urlencode(safe="/;:") + + parsed_url = urlparse(graph_query_url) + query_dict = QueryDict(parsed_url.query, mutable=True) + query_dict.update(request.GET) + + if query_dict: + graph_query_url = urlunparse( + parsed_url._replace(query=query_dict.urlencode(safe="/;:")) + ) + response = requests.get(graph_query_url, stream=True) if response.status_code != 200: diff --git a/swh/web/tests/api/views/test_graph.py b/swh/web/tests/api/views/test_graph.py --- a/swh/web/tests/api/views/test_graph.py +++ b/swh/web/tests/api/views/test_graph.py @@ -1,10 +1,14 @@ -# Copyright (C) 2021 The Software Heritage developers +# Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import hashlib +import re import textwrap +from urllib.parse import unquote, urlparse + +import pytest from django.http.response import StreamingHttpResponse @@ -290,3 +294,44 @@ resp = check_http_get_response(api_client, url, status_code=404) assert resp.content_type == content_type assert resp.content == f'"{error_message}"'.encode() + + +@pytest.mark.parametrize( + "graph_query, query_params, expected_graph_query_params", + [ + ("stats", {}, ""), + ("stats", {"resolve_origins": "true"}, "resolve_origins=true"), + ("stats?a=1", {}, "a=1"), + ("stats%3Fb=2", {}, "b=2"), + ("stats?a=1", {"resolve_origins": "true"}, "a=1&resolve_origins=true"), + ("stats%3Fb=2", {"resolve_origins": "true"}, "b=2&resolve_origins=true"), + ("stats/?a=1", {"a": "2"}, "a=1&a=2"), + ("stats/%3Fa=1", {"a": "2"}, "a=1&a=2"), + ], +) +def test_graph_query_params( + api_client, + keycloak_oidc, + requests_mock, + graph_query, + query_params, + expected_graph_query_params, +): + _authenticate_graph_user(api_client, keycloak_oidc) + + requests_mock.get( + re.compile(get_config()["graph"]["server_url"]), + json=_response_json, + headers={"Content-Type": "application/json"}, + ) + + url = reverse( + "api-1-graph", url_args={"graph_query": graph_query}, query_params=query_params, + ) + + check_http_get_response(api_client, url, status_code=200) + + url = requests_mock.request_history[0].url + parsed_url = urlparse(url) + assert parsed_url.path == f"/graph/{unquote(graph_query).split('?')[0]}" + assert parsed_url.query == expected_graph_query_params