diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -6,7 +6,7 @@ import functools from typing import Dict, List, Optional -from django.http import HttpResponse +from django.http.response import HttpResponseBase from django.utils.cache import add_never_cache_headers from rest_framework.decorators import api_view @@ -99,7 +99,7 @@ doc_data = response["doc_data"] response = response["data"] # check if HTTP response needs to be created - if not isinstance(response, HttpResponse): + if not isinstance(response, HttpResponseBase): api_response = make_api_response( request, data=response, doc_data=doc_data ) diff --git a/swh/web/api/views/graph.py b/swh/web/api/views/graph.py --- a/swh/web/api/views/graph.py +++ b/swh/web/api/views/graph.py @@ -5,10 +5,11 @@ from distutils.util import strtobool import json -from typing import Dict +from typing import Dict, Iterator, Union import requests +from django.http.response import StreamingHttpResponse from rest_framework.decorators import renderer_classes from rest_framework.request import Request from rest_framework.response import Response @@ -42,7 +43,9 @@ return swhid -def _resolve_origin_swhids_in_graph_response(response: requests.Response) -> str: +def _resolve_origin_swhids_in_graph_response( + response: requests.Response, +) -> Iterator[bytes]: """ Resolve origin urls from their swhid sha1 representations in graph service responses. @@ -50,24 +53,22 @@ content_type = response.headers["Content-Type"] origin_urls: Dict[str, str] = {} if content_type == "application/x-ndjson": - processed_response = [] - for line in response.text.split("\n")[:-1]: - swhids = json.loads(line) + for line in response.iter_lines(): + swhids = json.loads(line.decode("utf-8")) processed_line = [] for swhid in swhids: processed_line.append(_resolve_origin_swhid(swhid, origin_urls)) - processed_response.append(json.dumps(processed_line)) - return "\n".join(processed_response) + "\n" + yield (json.dumps(processed_line) + "\n").encode() elif content_type == "text/plain": - processed_response = [] - for line in response.text.split("\n")[:-1]: + for line in response.iter_lines(): processed_line = [] - swhids = line.split(" ") + swhids = line.decode("utf-8").split(" ") for swhid in swhids: processed_line.append(_resolve_origin_swhid(swhid, origin_urls)) - processed_response.append(" ".join(processed_line)) - return "\n".join(processed_response) + "\n" - return response.text + yield (" ".join(processed_line) + "\n").encode() + else: + for line in response.iter_lines(): + yield line + b"\n" @api_route(r"/graph/", "api-1-graph-doc") @@ -121,7 +122,9 @@ @api_route(r"/graph/(?P.+)/", "api-1-graph") @renderer_classes([PlainTextRenderer]) -def api_graph_proxy(request: Request, graph_query: str) -> Response: +def api_graph_proxy( + request: Request, graph_query: str +) -> Union[Response, StreamingHttpResponse]: if request.get_host() != SWH_WEB_INTERNAL_SERVER_NAME: if not bool(request.user and request.user.is_authenticated): return Response("Authentication credentials were not provided.", status=401) @@ -133,13 +136,23 @@ graph_query_url += graph_query if request.GET: graph_query_url += "?" + request.GET.urlencode(safe="/;:") - response = requests.get(graph_query_url) - response_text = response.text - resolve_origins = strtobool(request.GET.get("resolve_origins", "false")) - if response.status_code == 200 and resolve_origins: - response_text = _resolve_origin_swhids_in_graph_response(response) - return Response( - response_text, - status=response.status_code, - content_type=response.headers["Content-Type"], - ) + response = requests.get(graph_query_url, stream=True) + # graph stats and counter endpoint responses are not streamed + if response.headers.get("Transfer-Encoding") != "chunked": + return Response( + response.text, + status=response.status_code, + content_type=response.headers["Content-Type"], + ) + # other endpoint responses are streamed + else: + resolve_origins = strtobool(request.GET.get("resolve_origins", "false")) + if response.status_code == 200 and resolve_origins: + response_stream = _resolve_origin_swhids_in_graph_response(response) + else: + response_stream = map(lambda line: line + b"\n", response.iter_lines()) + return StreamingHttpResponse( + response_stream, + status=response.status_code, + content_type=response.headers["Content-Type"], + ) diff --git a/swh/web/tests/api/views/test_graph.py b/swh/web/tests/api/views/test_graph.py --- a/swh/web/tests/api/views/test_graph.py +++ b/swh/web/tests/api/views/test_graph.py @@ -9,6 +9,8 @@ from hypothesis import given +from django.http.response import StreamingHttpResponse + from swh.model.identifiers import ORIGIN, SNAPSHOT, swhid from swh.web.api.views.graph import API_GRAPH_PERM from swh.web.common.utils import reverse @@ -87,7 +89,7 @@ requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_text, - headers={"Content-Type": "text/plain"}, + headers={"Content-Type": "text/plain", "Transfer-Encoding": "chunked"}, ) url = reverse("api-1-graph", url_args={"graph_query": graph_query}) @@ -95,7 +97,8 @@ resp = check_http_get_response( api_client, url, status_code=200, content_type="text/plain" ) - assert resp.content == response_text.encode() + assert isinstance(resp, StreamingHttpResponse) + assert b"".join(resp.streaming_content) == response_text.encode() _response_json = { @@ -148,14 +151,18 @@ requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_ndjson, - headers={"Content-Type": "application/x-ndjson"}, + headers={ + "Content-Type": "application/x-ndjson", + "Transfer-Encoding": "chunked", + }, ) url = reverse("api-1-graph", url_args={"graph_query": graph_query}) resp = check_http_get_response(api_client, url, status_code=200) - assert resp.content_type == "application/x-ndjson" - assert resp.content == response_ndjson.encode() + assert isinstance(resp, StreamingHttpResponse) + assert resp["Content-Type"] == "application/x-ndjson" + assert b"".join(resp.streaming_content) == response_ndjson.encode() @given(origin()) @@ -195,7 +202,7 @@ requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_text, - headers={"Content-Type": content_type}, + headers={"Content-Type": content_type, "Transfer-Encoding": "chunked"}, ) url = reverse( @@ -205,8 +212,9 @@ ) resp = check_http_get_response(api_client, url, status_code=200) - assert resp.content_type == content_type - assert resp.content == response_text.encode() + assert isinstance(resp, StreamingHttpResponse) + assert resp["Content-Type"] == content_type + assert b"".join(resp.streaming_content) == response_text.encode() url = reverse( "api-1-graph", @@ -215,9 +223,11 @@ ) resp = check_http_get_response(api_client, url, status_code=200) - assert resp.content_type == content_type + assert isinstance(resp, StreamingHttpResponse) + assert resp["Content-Type"] == content_type assert ( - resp.content == response_text.replace(origin_swhid, origin["url"]).encode() + b"".join(resp.streaming_content) + == response_text.replace(origin_swhid, origin["url"]).encode() ) diff --git a/swh/web/tests/utils.py b/swh/web/tests/utils.py --- a/swh/web/tests/utils.py +++ b/swh/web/tests/utils.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, cast -from django.http import HttpResponse +from django.http import HttpResponse, StreamingHttpResponse from django.test.client import Client from rest_framework.response import Response from rest_framework.test import APIClient @@ -24,6 +24,8 @@ if isinstance(drf_response.data, dict) and "traceback" in drf_response.data else drf_response.data ) + elif isinstance(response, StreamingHttpResponse): + error_context = getattr(response, "traceback", response.streaming_content) else: error_context = getattr(response, "traceback", response.content)