diff --git a/swh/web/api/views/graph.py b/swh/web/api/views/graph.py index e91b7903..e3117a55 100644 --- a/swh/web/api/views/graph.py +++ b/swh/web/api/views/graph.py @@ -1,144 +1,145 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from distutils.util import strtobool import json from typing import Dict import requests from rest_framework.decorators import renderer_classes from rest_framework.request import Request from rest_framework.response import Response from swh.model.identifiers import ORIGIN, parse_swhid from swh.web.api.apidoc import api_doc from swh.web.api.apiurls import api_route from swh.web.api.renderers import PlainTextRenderer from swh.web.common import archive -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config API_GRAPH_PERM = "swh.web.api.graph" def _resolve_origin_swhid(swhid: str, origin_urls: Dict[str, str]) -> str: """ Resolve origin url from its swhid sha1 representation. """ parsed_swhid = parse_swhid(swhid) if parsed_swhid.object_type == ORIGIN: if parsed_swhid.object_id in origin_urls: return origin_urls[parsed_swhid.object_id] else: origin_info = list( archive.lookup_origins_by_sha1s([parsed_swhid.object_id]) )[0] assert origin_info is not None origin_urls[parsed_swhid.object_id] = origin_info["url"] return origin_info["url"] else: return swhid def _resolve_origin_swhids_in_graph_response(response: requests.Response) -> str: """ Resolve origin urls from their swhid sha1 representations in graph service responses. """ content_type = response.headers["Content-Type"] origin_urls: Dict[str, str] = {} if content_type == "application/x-ndjson": processed_response = [] for line in response.text.split("\n")[:-1]: swhids = json.loads(line) processed_line = [] for swhid in swhids: processed_line.append(_resolve_origin_swhid(swhid, origin_urls)) processed_response.append(json.dumps(processed_line)) return "\n".join(processed_response) + "\n" elif content_type == "text/plain": processed_response = [] for line in response.text.split("\n")[:-1]: processed_line = [] swhids = line.split(" ") for swhid in swhids: processed_line.append(_resolve_origin_swhid(swhid, origin_urls)) processed_response.append(" ".join(processed_line)) return "\n".join(processed_response) + "\n" return response.text @api_route(r"/graph/", "api-1-graph-doc") @api_doc("/graph/") def api_graph(request: Request) -> None: """ .. http:get:: /api/1/graph/(graph_query)/ Provide fast access to the graph representation of the Software Heritage archive. That endpoint acts as a proxy for the `Software Heritage Graph service `_. It provides fast access to the `graph representation `_ of the Software Heritage archive. The full documentation of the available Graph REST API can be found `here `_. .. warning:: That endpoint is not publicly available and requires authentication and special user permission in order to be able to request it. :param string graph_query: query to forward to the Software Heritage Graph archive (see its `documentation `_) :query boolean resolve_origins: extra parameter defined by that proxy enabling to resolve origin urls from their sha1 representations :statuscode 200: no error :statuscode 400: an invalid graph query has been provided :statuscode 404: provided graph node cannot be found **Examples:** .. parsed-literal:: :swh_web_api:`graph/leaves/swh:1:dir:432d1b21c1256f7408a07c577b6974bbdbcc1323/` :swh_web_api:`graph/neighbors/swh:1:rev:f39d7d78b70e0f39facb1e4fab77ad3df5c52a35/` :swh_web_api:`graph/randomwalk/swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2/ori?direction=backward` :swh_web_api:`graph/randomwalk/swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2/ori?direction=backward&limit=-2` :swh_web_api:`graph/visit/nodes/swh:1:snp:40f9f177b8ab0b7b3d70ee14bbc8b214e2b2dcfc?direction=backward&resolve_origins=true` :swh_web_api:`graph/visit/edges/swh:1:snp:40f9f177b8ab0b7b3d70ee14bbc8b214e2b2dcfc?direction=backward&resolve_origins=true` :swh_web_api:`graph/visit/paths/swh:1:dir:644dd466d8ad527ea3a609bfd588a3244e6dafcb?direction=backward&resolve_origins=true` """ return None @api_route(r"/graph/(?P.+)/", "api-1-graph") @renderer_classes([PlainTextRenderer]) def api_graph_proxy(request: Request, graph_query: str) -> Response: - if not bool(request.user and request.user.is_authenticated): - return Response("Authentication credentials were not provided.", status=401) - if not request.user.has_perm(API_GRAPH_PERM): - return Response( - "You do not have permission to perform this action.", status=403 - ) + if request.get_host() != SWH_WEB_INTERNAL_SERVER_NAME: + if not bool(request.user and request.user.is_authenticated): + return Response("Authentication credentials were not provided.", status=401) + if not request.user.has_perm(API_GRAPH_PERM): + return Response( + "You do not have permission to perform this action.", status=403 + ) graph_query_url = get_config()["graph"]["server_url"] graph_query_url += graph_query if request.GET: graph_query_url += "?" + request.GET.urlencode(safe="/;:") response = requests.get(graph_query_url) response_text = response.text resolve_origins = strtobool(request.GET.get("resolve_origins", "false")) if response.status_code == 200 and resolve_origins: response_text = _resolve_origin_swhids_in_graph_response(response) return Response( response_text, status=response.status_code, content_type=response.headers["Content-Type"], ) diff --git a/swh/web/config.py b/swh/web/config.py index db4e8939..a044d08f 100644 --- a/swh/web/config.py +++ b/swh/web/config.py @@ -1,168 +1,170 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import os from typing import Any, Dict from swh.core import config from swh.indexer.storage import get_indexer_storage from swh.scheduler import get_scheduler from swh.search import get_search from swh.storage import get_storage from swh.vault import get_vault from swh.web import settings +SWH_WEB_INTERNAL_SERVER_NAME = "archive.internal.softwareheritage.org" + SETTINGS_DIR = os.path.dirname(settings.__file__) DEFAULT_CONFIG = { "allowed_hosts": ("list", []), "search": ( "dict", {"cls": "remote", "url": "http://127.0.0.1:5010/", "timeout": 10,}, ), "storage": ( "dict", {"cls": "remote", "url": "http://127.0.0.1:5002/", "timeout": 10,}, ), "indexer_storage": ( "dict", {"cls": "remote", "url": "http://127.0.0.1:5007/", "timeout": 1,}, ), "log_dir": ("string", "/tmp/swh/log"), "debug": ("bool", False), "serve_assets": ("bool", False), "host": ("string", "127.0.0.1"), "port": ("int", 5004), "secret_key": ("string", "development key"), # do not display code highlighting for content > 1MB "content_display_max_size": ("int", 5 * 1024 * 1024), "snapshot_content_max_size": ("int", 1000), "throttling": ( "dict", { "cache_uri": None, # production: memcached as cache (127.0.0.1:11211) # development: in-memory cache so None "scopes": { "swh_api": { "limiter_rate": {"default": "120/h"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_api_origin_search": { "limiter_rate": {"default": "10/m"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_vault_cooking": { "limiter_rate": {"default": "120/h", "GET": "60/m"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_save_origin": { "limiter_rate": {"default": "120/h", "POST": "10/h"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_api_origin_visit_latest": { "limiter_rate": {"default": "700/m"}, "exempted_networks": ["127.0.0.0/8"], }, }, }, ), "vault": ("dict", {"cls": "remote", "args": {"url": "http://127.0.0.1:5005/",}}), "scheduler": ("dict", {"cls": "remote", "url": "http://127.0.0.1:5008/"}), "development_db": ("string", os.path.join(SETTINGS_DIR, "db.sqlite3")), "test_db": ("string", os.path.join(SETTINGS_DIR, "testdb.sqlite3")), "production_db": ("string", "/var/lib/swh/web.sqlite3"), "deposit": ( "dict", { "private_api_url": "https://deposit.softwareheritage.org/1/private/", "private_api_user": "swhworker", "private_api_password": "", }, ), "coverage_count_origins": ("bool", False), "e2e_tests_mode": ("bool", False), "es_workers_index_url": ("string", ""), "history_counters_url": ( "string", "https://stats.export.softwareheritage.org/history_counters.json", ), "client_config": ("dict", {}), "keycloak": ("dict", {"server_url": "", "realm_name": ""}), "graph": ( "dict", {"server_url": "http://graph.internal.softwareheritage.org:5009/graph/"}, ), } swhweb_config = {} # type: Dict[str, Any] def get_config(config_file="web/web"): """Read the configuration file `config_file`. If an environment variable SWH_CONFIG_FILENAME is defined, this takes precedence over the config_file parameter. In any case, update the app with parameters (secret_key, conf) and return the parsed configuration as a dict. If no configuration file is provided, return a default configuration. """ if not swhweb_config: config_filename = os.environ.get("SWH_CONFIG_FILENAME") if config_filename: config_file = config_filename cfg = config.load_named_config(config_file, DEFAULT_CONFIG) swhweb_config.update(cfg) config.prepare_folders(swhweb_config, "log_dir") if swhweb_config.get("search"): swhweb_config["search"] = get_search(**swhweb_config["search"]) else: swhweb_config["search"] = None swhweb_config["storage"] = get_storage(**swhweb_config["storage"]) swhweb_config["vault"] = get_vault(**swhweb_config["vault"]) swhweb_config["indexer_storage"] = get_indexer_storage( **swhweb_config["indexer_storage"] ) swhweb_config["scheduler"] = get_scheduler(**swhweb_config["scheduler"]) return swhweb_config def search(): """Return the current application's search. """ return get_config()["search"] def storage(): """Return the current application's storage. """ return get_config()["storage"] def vault(): """Return the current application's vault. """ return get_config()["vault"] def indexer_storage(): """Return the current application's indexer storage. """ return get_config()["indexer_storage"] def scheduler(): """Return the current application's scheduler. """ return get_config()["scheduler"] diff --git a/swh/web/settings/tests.py b/swh/web/settings/tests.py index b4e50124..f4c6e31c 100644 --- a/swh/web/settings/tests.py +++ b/swh/web/settings/tests.py @@ -1,109 +1,109 @@ # Copyright (C) 2017-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information """ Django tests settings for swh-web. """ import os import sys -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config scope1_limiter_rate = 3 scope1_limiter_rate_post = 1 scope2_limiter_rate = 5 scope2_limiter_rate_post = 2 scope3_limiter_rate = 1 scope3_limiter_rate_post = 1 save_origin_rate_post = 10 swh_web_config = get_config() swh_web_config.update( { # disable django debug mode when running cypress tests "debug": "pytest" in sys.argv[0] or "PYTEST_XDIST_WORKER" in os.environ, "secret_key": "test", "history_counters_url": "", "throttling": { "cache_uri": None, "scopes": { "swh_api": { "limiter_rate": {"default": "60/min"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_api_origin_search": { "limiter_rate": {"default": "100/min"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_api_origin_visit_latest": { "limiter_rate": {"default": "6000/min"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_vault_cooking": { "limiter_rate": {"default": "120/h", "GET": "60/m"}, "exempted_networks": ["127.0.0.0/8"], }, "swh_save_origin": { "limiter_rate": { "default": "120/h", "POST": "%s/h" % save_origin_rate_post, } }, "scope1": { "limiter_rate": { "default": "%s/min" % scope1_limiter_rate, "POST": "%s/min" % scope1_limiter_rate_post, } }, "scope2": { "limiter_rate": { "default": "%s/min" % scope2_limiter_rate, "POST": "%s/min" % scope2_limiter_rate_post, } }, "scope3": { "limiter_rate": { "default": "%s/min" % scope3_limiter_rate, "POST": "%s/min" % scope3_limiter_rate_post, }, "exempted_networks": ["127.0.0.0/8"], }, }, }, "keycloak": { "server_url": "http://localhost:8080/auth", "realm_name": "SoftwareHeritage", }, } ) from .common import * # noqa from .common import ALLOWED_HOSTS, LOGGING # noqa, isort: skip DATABASES = { "default": { "ENGINE": "django.db.backends.sqlite3", "NAME": swh_web_config["test_db"], } } # when not running unit tests, make the webapp fetch data from memory storages if "pytest" not in sys.argv[0] and "PYTEST_XDIST_WORKER" not in os.environ: swh_web_config.update({"debug": True, "e2e_tests_mode": True}) from swh.web.tests.data import get_tests_data, override_storages test_data = get_tests_data() override_storages( test_data["storage"], test_data["idx_storage"], test_data["search"] ) else: - ALLOWED_HOSTS += ["testserver"] + ALLOWED_HOSTS += ["testserver", SWH_WEB_INTERNAL_SERVER_NAME] # Silent DEBUG output when running unit tests LOGGING["handlers"]["console"]["level"] = "INFO" # type: ignore diff --git a/swh/web/tests/api/views/test_graph.py b/swh/web/tests/api/views/test_graph.py index c71d3d4b..a6fba6ea 100644 --- a/swh/web/tests/api/views/test_graph.py +++ b/swh/web/tests/api/views/test_graph.py @@ -1,232 +1,245 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information import hashlib import json import textwrap from hypothesis import given from swh.model.identifiers import ORIGIN, SNAPSHOT, swhid from swh.web.api.views.graph import API_GRAPH_PERM from swh.web.common.utils import reverse -from swh.web.config import get_config +from swh.web.config import SWH_WEB_INTERNAL_SERVER_NAME, get_config from swh.web.tests.auth.keycloak_mock import mock_keycloak from swh.web.tests.auth.sample_data import oidc_profile from swh.web.tests.strategies import origin from swh.web.tests.utils import check_http_get_response +def test_graph_endpoint_no_authentication_for_vpn_users(api_client, requests_mock): + graph_query = "stats" + url = reverse("api-1-graph", url_args={"graph_query": graph_query}) + requests_mock.get( + get_config()["graph"]["server_url"] + graph_query, + json={}, + headers={"Content-Type": "application/json"}, + ) + check_http_get_response( + api_client, url, status_code=200, server_name=SWH_WEB_INTERNAL_SERVER_NAME + ) + + def test_graph_endpoint_needs_authentication(api_client): url = reverse("api-1-graph", url_args={"graph_query": "stats"}) check_http_get_response(api_client, url, status_code=401) def _authenticate_graph_user(api_client, mocker): mock_keycloak(mocker, user_permissions=[API_GRAPH_PERM]) api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {oidc_profile['refresh_token']}") def test_graph_endpoint_needs_permission(api_client, mocker, requests_mock): graph_query = "stats" url = reverse("api-1-graph", url_args={"graph_query": graph_query}) api_client.credentials(HTTP_AUTHORIZATION=f"Bearer {oidc_profile['refresh_token']}") mock_keycloak(mocker, user_permissions=[]) check_http_get_response(api_client, url, status_code=403) _authenticate_graph_user(api_client, mocker) requests_mock.get( get_config()["graph"]["server_url"] + graph_query, json={}, headers={"Content-Type": "application/json"}, ) check_http_get_response(api_client, url, status_code=200) def test_graph_text_plain_response(api_client, mocker, requests_mock): _authenticate_graph_user(api_client, mocker) graph_query = "leaves/swh:1:dir:432d1b21c1256f7408a07c577b6974bbdbcc1323" response_text = textwrap.dedent( """\ swh:1:cnt:1d3dace0a825b0535c37c53ed669ef817e9c1b47 swh:1:cnt:6d5b280f4e33589ae967a7912a587dd5cb8dedaa swh:1:cnt:91bef238bf01356a550d416d14bb464c576ac6f4 swh:1:cnt:58a8b925a463b87d49639fda282b8f836546e396 swh:1:cnt:fd32ee0a87e16ccc853dfbeb7018674f9ce008c0 swh:1:cnt:ab7c39871872589a4fc9e249ebc927fb1042c90d swh:1:cnt:93073c02bf3869845977527de16af4d54765838d swh:1:cnt:4251f795b52c54c447a97c9fe904d8b1f993b1e0 swh:1:cnt:c6e7055424332006d07876ffeba684e7e284b383 swh:1:cnt:8459d8867dc3b15ef7ae9683e21cccc9ab2ec887 swh:1:cnt:5f9981d52202815aa947f85b9dfa191b66f51138 swh:1:cnt:00a685ec51bcdf398c15d588ecdedb611dbbab4b swh:1:cnt:e1cf1ea335106a0197a2f92f7804046425a7d3eb swh:1:cnt:07069b38087f88ec192d2c9aff75a502476fd17d swh:1:cnt:f045ee845c7f14d903a2c035b2691a7c400c01f0 """ ) requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_text, headers={"Content-Type": "text/plain"}, ) url = reverse("api-1-graph", url_args={"graph_query": graph_query}) resp = check_http_get_response( api_client, url, status_code=200, content_type="text/plain" ) assert resp.content == response_text.encode() _response_json = { "counts": {"nodes": 17075708289, "edges": 196236587976}, "ratios": { "compression": 0.16, "bits_per_node": 58.828, "bits_per_edge": 5.119, "avg_locality": 2184278529.729, }, "indegree": {"min": 0, "max": 263180117, "avg": 11.4921492364925}, "outdegree": {"min": 0, "max": 1033207, "avg": 11.4921492364925}, } def test_graph_json_response(api_client, mocker, requests_mock): _authenticate_graph_user(api_client, mocker) graph_query = "stats" requests_mock.get( get_config()["graph"]["server_url"] + graph_query, json=_response_json, headers={"Content-Type": "application/json"}, ) url = reverse("api-1-graph", url_args={"graph_query": graph_query}) resp = check_http_get_response(api_client, url, status_code=200) assert resp.content_type == "application/json" assert resp.content == json.dumps(_response_json).encode() def test_graph_ndjson_response(api_client, mocker, requests_mock): _authenticate_graph_user(api_client, mocker) graph_query = "visit/paths/swh:1:dir:644dd466d8ad527ea3a609bfd588a3244e6dafcb" response_ndjson = textwrap.dedent( """\ ["swh:1:dir:644dd466d8ad527ea3a609bfd588a3244e6dafcb",\ "swh:1:cnt:acfb7cabd63b368a03a9df87670ece1488c8bce0"] ["swh:1:dir:644dd466d8ad527ea3a609bfd588a3244e6dafcb",\ "swh:1:cnt:2a0837708151d76edf28fdbb90dc3eabc676cff3"] ["swh:1:dir:644dd466d8ad527ea3a609bfd588a3244e6dafcb",\ "swh:1:cnt:eaf025ad54b94b2fdda26af75594cfae3491ec75"] """ ) requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_ndjson, headers={"Content-Type": "application/x-ndjson"}, ) url = reverse("api-1-graph", url_args={"graph_query": graph_query}) resp = check_http_get_response(api_client, url, status_code=200) assert resp.content_type == "application/x-ndjson" assert resp.content == response_ndjson.encode() @given(origin()) def test_graph_response_resolve_origins( archive_data, api_client, mocker, requests_mock, origin ): hasher = hashlib.sha1() hasher.update(origin["url"].encode()) origin_sha1 = hasher.hexdigest() origin_swhid = str(swhid(ORIGIN, origin_sha1)) snapshot = archive_data.snapshot_get_latest(origin["url"])["id"] snapshot_swhid = str(swhid(SNAPSHOT, snapshot)) _authenticate_graph_user(api_client, mocker) for graph_query, response_text, content_type in ( ( f"visit/nodes/{snapshot_swhid}", f"{snapshot_swhid}\n{origin_swhid}\n", "text/plain", ), ( f"visit/edges/{snapshot_swhid}", f"{snapshot_swhid} {origin_swhid}\n", "text/plain", ), ( f"visit/paths/{snapshot_swhid}", f'["{snapshot_swhid}", "{origin_swhid}"]\n', "application/x-ndjson", ), ): # set two lines response to check resolved origins cache response_text = response_text + response_text requests_mock.get( get_config()["graph"]["server_url"] + graph_query, text=response_text, headers={"Content-Type": content_type}, ) url = reverse( "api-1-graph", url_args={"graph_query": graph_query}, query_params={"direction": "backward"}, ) resp = check_http_get_response(api_client, url, status_code=200) assert resp.content_type == content_type assert resp.content == response_text.encode() url = reverse( "api-1-graph", url_args={"graph_query": graph_query}, query_params={"direction": "backward", "resolve_origins": "true"}, ) resp = check_http_get_response(api_client, url, status_code=200) assert resp.content_type == content_type assert ( resp.content == response_text.replace(origin_swhid, origin["url"]).encode() ) def test_graph_response_resolve_origins_nothing_to_do( api_client, mocker, requests_mock ): _authenticate_graph_user(api_client, mocker) graph_query = "stats" requests_mock.get( get_config()["graph"]["server_url"] + graph_query, json=_response_json, headers={"Content-Type": "application/json"}, ) url = reverse( "api-1-graph", url_args={"graph_query": graph_query}, query_params={"resolve_origins": "true"}, ) resp = check_http_get_response(api_client, url, status_code=200) assert resp.content_type == "application/json" assert resp.content == json.dumps(_response_json).encode() diff --git a/swh/web/tests/utils.py b/swh/web/tests/utils.py index 96852d99..4f6c9458 100644 --- a/swh/web/tests/utils.py +++ b/swh/web/tests/utils.py @@ -1,196 +1,202 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Dict, Optional, cast from django.http import HttpResponse from django.test.client import Client from rest_framework.response import Response from rest_framework.test import APIClient from swh.web.tests.django_asserts import assert_template_used def _assert_http_response( response: HttpResponse, status_code: int, content_type: str ) -> HttpResponse: if isinstance(response, Response): drf_response = cast(Response, response) error_context = ( drf_response.data.pop("traceback") if isinstance(drf_response.data, dict) and "traceback" in drf_response.data else drf_response.data ) else: error_context = getattr(response, "traceback", response.content) assert response.status_code == status_code, error_context if content_type != "*/*": assert response["Content-Type"].startswith(content_type) return response def check_http_get_response( client: Client, url: str, status_code: int, content_type: str = "*/*", http_origin: Optional[str] = None, + server_name: Optional[str] = None, ) -> HttpResponse: """Helper function to check HTTP response for a GET request. Args: client: Django test client url: URL to check response status_code: expected HTTP status code content_type: expected response content type http_origin: optional HTTP_ORIGIN header value Returns: The HTTP response """ return _assert_http_response( - response=client.get(url, HTTP_ACCEPT=content_type, HTTP_ORIGIN=http_origin), + response=client.get( + url, + HTTP_ACCEPT=content_type, + HTTP_ORIGIN=http_origin, + SERVER_NAME=server_name if server_name else "testserver", + ), status_code=status_code, content_type=content_type, ) def check_http_post_response( client: Client, url: str, status_code: int, content_type: str = "*/*", data: Optional[Dict[str, Any]] = None, ) -> HttpResponse: """Helper function to check HTTP response for a POST request. Args: client: Django test client url: URL to check response status_code: expected HTTP status code content_type: expected response content type data: optional POST data Returns: The HTTP response """ return _assert_http_response( response=client.post( url, data=data, content_type="application/json", HTTP_ACCEPT=content_type, ), status_code=status_code, content_type=content_type, ) def check_api_get_responses( api_client: APIClient, url: str, status_code: int ) -> Response: """Helper function to check Web API responses for GET requests for all accepted content types (JSON, YAML, HTML). Args: api_client: DRF test client url: Web API URL to check responses status_code: expected HTTP status code Returns: The Web API JSON response """ # check JSON response response_json = check_http_get_response( api_client, url, status_code, content_type="application/json" ) # check HTML response (API Web UI) check_http_get_response(api_client, url, status_code, content_type="text/html") # check YAML response check_http_get_response( api_client, url, status_code, content_type="application/yaml" ) return cast(Response, response_json) def check_api_post_response( api_client: APIClient, url: str, status_code: int, content_type: str = "*/*", data: Optional[Dict[str, Any]] = None, ) -> HttpResponse: """Helper function to check Web API response for a POST request for all accepted content types. Args: api_client: DRF test client url: Web API URL to check response status_code: expected HTTP status code Returns: The HTTP response """ return _assert_http_response( response=api_client.post( url, data=data, format="json", HTTP_ACCEPT=content_type, ), status_code=status_code, content_type=content_type, ) def check_api_post_responses( api_client: APIClient, url: str, status_code: int, data: Optional[Dict[str, Any]] = None, ) -> Response: """Helper function to check Web API responses for POST requests for all accepted content types (JSON, YAML). Args: api_client: DRF test client url: Web API URL to check responses status_code: expected HTTP status code Returns: The Web API JSON response """ # check JSON response response_json = check_api_post_response( api_client, url, status_code, content_type="application/json", data=data ) # check YAML response check_api_post_response( api_client, url, status_code, content_type="application/yaml", data=data ) return cast(Response, response_json) def check_html_get_response( client: Client, url: str, status_code: int, template_used: Optional[str] = None ) -> HttpResponse: """Helper function to check HTML responses for a GET request. Args: client: Django test client url: URL to check responses status_code: expected HTTP status code template_used: optional used Django template to check Returns: The HTML response """ response = check_http_get_response( client, url, status_code, content_type="text/html" ) if template_used is not None: assert_template_used(response, template_used) return response