Page MenuHomeSoftware Heritage

D7477.diff
No OneTemporary

D7477.diff

diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py
--- a/swh/provenance/__init__.py
+++ b/swh/provenance/__init__.py
@@ -44,11 +44,12 @@
elif cls == "graph":
try:
from swh.graph.client import RemoteGraphClient
+ from swh.storage import get_storage
from .swhgraph.archive import ArchiveGraph
graph = RemoteGraphClient(kwargs.get("url"))
- return ArchiveGraph(graph, get_storage(**kwargs["storage"]))
+ return ArchiveGraph(graph, get_storage(cls="memory"))
except ModuleNotFoundError:
raise EnvironmentError(
diff --git a/swh/provenance/swhgraph/archive.py b/swh/provenance/swhgraph/archive.py
--- a/swh/provenance/swhgraph/archive.py
+++ b/swh/provenance/swhgraph/archive.py
@@ -3,7 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from typing import Any, Dict, Iterable
+from typing import Any, Dict, Iterable, Set
from swh.core.statsd import statsd
from swh.model.model import Sha1Git
@@ -17,6 +17,7 @@
def __init__(self, graph, storage: StorageInterface) -> None:
self.graph = graph
self.storage = storage # required by ArchiveInterface
+ self.parents: Dict[Sha1Git, Set[Sha1Git]] = {}
@statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "directory_ls"})
def directory_ls(self, id: Sha1Git, minsize: int = 0) -> Iterable[Dict[str, Any]]:
@@ -26,10 +27,22 @@
metric=ARCHIVE_DURATION_METRIC, tags={"method": "revision_get_parents"}
)
def revision_get_parents(self, id: Sha1Git) -> Iterable[Sha1Git]:
- src = CoreSWHID(object_type=ObjectType.REVISION, object_id=id)
- request = self.graph.neighbors(str(src), edges="rev:rev", return_types="rev")
-
- yield from (CoreSWHID.from_string(swhid).object_id for swhid in request)
+ if id not in self.parents:
+ self.parents = {}
+
+ src = CoreSWHID(object_type=ObjectType.REVISION, object_id=id)
+ edges = {
+ (
+ CoreSWHID.from_string(child).object_id,
+ CoreSWHID.from_string(parent).object_id,
+ )
+ for child, parent in self.graph.visit_edges(str(src), edges="rev:rev")
+ }
+ for child, parent in edges:
+ self.parents.setdefault(child, set()).add(parent)
+ self.parents.setdefault(parent, set())
+
+ yield from self.parents[id]
@statsd.timed(metric=ARCHIVE_DURATION_METRIC, tags={"method": "snapshot_get_heads"})
def snapshot_get_heads(self, id: Sha1Git) -> Iterable[Sha1Git]:
@@ -38,4 +51,6 @@
str(src), edges="snp:rev,snp:rel,rel:rev", return_types="rev"
)
- yield from (CoreSWHID.from_string(swhid).object_id for swhid in request)
+ yield from (
+ CoreSWHID.from_string(swhid).object_id for swhid in request if swhid
+ )
diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py
--- a/swh/provenance/tests/test_archive_interface.py
+++ b/swh/provenance/tests/test_archive_interface.py
@@ -31,7 +31,7 @@
from swh.provenance.postgresql.archive import ArchivePostgreSQL
from swh.provenance.storage.archive import ArchiveStorage
from swh.provenance.swhgraph.archive import ArchiveGraph
-from swh.provenance.tests.conftest import fill_storage, load_repo_data
+from swh.provenance.tests.conftest import fill_storage, load_repo_data, objs_from_dict
from swh.storage.postgresql.storage import Storage
@@ -88,16 +88,6 @@
raise ValueError
-def data_to_model(data: Dict[str, List[dict]]) -> Dict[str, List[BaseModel]]:
- model: Dict[str, List[BaseModel]] = {}
- for object_type, objects in data.items():
- for object in objects:
- model.setdefault(object_type, []).append(
- get_object_class(object_type).from_dict(object)
- )
- return model
-
-
def add_link(
edges: Set[
Tuple[
@@ -129,17 +119,20 @@
]
] = set()
- model = data_to_model(data)
+ objects = {
+ objtype: [objs_from_dict(objtype, d) for d in dicts]
+ for objtype, dicts in data.items()
+ }
- for origin in model["origin"]:
+ for origin in objects["origin"]:
assert isinstance(origin, Origin)
nodes.add(origin.swhid())
- for status in model["origin_visit_status"]:
+ for status in objects["origin_visit_status"]:
assert isinstance(status, OriginVisitStatus)
if status.origin == origin.url and status.snapshot is not None:
add_link(edges, origin, status.snapshot, ExtendedObjectType.SNAPSHOT)
- for snapshot in model["snapshot"]:
+ for snapshot in objects["snapshot"]:
assert isinstance(snapshot, Snapshot)
nodes.add(snapshot.swhid())
for branch in snapshot.branches.values():
@@ -152,7 +145,7 @@
)
add_link(edges, snapshot, branch.target, target_type)
- for revision in model["revision"]:
+ for revision in objects["revision"]:
assert isinstance(revision, Revision)
nodes.add(revision.swhid())
# root directory
@@ -161,7 +154,7 @@
for parent in revision.parents:
add_link(edges, revision, parent, ExtendedObjectType.REVISION)
- for directory in model["directory"]:
+ for directory in objects["directory"]:
assert isinstance(directory, Directory)
nodes.add(directory.swhid())
for entry in directory.entries:
@@ -174,7 +167,7 @@
target_type = ExtendedObjectType.REVISION
add_link(edges, directory, entry.target, target_type)
- for content in model["content"]:
+ for content in objects["content"]:
assert isinstance(content, Content)
nodes.add(content.swhid())

File Metadata

Mime Type
text/plain
Expires
Mon, Apr 14, 7:00 AM (17 h, 17 m ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3221223

Event Timeline