diff --git a/swh/graph/http_server.py b/swh/graph/http_server.py --- a/swh/graph/http_server.py +++ b/swh/graph/http_server.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -41,9 +41,27 @@ RANDOM_RETRIES = 10 # TODO make this configurable via rpc-serve configuration +async def _aiorpcerror_middleware(app, handler): + async def middleware_handler(request): + try: + return await handler(request) + except grpc.aio.AioRpcError as e: + # The default error handler of the RPC framework tries to serialize this + # with msgpack; which for some unknown reason causes it to raise + # ValueError("recursion limit exceeded") with a lot of context, causing + # Sentry to be overflowed with gigabytes of logs (160KB per event, with + # potentially hundreds of thousands of events per day). + # Instead, we simply serialize the exception to a string. + # https://sentry.softwareheritage.org/share/issue/d6d4db971e4b47728a6c1dd06cb9b8a5/ + raise aiohttp.web.HTTPServiceUnavailable(text=str(e)) + + return middleware_handler + + class GraphServerApp(RPCServerApp): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, middlewares=(), **kwargs): + middlewares = (_aiorpcerror_middleware,) + middlewares + super().__init__(*args, middlewares=middlewares, **kwargs) self.on_startup.append(self._start) self.on_shutdown.append(self._stop) @@ -240,12 +258,17 @@ self.traversal_request.max_edges = self.get_max_edges() await self.check_swhid(src) self.configure_request() + self.nodes_stream = self.rpc_client.Traverse(self.traversal_request) + + # Force gRPC to query the server and fetch the first nodes; so errors + # are raised early, so we can return HTTP 503 before HTTP 200 + await self.nodes_stream.wait_for_connection() def configure_request(self): pass async def stream_response(self): - async for node in self.rpc_client.Traverse(self.traversal_request): + async for node in self.nodes_stream: await self.stream_line(node.swhid) diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -35,7 +35,13 @@ client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") - self.q.put({"server_url": url, "rpc_url": app["rpc_url"]}) + self.q.put( + { + "server_url": url, + "rpc_url": app["rpc_url"], + "pid": app["local_server"].pid, + } + ) loop.run_forever() except Exception as e: self.q.put(e) @@ -46,8 +52,17 @@ @pytest.fixture(scope="module") -def graph_grpc_server(): +def graph_grpc_server_process(): server = GraphServerProcess() + + yield server + + server.kill() + + +@pytest.fixture(scope="module") +def graph_grpc_server(graph_grpc_server_process): + server = graph_grpc_server_process server.start() if isinstance(server.result, Exception): raise server.result @@ -66,7 +81,7 @@ @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": - server = GraphServerProcess() + server = request.getfixturevalue("graph_grpc_server_process") server.start() if isinstance(server.result, Exception): raise server.result diff --git a/swh/graph/tests/test_http_server_down.py b/swh/graph/tests/test_http_server_down.py new file mode 100644 --- /dev/null +++ b/swh/graph/tests/test_http_server_down.py @@ -0,0 +1,38 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import os +import signal + +import pytest + +from swh.core.api import TransientRemoteException +from swh.graph.http_client import RemoteGraphClient +from swh.graph.http_naive_client import NaiveClient + +from .test_http_client import TEST_ORIGIN_ID + + +def test_leaves(graph_client, graph_grpc_server_process): + if isinstance(graph_client, RemoteGraphClient): + pass + elif isinstance(graph_client, NaiveClient): + pytest.skip("test irrelevant for naive graph client") + else: + assert False, f"unexpected graph_client class: {graph_client.__class__}" + + list(graph_client.leaves(TEST_ORIGIN_ID)) + + server = graph_grpc_server_process + pid = server.result["pid"] + os.kill(pid, signal.SIGKILL) + try: + os.waitpid(pid, os.WNOHANG) + except ChildProcessError: + pass + + it = graph_client.leaves(TEST_ORIGIN_ID) + with pytest.raises(TransientRemoteException, match="failed to connect"): + list(it)