diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index a91c6e1..f3f1306 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,92 +1,92 @@ -# Copyright (C) 2019-2021 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import multiprocessing from pathlib import Path import subprocess from aiohttp.test_utils import TestClient, TestServer, loop_context import grpc import pytest from swh.graph.http_client import RemoteGraphClient from swh.graph.http_naive_client import NaiveClient from swh.graph.rpc.swhgraph_pb2_grpc import TraversalServiceStub SWH_GRAPH_TESTS_ROOT = Path(__file__).parents[0] TEST_GRAPH_PATH = SWH_GRAPH_TESTS_ROOT / "dataset/compressed/example" class GraphServerProcess(multiprocessing.Process): - def __init__(self, q, *args, **kwargs): - self.q = q + def __init__(self, *args, **kwargs): + self.q = multiprocessing.Queue() super().__init__(*args, **kwargs) def run(self): # Lazy import to allow debian packaging from swh.graph.http_server import make_app try: config = {"graph": {"path": TEST_GRAPH_PATH}} with loop_context() as loop: app = make_app(config=config, debug=True, spawn_rpc_port=None) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") - self.q.put((url, app["rpc_url"])) + self.q.put({"server_url": url, "rpc_url": app["rpc_url"]}) loop.run_forever() except Exception as e: self.q.put(e) + def start(self, *args, **kwargs): + super().start() + self.result = self.q.get() + @pytest.fixture(scope="module") def graph_grpc_server(): - queue = multiprocessing.Queue() - server = GraphServerProcess(queue) + server = GraphServerProcess() server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - grpc_url = res[1] + if isinstance(server.result, Exception): + raise server.result + grpc_url = server.result["rpc_url"] yield grpc_url server.terminate() @pytest.fixture(scope="module") def graph_grpc_stub(graph_grpc_server): with grpc.insecure_channel(graph_grpc_server) as channel: stub = TraversalServiceStub(channel) yield stub @pytest.fixture(scope="module", params=["remote", "naive"]) def graph_client(request): if request.param == "remote": - queue = multiprocessing.Queue() - server = GraphServerProcess(queue) + server = GraphServerProcess() server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - yield RemoteGraphClient(str(res[0])) + if isinstance(server.result, Exception): + raise server.result + yield RemoteGraphClient(str(server.result["server_url"])) server.terminate() else: def zstdcat(*files): p = subprocess.run(["zstdcat", *files], stdout=subprocess.PIPE) return p.stdout.decode() edges_dataset = SWH_GRAPH_TESTS_ROOT / "dataset/edges" edge_files = edges_dataset.glob("*/*.edges.csv.zst") node_files = edges_dataset.glob("*/*.nodes.csv.zst") nodes = set(zstdcat(*node_files).strip().split("\n")) edge_lines = [line.split() for line in zstdcat(*edge_files).strip().split("\n")] edges = [(src, dst) for src, dst, *_ in edge_lines] for src, dst in edges: nodes.add(src) nodes.add(dst) yield NaiveClient(nodes=list(nodes), edges=edges)