diff --git a/requirements-swh.txt b/requirements-swh.txt index 840225b..1865990 100644 --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,6 +1,6 @@ # Add here internal Software Heritage dependencies, one per line. swh.core[db,http] >= 0.14 swh.model >= 2.6.1 swh.storage -swh.graph +swh.graph >= 2.0.0 swh.journal diff --git a/swh/provenance/swhgraph/archive.py b/swh/provenance/swhgraph/archive.py index e0cce34..3a638fa 100644 --- a/swh/provenance/swhgraph/archive.py +++ b/swh/provenance/swhgraph/archive.py @@ -1,80 +1,80 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Any, Dict, Iterable, Tuple from google.protobuf.field_mask_pb2 import FieldMask import grpc from swh.core.statsd import statsd -from swh.graph.rpc import swhgraph_pb2, swhgraph_pb2_grpc +from swh.graph.grpc import swhgraph_pb2, swhgraph_pb2_grpc from swh.model.model import Sha1Git from swh.model.swhids import CoreSWHID, ObjectType from swh.storage.interface import StorageInterface ARCHIVE_DURATION_METRIC = "swh_provenance_archive_graph_duration_seconds" class ArchiveGraph: def __init__(self, url, storage: StorageInterface) -> None: self.graph_url = url self._channel = grpc.insecure_channel(self.graph_url) self._stub = swhgraph_pb2_grpc.TraversalServiceStub(self._channel) self.storage = storage # required by ArchiveInterface @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "directory_ls"}) def directory_ls(self, id: Sha1Git, minsize: int = 0) -> Iterable[Dict[str, Any]]: raise NotImplementedError @statsd.timed( metric=ARCHIVE_DURATION_METRIC, tags={"method": "revision_get_some_outbound_edges"}, ) def revision_get_some_outbound_edges( self, revision_id: Sha1Git ) -> Iterable[Tuple[Sha1Git, Sha1Git]]: src = str(CoreSWHID(object_type=ObjectType.REVISION, object_id=revision_id)) request = self._stub.Traverse( swhgraph_pb2.TraversalRequest( src=[src], edges="rev:rev", max_edges=1000, mask=FieldMask(paths=["swhid", "successor"]), ) ) try: for node in request: obj_id = CoreSWHID.from_string(node.swhid).object_id if node.successor: for parent in node.successor: yield (obj_id, CoreSWHID.from_string(parent.swhid).object_id) except grpc.RpcError as e: if ( e.code() == grpc.StatusCode.INVALID_ARGUMENT and "Unknown SWHID" in e.details() ): pass raise @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "snapshot_get_heads"}) def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: src = str(CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=id)) request = self._stub.Traverse( swhgraph_pb2.TraversalRequest( src=[src], edges="snp:rev,snp:rel,rel:rev", return_nodes=swhgraph_pb2.NodeFilter(types="rev"), mask=FieldMask(paths=["swhid"]), ) ) try: yield from (CoreSWHID.from_string(node.swhid).object_id for node in request) except grpc.RpcError as e: if ( e.code() == grpc.StatusCode.INVALID_ARGUMENT and "Unknown SWHID" in e.details() ): pass raise diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py index 20062ac..297d06b 100644 --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -1,201 +1,207 @@ # Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime import multiprocessing from os import path from pathlib import Path from typing import Any, Dict, Generator, List from _pytest.fixtures import SubRequest from aiohttp.test_utils import TestClient, TestServer, loop_context import msgpack import psycopg2.extensions import pytest from pytest_postgresql.factories import postgresql -from swh.graph.http_server import make_app +from swh.graph.http_rpc_server import make_app from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel, TimestampWithTimezone from swh.provenance import get_provenance, get_provenance_storage from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ProvenanceInterface, ProvenanceStorageInterface from swh.provenance.storage.archive import ArchiveStorage from swh.storage.interface import StorageInterface from swh.storage.replay import OBJECT_CONVERTERS, OBJECT_FIXERS, process_replay_objects @pytest.fixture( params=[ "with-path", "without-path", "with-path-denormalized", "without-path-denormalized", ] ) def provenance_postgresqldb( request: SubRequest, postgresql: psycopg2.extensions.connection, ) -> Dict[str, str]: """return a working and initialized provenance db""" from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, ) init_admin_extensions("swh.provenance", postgresql.dsn) populate_database_for_package( "swh.provenance", postgresql.dsn, flavor=request.param ) return postgresql.get_dsn_parameters() @pytest.fixture(params=["postgresql", "rabbitmq"]) def provenance_storage( request: SubRequest, provenance_postgresqldb: Dict[str, str], ) -> Generator[ProvenanceStorageInterface, None, None]: """Return a working and initialized ProvenanceStorageInterface object""" if request.param == "rabbitmq": from swh.provenance.api.server import ProvenanceStorageRabbitMQServer rabbitmq = request.getfixturevalue("rabbitmq") host = rabbitmq.args["host"] port = rabbitmq.args["port"] rabbitmq_params: Dict[str, Any] = { "url": f"amqp://guest:guest@{host}:{port}/%2f", "storage_config": { "cls": "postgresql", "db": provenance_postgresqldb, "raise_on_commit": True, }, } server = ProvenanceStorageRabbitMQServer( url=rabbitmq_params["url"], storage_config=rabbitmq_params["storage_config"] ) server.start() with get_provenance_storage(cls=request.param, **rabbitmq_params) as storage: yield storage server.stop() else: # in test sessions, we DO want to raise any exception occurring at commit time with get_provenance_storage( cls=request.param, db=provenance_postgresqldb, raise_on_commit=True ) as storage: yield storage provenance_postgresql = postgresql("postgresql_proc", dbname="provenance_tests") @pytest.fixture def provenance( provenance_postgresql: psycopg2.extensions.connection, ) -> Generator[ProvenanceInterface, None, None]: """Return a working and initialized ProvenanceInterface object""" from swh.core.db.db_utils import ( init_admin_extensions, populate_database_for_package, ) init_admin_extensions("swh.provenance", provenance_postgresql.dsn) populate_database_for_package( "swh.provenance", provenance_postgresql.dsn, flavor="with-path" ) # in test sessions, we DO want to raise any exception occurring at commit time with get_provenance( cls="postgresql", db=provenance_postgresql.get_dsn_parameters(), raise_on_commit=True, ) as provenance: yield provenance @pytest.fixture def archive(swh_storage: StorageInterface) -> ArchiveInterface: """Return an ArchiveStorage-based ArchiveInterface object""" return ArchiveStorage(swh_storage) def fill_storage(storage: StorageInterface, data: Dict[str, List[dict]]) -> None: objects = { objtype: [objs_from_dict(objtype, d) for d in dicts] for objtype, dicts in data.items() } process_replay_objects(objects, storage=storage) def get_datafile(fname: str) -> str: return path.join(path.dirname(__file__), "data", fname) # TODO: this should return Dict[str, List[BaseModel]] directly, but it requires # refactoring several tests def load_repo_data(repo: str) -> Dict[str, List[dict]]: data: Dict[str, List[dict]] = {} with open(get_datafile(f"{repo}.msgpack"), "rb") as fobj: unpacker = msgpack.Unpacker( fobj, raw=False, ext_hook=msgpack_ext_hook, strict_map_key=False, timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) for msg in unpacker: if len(msg) == 2: # old format objtype, objd = msg else: # now we should have a triplet (type, key, value) objtype, _, objd = msg data.setdefault(objtype, []).append(objd) return data def objs_from_dict(object_type: str, dict_repr: dict) -> BaseModel: if object_type in OBJECT_FIXERS: dict_repr = OBJECT_FIXERS[object_type](dict_repr) obj = OBJECT_CONVERTERS[object_type](dict_repr) return obj def ts2dt(ts: Dict[str, Any]) -> datetime: return TimestampWithTimezone.from_dict(ts).to_datetime() def run_grpc_server(queue, dataset_path): try: - config = {"graph": {"path": dataset_path}} + config = { + "graph": { + "cls": "local", + "grpc_server": {"path": dataset_path}, + "http_rpc_server": {"debug": True}, + } + } with loop_context() as loop: - app = make_app(config=config, debug=True, spawn_rpc_port=None) + app = make_app(config=config) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url("/graph/") queue.put((url, app["rpc_url"])) loop.run_forever() except Exception as e: queue.put(e) @contextmanager def grpc_server(dataset): dataset_path = ( Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" ) queue = multiprocessing.Queue() server = multiprocessing.Process( target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} ) server.start() res = queue.get() if isinstance(res, Exception): raise res grpc_url = res[1] try: yield grpc_url finally: server.terminate()