diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,8 @@ +grpcio pytest pytest-rabbitmq swh.loader.git >= 0.8 swh.journal >= 0.8 swh.storage >= 0.40 -swh.graph >= 0.3.2 +swh.graph[testing] >= 1.0.1 types-Deprecated diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -43,13 +43,11 @@ elif cls == "graph": try: - from swh.graph.client import RemoteGraphClient from swh.storage import get_storage from .swhgraph.archive import ArchiveGraph - graph = RemoteGraphClient(kwargs.get("url")) - return ArchiveGraph(graph, get_storage(**kwargs["storage"])) + return ArchiveGraph(kwargs.get("url"), get_storage(**kwargs["storage"])) except ModuleNotFoundError: raise EnvironmentError( diff --git a/swh/provenance/swhgraph/archive.py b/swh/provenance/swhgraph/archive.py --- a/swh/provenance/swhgraph/archive.py +++ b/swh/provenance/swhgraph/archive.py @@ -5,7 +5,11 @@ from typing import Any, Dict, Iterable, Tuple +from google.protobuf.field_mask_pb2 import FieldMask +import grpc + from swh.core.statsd import statsd +from swh.graph.rpc import swhgraph_pb2, swhgraph_pb2_grpc from swh.model.model import Sha1Git from swh.model.swhids import CoreSWHID, ObjectType from swh.storage.interface import StorageInterface @@ -14,8 +18,10 @@ class ArchiveGraph: - def __init__(self, graph, storage: StorageInterface) -> None: - self.graph = graph + def __init__(self, url, storage: StorageInterface) -> None: + self.graph_url = url + self._channel = grpc.insecure_channel(self.graph_url) + self._stub = swhgraph_pb2_grpc.TraversalServiceStub(self._channel) self.storage = storage # required by ArchiveInterface @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "directory_ls"}) @@ -29,23 +35,46 @@ def revision_get_some_outbound_edges( self, revision_id: Sha1Git ) -> Iterable[Tuple[Sha1Git, Sha1Git]]: - src = CoreSWHID(object_type=ObjectType.REVISION, object_id=revision_id) - request = self.graph.visit_edges(str(src), edges="rev:rev") - - for edge in request: - if edge: - yield ( - CoreSWHID.from_string(edge[0]).object_id, - CoreSWHID.from_string(edge[1]).object_id, - ) + src = str(CoreSWHID(object_type=ObjectType.REVISION, object_id=revision_id)) + request = self._stub.Traverse( + swhgraph_pb2.TraversalRequest( + src=[src], + edges="rev:rev", + max_edges=1000, + mask=FieldMask(paths=["swhid", "successor"]), + ) + ) + try: + for node in request: + obj_id = CoreSWHID.from_string(node.swhid).object_id + if node.successor: + for parent in node.successor: + yield (obj_id, CoreSWHID.from_string(parent.swhid).object_id) + except grpc.RpcError as e: + if ( + e.code() == grpc.StatusCode.INVALID_ARGUMENT + and "Unknown SWHID" in e.details() + ): + pass + raise @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "snapshot_get_heads"}) def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: - src = CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=id) - request = self.graph.visit_nodes( - str(src), edges="snp:rev,snp:rel,rel:rev", return_types="rev" - ) - - yield from ( - CoreSWHID.from_string(swhid).object_id for swhid in request if swhid + src = str(CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=id)) + request = self._stub.Traverse( + swhgraph_pb2.TraversalRequest( + src=[src], + edges="snp:rev,snp:rel,rel:rev", + return_nodes=swhgraph_pb2.NodeFilter(types="rev"), + mask=FieldMask(paths=["swhid"]), + ) ) + try: + yield from (CoreSWHID.from_string(node.swhid).object_id for node in request) + except grpc.RpcError as e: + if ( + e.code() == grpc.StatusCode.INVALID_ARGUMENT + and "Unknown SWHID" in e.details() + ): + pass + raise diff --git a/swh/provenance/tests/conftest.py b/swh/provenance/tests/conftest.py --- a/swh/provenance/tests/conftest.py +++ b/swh/provenance/tests/conftest.py @@ -3,16 +3,21 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager from datetime import datetime +import multiprocessing from os import path +from pathlib import Path from typing import Any, Dict, Generator, List from _pytest.fixtures import SubRequest +from aiohttp.test_utils import TestClient, TestServer, loop_context import msgpack import psycopg2.extensions import pytest from pytest_postgresql.factories import postgresql +from swh.graph.http_server import make_app from swh.journal.serializers import msgpack_ext_hook from swh.model.model import BaseModel, TimestampWithTimezone from swh.provenance import get_provenance, get_provenance_storage @@ -160,3 +165,37 @@ def ts2dt(ts: Dict[str, Any]) -> datetime: return TimestampWithTimezone.from_dict(ts).to_datetime() + + +def run_grpc_server(queue, dataset_path): + try: + config = {"graph": {"path": dataset_path}} + with loop_context() as loop: + app = make_app(config=config, debug=True, spawn_rpc_port=None) + client = TestClient(TestServer(app), loop=loop) + loop.run_until_complete(client.start_server()) + url = client.make_url("/graph/") + queue.put((url, app["rpc_url"])) + loop.run_forever() + except Exception as e: + queue.put(e) + + +@contextmanager +def grpc_server(dataset): + dataset_path = ( + Path(__file__).parents[0] / "data/swhgraph" / dataset / "compressed/example" + ) + queue = multiprocessing.Queue() + server = multiprocessing.Process( + target=run_grpc_server, kwargs={"queue": queue, "dataset_path": dataset_path} + ) + server.start() + res = queue.get() + if isinstance(res, Exception): + raise res + grpc_url = res[1] + try: + yield grpc_url + finally: + server.terminate() diff --git a/swh/provenance/tests/data/README.md b/swh/provenance/tests/data/README.md --- a/swh/provenance/tests/data/README.md +++ b/swh/provenance/tests/data/README.md @@ -19,7 +19,10 @@ - a set of synthetic files, named `synthetic_xxx_(lower|upper)_.txt`, describing the expected result in the provenance database if ingested with the flag `lower` set or not set, and the `mindepth` value (integer, most - often `1` or `2`). + often `1` or `2`), +- a swh-graph compressed dataset (in the `swhgraph/` directory), used for testing + the ArchiveGraph backend. + ### Generate datasets files For each dataset `xxx`, execute a number of commands: @@ -29,6 +32,7 @@ python generate_repo.py -C ${dataset}_repo.yaml $dataset > synthetic_${dataset}_template.txt # you may want to edit/update synthetic files from this template, see below python generate_storage_from_git.py $dataset + python generate_graph_dataset.py --compress $dataset done ``` diff --git a/swh/provenance/tests/data/generate_graph_dataset.py b/swh/provenance/tests/data/generate_graph_dataset.py new file mode 100755 --- /dev/null +++ b/swh/provenance/tests/data/generate_graph_dataset.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +# type: ignore + +import argparse +import logging +from pathlib import Path +import shutil + +from swh.dataset.exporters.edges import GraphEdgesExporter +from swh.dataset.exporters.orc import ORCExporter +from swh.graph.webgraph import compress +from swh.provenance.tests.conftest import load_repo_data + + +def main(): + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser(description="Generate a test dataset") + parser.add_argument( + "--compress", + action="store_true", + default=False, + help="Also compress the dataset", + ) + parser.add_argument("--output", help="output directory", default="swhgraph") + parser.add_argument("dataset", help="dataset name", nargs="+") + args = parser.parse_args() + + for repo in args.dataset: + exporters = {"edges": GraphEdgesExporter, "orc": ORCExporter} + config = {"test_unique_file_id": "all"} + output_path = Path(args.output) / repo + data = load_repo_data(repo) + print(data.keys()) + + for name, exporter in exporters.items(): + if (output_path / name).exists(): + shutil.rmtree(output_path / name) + with exporter(config, output_path / name) as e: + for object_type, objs in data.items(): + for obj_dict in objs: + e.process_object(object_type, obj_dict) + + if args.compress: + if (output_path / "compressed").exists(): + shutil.rmtree(output_path / "compressed") + compress("example", output_path / "orc", output_path / "compressed") + + +if __name__ == "__main__": + main() diff --git a/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example-labelled.labelobl b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example-labelled.labelobl new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@©ø°ú碷ô]­­vò̲\‚-Ï[Ö¾ÕŸiÖÆ»z ]“|µËÈòUÒíkË·‘®j[¼‡kê—ï"͖RÝvÅ-ÞF›–Ü \ No newline at end of file diff --git a/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.indegree b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.indegree new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.indegree @@ -0,0 +1,6 @@ +3 +44 +11 +8 +2 +2 diff --git a/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.labels.count.txt b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.labels.count.txt new file mode 100644 --- /dev/null +++ b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.labels.count.txt @@ -0,0 +1 @@ +36 diff --git a/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.labels.csv.zst b/swh/provenance/tests/data/swhgraph/cmdbts2/compressed/example.labels.csv.zst new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@ None: + data = load_repo_data(repo) + fill_storage(archive.storage, data) + + with grpc_server(repo) as url: + # test against ArchiveGraph + archive_graph = ArchiveGraph(url, archive.storage) + with pytest.raises(NotImplementedError): + check_directory_ls(archive, archive_graph, data) + check_revision_get_some_outbound_edges(archive, archive_graph, data) + check_snapshot_get_heads(archive, archive_graph, data) + + +@pytest.mark.parametrize( + "repo", + ("cmdbts2", "out-of-order", "with-merges"), +) +def test_archive_multiplexed(repo: str, archive: ArchiveInterface) -> None: + # read data/README.md for more details on how these datasets are generated + data = load_repo_data(repo) + fill_storage(archive.storage, data) # test against ArchiveMultiplexer - archive_multiplexed = ArchiveMultiplexed( - [("noop", ArchiveNoop()), ("graph", archive_graph), ("api", archive_api)] - ) - check_directory_ls(archive, archive_multiplexed, data) - check_revision_get_some_outbound_edges(archive, archive_multiplexed, data) - check_snapshot_get_heads(archive, archive_multiplexed, data) + with grpc_server(repo) as url: + archive_api = ArchiveStorage(archive.storage) + archive_graph = ArchiveGraph(url, archive.storage) + archive_multiplexed = ArchiveMultiplexed( + [("noop", ArchiveNoop()), ("graph", archive_graph), ("api", archive_api)] + ) + check_directory_ls(archive, archive_multiplexed, data) + check_revision_get_some_outbound_edges(archive, archive_multiplexed, data) + check_snapshot_get_heads(archive, archive_multiplexed, data) def test_noop_multiplexer():