You have reached the Software Heritage graph API server.
See its API documentation for more information.
""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.rpc_client: TraversalServiceStub = self.request.app["rpc_client"] def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s.upper() def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s def get_return_types(self): """Validate HTTP query parameter 'return types', i.e, a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) # if the user puts a star, # then we filter nothing, we don't need the other information if "*" in s: return "*" else: return s def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(text=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") async def check_swhid(self, swhid): """Validate that the given SWHID exists in the graph""" try: await self.rpc_client.GetNode( GetNodeRequest(swhid=swhid, mask=FieldMask(paths=["swhid"])) ) except grpc.aio.AioRpcError as e: if e.code() == grpc.StatusCode.INVALID_ARGUMENT: raise aiohttp.web.HTTPBadRequest(text=str(e.details())) class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: self._buf = [] try: await self.stream_response() finally: await self._flush_buffer() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" self._buf.append(line) if len(self._buf) > 100: await self._flush_buffer() async def _flush_buffer(self): await self.response_stream.write("\n".join(self._buf).encode() + b"\n") self._buf = [] class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): res = await self.rpc_client.Stats(StatsRequest()) stats = json_format.MessageToDict( res, including_default_value_fields=True, preserving_proto_field_name=True ) # Int64 fields are serialized as strings by default. for descriptor in res.DESCRIPTOR.fields: if descriptor.type == descriptor.TYPE_INT64: try: stats[descriptor.name] = int(stats[descriptor.name]) except KeyError: pass json_body = json.dumps(stats, indent=4, sort_keys=True) return aiohttp.web.Response(body=json_body, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" async def prepare_response(self): src = self.request.match_info["src"] self.traversal_request = TraversalRequest( src=[src], edges=self.get_edges(), direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() await self.check_swhid(src) self.configure_request() + self.nodes_stream = self.rpc_client.Traverse(self.traversal_request) + + # Force gRPC to query the server and fetch the first nodes; so errors + # are raised early, so we can return HTTP 503 before HTTP 200 + await self.nodes_stream.wait_for_connection() def configure_request(self): pass async def stream_response(self): - async for node in self.rpc_client.Traverse(self.traversal_request): + async for node in self.nodes_stream: await self.stream_line(node.swhid) class LeavesView(SimpleTraversalView): def configure_request(self): self.traversal_request.return_nodes.max_traversal_successors = 0 class NeighborsView(SimpleTraversalView): def configure_request(self): self.traversal_request.min_depth = 1 self.traversal_request.max_depth = 1 class VisitNodesView(SimpleTraversalView): pass class VisitEdgesView(SimpleTraversalView): def configure_request(self): self.traversal_request.mask.paths.extend(["successor", "successor.swhid"]) # self.traversal_request.return_fields.successor = True async def stream_response(self): async for node in self.rpc_client.Traverse(self.traversal_request): for succ in node.successor: await self.stream_line(node.swhid + " " + succ.swhid) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): src = self.request.match_info["src"] self.traversal_request = TraversalRequest( src=[src], edges=self.get_edges(), direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() self.configure_request() res = await self.rpc_client.CountNodes(self.traversal_request) return aiohttp.web.Response( body=str(res.count), content_type="application/json" ) def configure_request(self): pass class CountNeighborsView(CountView): def configure_request(self): self.traversal_request.min_depth = 1 self.traversal_request.max_depth = 1 class CountLeavesView(CountView): def configure_request(self): self.traversal_request.return_nodes.max_traversal_successors = 0 class CountVisitNodesView(CountView): pass def make_app(config=None, rpc_url=None, spawn_rpc_port=50091, **kwargs): app = GraphServerApp(**kwargs) if rpc_url is None: app["local_server"], port = spawn_java_rpc_server(config, port=spawn_rpc_port) rpc_url = f"localhost:{port}" app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["rpc_url"] = rpc_url return app def make_app_from_configfile(): """Load configuration and then build application to run""" config_file = os.environ.get("SWH_CONFIG_FILENAME") config = config_read(config_file) return make_app(config=config) diff --git a/swh/graph/rpc_server.py b/swh/graph/rpc_server.py index 540fc5d..f6e1b4b 100644 --- a/swh/graph/rpc_server.py +++ b/swh/graph/rpc_server.py @@ -1,47 +1,48 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A simple tool to start the swh-graph GRPC server in Java. """ import logging +import shlex import subprocess import aiohttp.test_utils import aiohttp.web from swh.graph.config import check_config def spawn_java_rpc_server(config, port=None): if port is None: port = aiohttp.test_utils.unused_port() config = check_config(config or {}) cmd = [ "java", "-cp", config["classpath"], *config["java_tool_options"].split(), "org.softwareheritage.graph.rpc.GraphServer", "--port", str(port), str(config["graph"]["path"]), ] print(cmd) # XXX: shlex.join() is in 3.8 # logging.info("Starting RPC server: %s", shlex.join(cmd)) - logging.info("Starting RPC server: %s", str(cmd)) + logging.info("Starting RPC server: %s", " ".join(shlex.quote(x) for x in cmd)) server = subprocess.Popen(cmd) return server, port def stop_java_rpc_server(server: subprocess.Popen, timeout: int = 15): server.terminate() try: server.wait(timeout=timeout) except subprocess.TimeoutExpired: logging.warning("Server did not terminate, sending kill signal...") server.kill() diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index 3d86602..6e832af 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,70 +1,107 @@ -# Copyright (C) 2019-2021 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import multiprocessing from pathlib import Path import subprocess from aiohttp.test_utils import TestClient, TestServer, loop_context +import grpc import pytest from swh.graph.http_client import RemoteGraphClient from swh.graph.http_naive_client import NaiveClient +from swh.graph.rpc.swhgraph_pb2_grpc import TraversalServiceStub SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/compressed/example" class GraphServerProcess(multiprocessing.Process): - def __init__(self, q, *args, **kwargs): - self.q = q + def __init__(self, *args, **kwargs): + self.q = multiprocessing.Queue() super().__init__(*args, **kwargs) def run(self): # Lazy import to allow debian packaging from swh.graph.http_server import make_app try: config = {"graph": {"path": TEST_GRAPH_PATH}} with loop_context() as loop: app = make_app(config=config, debug=True, spawn_rpc_port=None) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") - self.q.put(url) + self.q.put( + { + "server_url": url, + "rpc_url": app["rpc_url"], + "pid": app["local_server"].pid, + } + ) loop.run_forever() except Exception as e: self.q.put(e) + def start(self, *args, **kwargs): + super().start() + self.result = self.q.get() + + +@pytest.fixture(scope="module") +def graph_grpc_server_process(): + server = GraphServerProcess() + + yield server + + server.kill() + + +@pytest.fixture(scope="module") +def graph_grpc_server(graph_grpc_server_process): + server = graph_grpc_server_process + server.start() + if isinstance(server.result, Exception): + raise server.result + grpc_url = server.result["rpc_url"] + yield grpc_url + server.kill() + + +@pytest.fixture(scope="module") +def graph_grpc_stub(graph_grpc_server): + with grpc.insecure_channel(graph_grpc_server) as channel: + stub = TraversalServiceStub(channel) + yield stub + @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": - queue = multiprocessing.Queue() - server = GraphServerProcess(queue) + server = request.getfixturevalue("graph_grpc_server_process") server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - yield RemoteGraphClient(str(res)) - server.terminate() + if isinstance(server.result, Exception): + raise server.result + yield RemoteGraphClient(str(server.result["server_url"])) + server.kill() else: def zstdcat(*files): p = subprocess.run(["zstdcat", *files], stdout=subprocess.PIPE) return p.stdout.decode() edges_dataset = SWH_GRAPH_TESTS_ROOT / "dataset/edges" edge_files = edges_dataset.glob("*/*.edges.csv.zst") node_files = edges_dataset.glob("*/*.nodes.csv.zst") nodes = set(zstdcat(*node_files).strip().split("\n")) edge_lines = [line.split() for line in zstdcat(*edge_files).strip().split("\n")] edges = [(src, dst) for src, dst, *_ in edge_lines] for src, dst in edges: nodes.add(src) nodes.add(dst) yield NaiveClient(nodes=list(nodes), edges=edges) diff --git a/swh/graph/tests/test_grpc.py b/swh/graph/tests/test_grpc.py new file mode 100644 index 0000000..2cef192 --- /dev/null +++ b/swh/graph/tests/test_grpc.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import hashlib + +from google.protobuf.field_mask_pb2 import FieldMask + +from swh.graph.rpc.swhgraph_pb2 import ( + GraphDirection, + NodeFilter, + StatsRequest, + TraversalRequest, +) + +TEST_ORIGIN_ID = "swh:1:ori:{}".format( + hashlib.sha1(b"https://example.com/swh/graph").hexdigest() +) + + +def test_stats(graph_grpc_stub): + stats = graph_grpc_stub.Stats(StatsRequest()) + assert stats.num_nodes == 21 + assert stats.num_edges == 23 + assert isinstance(stats.compression_ratio, float) + assert isinstance(stats.bits_per_node, float) + assert isinstance(stats.bits_per_edge, float) + assert isinstance(stats.avg_locality, float) + assert stats.indegree_min == 0 + assert stats.indegree_max == 3 + assert isinstance(stats.indegree_avg, float) + assert stats.outdegree_min == 0 + assert stats.outdegree_max == 3 + assert isinstance(stats.outdegree_avg, float) + + +def test_leaves(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=[TEST_ORIGIN_ID], + mask=FieldMask(paths=["swhid"]), + return_nodes=NodeFilter(types="cnt"), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + "swh:1:cnt:0000000000000000000000000000000000000007", + ] + assert set(actual) == set(expected) + + +def test_neighbors(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rev:0000000000000000000000000000000000000009"], + direction=GraphDirection.BACKWARD, + mask=FieldMask(paths=["swhid"]), + min_depth=1, + max_depth=1, + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:snp:0000000000000000000000000000000000000020", + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000013", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + edges="rel:rev,rev:rev", + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + return_nodes=NodeFilter(types="dir"), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected) diff --git a/swh/graph/tests/test_http_server_down.py b/swh/graph/tests/test_http_server_down.py new file mode 100644 index 0000000..d6cb3fb --- /dev/null +++ b/swh/graph/tests/test_http_server_down.py @@ -0,0 +1,38 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import os +import signal + +import pytest + +from swh.core.api import TransientRemoteException +from swh.graph.http_client import RemoteGraphClient +from swh.graph.http_naive_client import NaiveClient + +from .test_http_client import TEST_ORIGIN_ID + + +def test_leaves(graph_client, graph_grpc_server_process): + if isinstance(graph_client, RemoteGraphClient): + pass + elif isinstance(graph_client, NaiveClient): + pytest.skip("test irrelevant for naive graph client") + else: + assert False, f"unexpected graph_client class: {graph_client.__class__}" + + list(graph_client.leaves(TEST_ORIGIN_ID)) + + server = graph_grpc_server_process + pid = server.result["pid"] + os.kill(pid, signal.SIGKILL) + try: + os.waitpid(pid, os.WNOHANG) + except ChildProcessError: + pass + + it = graph_client.leaves(TEST_ORIGIN_ID) + with pytest.raises(TransientRemoteException, match="failed to connect"): + list(it)