Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/tests/utils.py
# Copyright (C) 2022 The Software Heritage developers | # Copyright (C) 2022 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from datetime import datetime | from datetime import datetime | ||||
import logging | import logging | ||||
import multiprocessing | |||||
from os import path | from os import path | ||||
from pathlib import Path | from pathlib import Path | ||||
import socket | |||||
import tempfile | import tempfile | ||||
import time | |||||
from typing import Any, Dict, List, Optional | from typing import Any, Dict, List, Optional | ||||
from aiohttp.test_utils import TestClient, TestServer, loop_context | |||||
from click.testing import CliRunner, Result | from click.testing import CliRunner, Result | ||||
import msgpack | import msgpack | ||||
from yaml import safe_dump | from yaml import safe_dump | ||||
from swh.graph.http_rpc_server import make_app | from swh.graph.grpc_server import spawn_java_grpc_server, stop_java_grpc_server | ||||
from swh.journal.serializers import msgpack_ext_hook | from swh.journal.serializers import msgpack_ext_hook | ||||
from swh.model.model import BaseModel, TimestampWithTimezone | from swh.model.model import BaseModel, TimestampWithTimezone | ||||
from swh.provenance.cli import cli | from swh.provenance.cli import cli | ||||
from swh.storage.interface import StorageInterface | from swh.storage.interface import StorageInterface | ||||
from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects | from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects | ||||
logger = logging.getLogger(__name__) | |||||
def invoke( | def invoke( | ||||
args: List[str], config: Optional[Dict] = None, catch_exceptions: bool = False | args: List[str], config: Optional[Dict] = None, catch_exceptions: bool = False | ||||
) -> Result: | ) -> Result: | ||||
"""Invoke swh journal subcommands""" | """Invoke swh journal subcommands""" | ||||
runner = CliRunner() | runner = CliRunner() | ||||
with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: | with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: | ||||
if config is not None: | if config is not None: | ||||
▲ Show 20 Lines • Show All 47 Lines • ▼ Show 20 Lines | def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: | ||||
obj = OBJECT_CONVERTERS[object_type](dict_repr) | obj = OBJECT_CONVERTERS[object_type](dict_repr) | ||||
return obj | return obj | ||||
def ts2dt(ts: Dict[str, Any]) -> datetime: | def ts2dt(ts: Dict[str, Any]) -> datetime: | ||||
return TimestampWithTimezone.from_dict(ts).to_datetime() | return TimestampWithTimezone.from_dict(ts).to_datetime() | ||||
def run_grpc_server(queue, dataset_path): | |||||
try: | |||||
config = { | |||||
"graph": { | |||||
"cls": "local", | |||||
"grpc_server": {"path": dataset_path}, | |||||
"http_rpc_server": {"debug": True}, | |||||
} | |||||
} | |||||
with loop_context() as loop: | |||||
app = make_app(config=config) | |||||
client = TestClient(TestServer(app), loop=loop) | |||||
loop.run_until_complete(client.start_server()) | |||||
url = client.make_url("/graph/") | |||||
queue.put((url, app["rpc_url"])) | |||||
loop.run_forever() | |||||
except Exception as e: | |||||
queue.put(e) | |||||
@contextmanager | @contextmanager | ||||
def grpc_server(dataset): | def grpc_server(dataset): | ||||
dataset_path = ( | dataset_path = ( | ||||
Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" | Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" | ||||
) | ) | ||||
queue = multiprocessing.Queue() | server, port = spawn_java_grpc_server(path=dataset_path) | ||||
server = multiprocessing.Process( | logging.debug("Spawned GRPC server on port %s", port) | ||||
target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} | |||||
) | |||||
server.start() | |||||
res = queue.get() | |||||
if isinstance(res, Exception): | |||||
raise res | |||||
grpc_url = res[1] | |||||
try: | try: | ||||
yield grpc_url | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||||
logging.debug("Waiting for the TCP socket localhost:%s...", port) | |||||
for i in range(50): | |||||
if sock.connect_ex(("localhost", port)) == 0: | |||||
sock.close() | |||||
break | |||||
time.sleep(0.1) | |||||
else: | |||||
raise EnvironmentError( | |||||
"Cannot connect to the GRPC server on localhost:%s", port | |||||
) | |||||
logger.debug("Connection to localhost:%s OK", port) | |||||
yield f"localhost:{port}" | |||||
finally: | finally: | ||||
server.terminate() | stop_java_grpc_server(server) |