diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index 3d86602..a91c6e1 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,70 +1,92 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import multiprocessing from pathlib import Path import subprocess from aiohttp.test_utils import TestClient, TestServer, loop_context +import grpc import pytest from swh.graph.http_client import RemoteGraphClient from swh.graph.http_naive_client import NaiveClient +from swh.graph.rpc.swhgraph_pb2_grpc import TraversalServiceStub SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/compressed/example" class GraphServerProcess(multiprocessing.Process): def __init__(self, q, *args, **kwargs): self.q = q super().__init__(*args, **kwargs) def run(self): # Lazy import to allow debian packaging from swh.graph.http_server import make_app try: config = {"graph": {"path": TEST_GRAPH_PATH}} with loop_context() as loop: app = make_app(config=config, debug=True, spawn_rpc_port=None) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") - self.q.put(url) + self.q.put((url, app["rpc_url"])) loop.run_forever() except Exception as e: self.q.put(e) +@pytest.fixture(scope="module") +def graph_grpc_server(): + queue = multiprocessing.Queue() + server = GraphServerProcess(queue) + server.start() + res = queue.get() + if isinstance(res, Exception): + raise res + grpc_url = res[1] + yield grpc_url + server.terminate() + + +@pytest.fixture(scope="module") +def graph_grpc_stub(graph_grpc_server): + with grpc.insecure_channel(graph_grpc_server) as channel: + stub = TraversalServiceStub(channel) + yield stub + + @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": queue = multiprocessing.Queue() server = GraphServerProcess(queue) server.start() res = queue.get() if isinstance(res, Exception): raise res - yield RemoteGraphClient(str(res)) + yield RemoteGraphClient(str(res[0])) server.terminate() else: def zstdcat(*files): p = subprocess.run(["zstdcat", *files], stdout=subprocess.PIPE) return p.stdout.decode() edges_dataset = SWH_GRAPH_TESTS_ROOT / "dataset/edges" edge_files = edges_dataset.glob("*/*.edges.csv.zst") node_files = edges_dataset.glob("*/*.nodes.csv.zst") nodes = set(zstdcat(*node_files).strip().split("\n")) edge_lines = [line.split() for line in zstdcat(*edge_files).strip().split("\n")] edges = [(src, dst) for src, dst, *_ in edge_lines] for src, dst in edges: nodes.add(src) nodes.add(dst) yield NaiveClient(nodes=list(nodes), edges=edges) diff --git a/swh/graph/tests/test_grpc.py b/swh/graph/tests/test_grpc.py new file mode 100644 index 0000000..2cef192 --- /dev/null +++ b/swh/graph/tests/test_grpc.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import hashlib + +from google.protobuf.field_mask_pb2 import FieldMask + +from swh.graph.rpc.swhgraph_pb2 import ( + GraphDirection, + NodeFilter, + StatsRequest, + TraversalRequest, +) + +TEST_ORIGIN_ID = "swh:1:ori:{}".format( + hashlib.sha1(b"https://example.com/swh/graph").hexdigest() +) + + +def test_stats(graph_grpc_stub): + stats = graph_grpc_stub.Stats(StatsRequest()) + assert stats.num_nodes == 21 + assert stats.num_edges == 23 + assert isinstance(stats.compression_ratio, float) + assert isinstance(stats.bits_per_node, float) + assert isinstance(stats.bits_per_edge, float) + assert isinstance(stats.avg_locality, float) + assert stats.indegree_min == 0 + assert stats.indegree_max == 3 + assert isinstance(stats.indegree_avg, float) + assert stats.outdegree_min == 0 + assert stats.outdegree_max == 3 + assert isinstance(stats.outdegree_avg, float) + + +def test_leaves(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=[TEST_ORIGIN_ID], + mask=FieldMask(paths=["swhid"]), + return_nodes=NodeFilter(types="cnt"), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + "swh:1:cnt:0000000000000000000000000000000000000007", + ] + assert set(actual) == set(expected) + + +def test_neighbors(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rev:0000000000000000000000000000000000000009"], + direction=GraphDirection.BACKWARD, + mask=FieldMask(paths=["swhid"]), + min_depth=1, + max_depth=1, + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:snp:0000000000000000000000000000000000000020", + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000013", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + edges="rel:rev,rev:rev", + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + return_nodes=NodeFilter(types="dir"), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_grpc_stub): + request = graph_grpc_stub.Traverse( + TraversalRequest( + src=["swh:1:rel:0000000000000000000000000000000000000010"], + mask=FieldMask(paths=["swhid"]), + ) + ) + actual = [node.swhid for node in request] + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected)