diff --git a/requirements-test.txt b/requirements-test.txt index 1742819..356d01f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,7 @@ pytest < 7.0.0 # v7.0.0 removed _pytest.tmpdir.TempdirFactory, which is used by some of the pytest plugins we use pytest-mongodb pytest-rabbitmq swh.loader.git >= 0.8 swh.journal >= 0.8 swh.storage >= 0.40 +swh.graph >= 0.3.2 diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 11da51f..506ca14 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,103 +1,118 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from typing import TYPE_CHECKING import warnings if TYPE_CHECKING: from .archive import ArchiveInterface from .interface import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> ArchiveInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: - cls: archive's class, either 'api' or 'direct' + cls: archive's class, either 'api', 'direct' or 'graph' args: dictionary of arguments passed to the archive class constructor Returns: an instance of archive object (either using swh.storage API or direct queries to the archive's database) Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) + elif cls == "direct": from swh.core.db import BaseDb from .postgresql.archive import ArchivePostgreSQL return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) + + elif cls == "graph": + try: + from swh.graph.client import RemoteGraphClient + + from .swhgraph.archive import ArchiveGraph + + graph = RemoteGraphClient(kwargs.get("url")) + return ArchiveGraph(graph, get_storage(**kwargs["storage"])) + + except ModuleNotFoundError: + raise EnvironmentError( + "Graph configuration required but module is not installed." + ) else: raise ValueError def get_provenance(**kwargs) -> ProvenanceInterface: """Get an provenance object with arguments ``args``. Args: args: dictionary of arguments to retrieve a swh.provenance.storage class (see :func:`get_provenance_storage` for details) Returns: an instance of provenance object """ from .provenance import Provenance return Provenance(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> ProvenanceStorageInterface: """Get an archive object of class ``cls`` with arguments ``args``. Args: cls: storage's class, only 'local' is currently supported args: dictionary of arguments passed to the storage class constructor Returns: an instance of storage object Raises: :cls:`ValueError` if passed an unknown archive class. """ if cls in ["local", "postgresql"]: from .postgresql.provenance import ProvenanceStoragePostgreSql if cls == "local": warnings.warn( '"local" class is deprecated for provenance storage, please ' 'use "postgresql" class instead.', DeprecationWarning, ) raise_on_commit = kwargs.get("raise_on_commit", False) return ProvenanceStoragePostgreSql( raise_on_commit=raise_on_commit, **kwargs["db"] ) elif cls == "mongodb": from .mongo.backend import ProvenanceStorageMongoDb engine = kwargs.get("engine", "pymongo") return ProvenanceStorageMongoDb(engine=engine, **kwargs["db"]) elif cls == "rabbitmq": from .api.client import ProvenanceStorageRabbitMQClient rmq_storage = ProvenanceStorageRabbitMQClient(**kwargs) if TYPE_CHECKING: assert isinstance(rmq_storage, ProvenanceStorageInterface) return rmq_storage raise ValueError diff --git a/swh/provenance/swhgraph/__init__.py b/swh/provenance/swhgraph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swh/provenance/swhgraph/archive.py b/swh/provenance/swhgraph/archive.py new file mode 100644 index 0000000..5424015 --- /dev/null +++ b/swh/provenance/swhgraph/archive.py @@ -0,0 +1,41 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any, Dict, Iterable + +from swh.core.statsd import statsd +from swh.model.model import Sha1Git +from swh.model.swhids import CoreSWHID, ObjectType +from swh.storage.interface import StorageInterface + +ARCHIVE_DURATION_METRIC = "swh_provenance_archive_graph_duration_seconds" + + +class ArchiveGraph: + def __init__(self, graph, storage: StorageInterface) -> None: + self.graph = graph + self.storage = storage # required by ArchiveInterface + + @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "directory_ls"}) + def directory_ls(self, id: Sha1Git, minsize: int = 0) -> Iterable[Dict[str, Any]]: + raise NotImplementedError + + @statsd.timed( + metric=ARCHIVE_DURATION_METRIC, tags={"method": "revision_get_parents"} + ) + def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]: + src = CoreSWHID(object_type=ObjectType.REVISION, object_id=id) + request = self.graph.neighbors(str(src), edges="rev:rev", return_types="rev") + + yield from (CoreSWHID.from_string(swhid).object_id for swhid in request) + + @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "snapshot_get_heads"}) + def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: + src = CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=id) + request = self.graph.visit_nodes( + str(src), edges="snp:rev,snp:rel,rel:rev", return_types="rev" + ) + + yield from (CoreSWHID.from_string(swhid).object_id for swhid in request) diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py index 9c4b21f..6d95fb0 100644 --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -1,61 +1,216 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Counter from operator import itemgetter from typing import Counter as TCounter +from typing import Dict, List, Set, Tuple, Type, Union import pytest from swh.core.db import BaseDb -from swh.model.model import Sha1Git +from swh.graph.naive_client import NaiveClient +from swh.model.model import ( + BaseModel, + Content, + Directory, + DirectoryEntry, + Origin, + OriginVisit, + OriginVisitStatus, + Revision, + Sha1Git, + Snapshot, + SnapshotBranch, + TargetType, +) +from swh.model.swhids import CoreSWHID, ExtendedObjectType, ExtendedSWHID +from swh.provenance.archive import ArchiveInterface from swh.provenance.postgresql.archive import ArchivePostgreSQL from swh.provenance.storage.archive import ArchiveStorage +from swh.provenance.swhgraph.archive import ArchiveGraph from swh.provenance.tests.conftest import fill_storage, load_repo_data -from swh.storage.interface import StorageInterface from swh.storage.postgresql.storage import Storage +def check_directory_ls( + reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] +) -> None: + for directory in data["directory"]: + entries_ref = sorted( + reference.directory_ls(directory["id"]), key=itemgetter("name") + ) + entries = sorted(archive.directory_ls(directory["id"]), key=itemgetter("name")) + assert entries_ref == entries + + +def check_revision_get_parents( + reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] +) -> None: + for revision in data["revision"]: + parents_ref: TCounter[Sha1Git] = Counter( + reference.revision_get_parents(revision["id"]) + ) + parents: TCounter[Sha1Git] = Counter( + archive.revision_get_parents(revision["id"]) + ) + assert parents_ref == parents + + +def check_snapshot_get_heads( + reference: ArchiveInterface, archive: ArchiveInterface, data: Dict[str, List[dict]] +) -> None: + for snapshot in data["snapshot"]: + heads_ref: TCounter[Sha1Git] = Counter( + reference.snapshot_get_heads(snapshot["id"]) + ) + heads: TCounter[Sha1Git] = Counter(archive.snapshot_get_heads(snapshot["id"])) + assert heads_ref == heads + + +def get_object_class(object_type: str) -> Type[BaseModel]: + if object_type == "origin": + return Origin + elif object_type == "origin_visit": + return OriginVisit + elif object_type == "origin_visit_status": + return OriginVisitStatus + elif object_type == "content": + return Content + elif object_type == "directory": + return Directory + elif object_type == "revision": + return Revision + elif object_type == "snapshot": + return Snapshot + raise ValueError + + +def data_to_model(data: Dict[str, List[dict]]) -> Dict[str, List[BaseModel]]: + model: Dict[str, List[BaseModel]] = {} + for object_type, objects in data.items(): + for object in objects: + model.setdefault(object_type, []).append( + get_object_class(object_type).from_dict(object) + ) + return model + + +def add_link( + edges: Set[ + Tuple[ + Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] + ] + ], + src_obj: Union[Origin, Snapshot, Revision, Directory, Content], + dst_id: bytes, + dst_type: ExtendedObjectType, +) -> None: + swhid = ExtendedSWHID(object_type=dst_type, object_id=dst_id) + edges.add((src_obj.swhid(), swhid)) + + +def get_graph_data( + data: Dict[str, List[dict]] +) -> Tuple[ + List[Union[CoreSWHID, ExtendedSWHID, str]], + List[ + Tuple[ + Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] + ] + ], +]: + nodes: Set[Union[CoreSWHID, ExtendedSWHID, str]] = set() + edges: Set[ + Tuple[ + Union[CoreSWHID, ExtendedSWHID, str], Union[CoreSWHID, ExtendedSWHID, str] + ] + ] = set() + + model = data_to_model(data) + + for origin in model["origin"]: + assert isinstance(origin, Origin) + nodes.add(origin.swhid()) + for status in model["origin_visit_status"]: + assert isinstance(status, OriginVisitStatus) + if status.origin == origin.url and status.snapshot is not None: + add_link(edges, origin, status.snapshot, ExtendedObjectType.SNAPSHOT) + + for snapshot in model["snapshot"]: + assert isinstance(snapshot, Snapshot) + nodes.add(snapshot.swhid()) + for branch in snapshot.branches.values(): + assert isinstance(branch, SnapshotBranch) + if branch.target_type in [TargetType.RELEASE, TargetType.REVISION]: + target_type = ( + ExtendedObjectType.RELEASE + if branch.target_type == TargetType.RELEASE + else ExtendedObjectType.REVISION + ) + add_link(edges, snapshot, branch.target, target_type) + + for revision in model["revision"]: + assert isinstance(revision, Revision) + nodes.add(revision.swhid()) + # root directory + add_link(edges, revision, revision.directory, ExtendedObjectType.DIRECTORY) + # parent + for parent in revision.parents: + add_link(edges, revision, parent, ExtendedObjectType.REVISION) + + for directory in model["directory"]: + assert isinstance(directory, Directory) + nodes.add(directory.swhid()) + for entry in directory.entries: + assert isinstance(entry, DirectoryEntry) + if entry.type == "file": + target_type = ExtendedObjectType.CONTENT + elif entry.type == "dir": + target_type = ExtendedObjectType.DIRECTORY + elif entry.type == "rev": + target_type = ExtendedObjectType.REVISION + add_link(edges, directory, entry.target, target_type) + + for content in model["content"]: + assert isinstance(content, Content) + nodes.add(content.swhid()) + + return list(nodes), list(edges) + + @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) -def test_archive_interface(repo: str, swh_storage: StorageInterface) -> None: - archive_api = ArchiveStorage(swh_storage) - assert isinstance(swh_storage, Storage) - dsn = swh_storage.get_db().conn.dsn +def test_archive_interface(repo: str, archive: ArchiveInterface) -> None: + # read data/README.md for more details on how these datasets are generated + data = load_repo_data(repo) + fill_storage(archive.storage, data) + + # test against ArchiveStorage + archive_api = ArchiveStorage(archive.storage) + check_directory_ls(archive, archive_api, data) + check_revision_get_parents(archive, archive_api, data) + check_snapshot_get_heads(archive, archive_api, data) + + # test against ArchivePostgreSQL + assert isinstance(archive.storage, Storage) + dsn = archive.storage.get_db().conn.dsn with BaseDb.connect(dsn).conn as conn: BaseDb.adapt_conn(conn) archive_direct = ArchivePostgreSQL(conn) - # read data/README.md for more details on how these datasets are generated - data = load_repo_data(repo) - fill_storage(swh_storage, data) + check_directory_ls(archive, archive_direct, data) + check_revision_get_parents(archive, archive_direct, data) + check_snapshot_get_heads(archive, archive_direct, data) - for directory in data["directory"]: - entries_api = sorted( - archive_api.directory_ls(directory["id"]), key=itemgetter("name") - ) - entries_direct = sorted( - archive_direct.directory_ls(directory["id"]), key=itemgetter("name") - ) - assert entries_api == entries_direct - - for revision in data["revision"]: - parents_api: TCounter[Sha1Git] = Counter( - archive_api.revision_get_parents(revision["id"]) - ) - parents_direct: TCounter[Sha1Git] = Counter( - archive_direct.revision_get_parents(revision["id"]) - ) - assert parents_api == parents_direct - - for snapshot in data["snapshot"]: - heads_api: TCounter[Sha1Git] = Counter( - archive_api.snapshot_get_heads(snapshot["id"]) - ) - heads_direct: TCounter[Sha1Git] = Counter( - archive_direct.snapshot_get_heads(snapshot["id"]) - ) - assert heads_api == heads_direct + # test against ArchiveGraph + nodes, edges = get_graph_data(data) + graph = NaiveClient(nodes=nodes, edges=edges) + archive_graph = ArchiveGraph(graph, archive.storage) + with pytest.raises(NotImplementedError): + check_directory_ls(archive, archive_graph, data) + check_revision_get_parents(archive, archive_graph, data) + check_snapshot_get_heads(archive, archive_graph, data)