diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -44,11 +44,12 @@ elif cls == "graph": try: from swh.graph.client import RemoteGraphClient + from swh.storage import get_storage from .swhgraph.archive import ArchiveGraph graph = RemoteGraphClient(kwargs.get("url")) - return ArchiveGraph(graph, get_storage(**kwargs["storage"])) + return ArchiveGraph(graph, get_storage(cls="memory")) except ModuleNotFoundError: raise EnvironmentError( diff --git a/swh/provenance/swhgraph/archive.py b/swh/provenance/swhgraph/archive.py --- a/swh/provenance/swhgraph/archive.py +++ b/swh/provenance/swhgraph/archive.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Set from swh.core.statsd import statsd from swh.model.model import Sha1Git @@ -17,6 +17,7 @@ def __init__(self, graph, storage: StorageInterface) -> None: self.graph = graph self.storage = storage # required by ArchiveInterface + self.parents: Dict[Sha1Git, Set[Sha1Git]] = {} @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "directory_ls"}) def directory_ls(self, id: Sha1Git, minsize: int = 0) -> Iterable[Dict[str, Any]]: @@ -26,10 +27,22 @@ metric=ARCHIVE_DURATION_METRIC, tags={"method": "revision_get_parents"} ) def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]: - src = CoreSWHID(object_type=ObjectType.REVISION, object_id=id) - request = self.graph.neighbors(str(src), edges="rev:rev", return_types="rev") - - yield from (CoreSWHID.from_string(swhid).object_id for swhid in request) + if id not in self.parents: + self.parents = {} + + src = CoreSWHID(object_type=ObjectType.REVISION, object_id=id) + edges = { + ( + CoreSWHID.from_string(child).object_id, + CoreSWHID.from_string(parent).object_id, + ) + for child, parent in self.graph.visit_edges(str(src), edges="rev:rev") + } + for child, parent in edges: + self.parents.setdefault(child, set()).add(parent) + self.parents.setdefault(parent, set()) + + yield from self.parents[id] @statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "snapshot_get_heads"}) def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]: @@ -38,4 +51,6 @@ str(src), edges="snp:rev,snp:rel,rel:rev", return_types="rev" ) - yield from (CoreSWHID.from_string(swhid).object_id for swhid in request) + yield from ( + CoreSWHID.from_string(swhid).object_id for swhid in request if swhid + ) diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -31,7 +31,7 @@ from swh.provenance.postgresql.archive import ArchivePostgreSQL from swh.provenance.storage.archive import ArchiveStorage from swh.provenance.swhgraph.archive import ArchiveGraph -from swh.provenance.tests.conftest import fill_storage, load_repo_data +from swh.provenance.tests.conftest import fill_storage, load_repo_data, objs_from_dict from swh.storage.postgresql.storage import Storage @@ -88,16 +88,6 @@ raise ValueError -def data_to_model(data: Dict[str, List[dict]]) -> Dict[str, List[BaseModel]]: - model: Dict[str, List[BaseModel]] = {} - for object_type, objects in data.items(): - for object in objects: - model.setdefault(object_type, []).append( - get_object_class(object_type).from_dict(object) - ) - return model - - def add_link( edges: Set[ Tuple[ @@ -129,17 +119,20 @@ ] ] = set() - model = data_to_model(data) + objects = { + objtype: [objs_from_dict(objtype, d) for d in dicts] + for objtype, dicts in data.items() + } - for origin in model["origin"]: + for origin in objects["origin"]: assert isinstance(origin, Origin) nodes.add(origin.swhid()) - for status in model["origin_visit_status"]: + for status in objects["origin_visit_status"]: assert isinstance(status, OriginVisitStatus) if status.origin == origin.url and status.snapshot is not None: add_link(edges, origin, status.snapshot, ExtendedObjectType.SNAPSHOT) - for snapshot in model["snapshot"]: + for snapshot in objects["snapshot"]: assert isinstance(snapshot, Snapshot) nodes.add(snapshot.swhid()) for branch in snapshot.branches.values(): @@ -152,7 +145,7 @@ ) add_link(edges, snapshot, branch.target, target_type) - for revision in model["revision"]: + for revision in objects["revision"]: assert isinstance(revision, Revision) nodes.add(revision.swhid()) # root directory @@ -161,7 +154,7 @@ for parent in revision.parents: add_link(edges, revision, parent, ExtendedObjectType.REVISION) - for directory in model["directory"]: + for directory in objects["directory"]: assert isinstance(directory, Directory) nodes.add(directory.swhid()) for entry in directory.entries: @@ -174,7 +167,7 @@ target_type = ExtendedObjectType.REVISION add_link(edges, directory, entry.target, target_type) - for content in model["content"]: + for content in objects["content"]: assert isinstance(content, Content) nodes.add(content.swhid())