diff --git a/pytest.ini b/pytest.ini --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,7 @@ norecursedirs = docs .* postgresql_postgres_options = -N 500 + +markers = + kafka + grpc diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -229,6 +229,7 @@ check_snapshot_get_heads(archive, archive_direct, data) +@pytest.mark.grpc @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), @@ -246,6 +247,7 @@ check_snapshot_get_heads(archive, archive_graph, data) +@pytest.mark.grpc @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), diff --git a/swh/provenance/tests/utils.py b/swh/provenance/tests/utils.py --- a/swh/provenance/tests/utils.py +++ b/swh/provenance/tests/utils.py @@ -7,24 +7,26 @@ from contextlib import contextmanager from datetime import datetime import logging -import multiprocessing from os import path from pathlib import Path +import socket import tempfile +import time from typing import Any, Dict, List, Optional -from aiohttp.test_utils import TestClient, TestServer, loop_context from click.testing import CliRunner, Result import msgpack from yaml import safe_dump -from swh.graph.http_rpc_server import make_app +from swh.graph.grpc_server import spawn_java_grpc_server, stop_java_grpc_server from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel, TimestampWithTimezone from swh.provenance.cli import cli from swh.storage.interface import StorageInterface from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects +logger = logging.getLogger(__name__) + def invoke( args: List[str], config: Optional[Dict] = None, catch_exceptions: bool = False @@ -88,41 +90,26 @@ return TimestampWithTimezone.from_dict(ts).to_datetime() -def run_grpc_server(queue, dataset_path): - try: - config = { - "graph": { - "cls": "local", - "grpc_server": {"path": dataset_path}, - "http_rpc_server": {"debug": True}, - } - } - with loop_context() as loop: - app = make_app(config=config) - client = TestClient(TestServer(app), loop=loop) - loop.run_until_complete(client.start_server()) - url = client.make_url("/graph/") - queue.put((url, app["rpc_url"])) - loop.run_forever() - except Exception as e: - queue.put(e) - - @contextmanager def grpc_server(dataset): dataset_path = ( Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" ) - queue = multiprocessing.Queue() - server = multiprocessing.Process( - target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} - ) - server.start() - res = queue.get() - if isinstance(res, Exception): - raise res - grpc_url = res[1] + server, port = spawn_java_grpc_server(path=dataset_path) + logging.debug("Spawned GRPC server on port %s", port) try: - yield grpc_url + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logging.debug("Waiting for the TCP socket localhost:%s...", port) + for i in range(50): + if sock.connect_ex(("localhost", port)) == 0: + sock.close() + break + time.sleep(0.1) + else: + raise EnvironmentError( + "Cannot connect to the GRPC server on localhost:%s", port + ) + logger.debug("Connection to localhost:%s OK", port) + yield f"localhost:{port}" finally: - server.terminate() + stop_java_grpc_server(server)