diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -59,7 +59,6 @@ depending on the relation being represented. """ - src: Sha1Git dst: Sha1Git path: Optional[bytes] @@ -162,7 +161,7 @@ @remote_api_endpoint("relation_add") def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: """Add entries in the selected `relation`. This method assumes all entities being related are already registered in the storage. See `content_add`, @@ -173,7 +172,7 @@ @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. @@ -181,7 +180,9 @@ ... @remote_api_endpoint("relation_get_all") - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` that are present in the provenance model. """ diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -281,31 +281,36 @@ } def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") - set_data = set(data) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": [x.dst for x in set_data]}}, {"_id": 1, "sha1": 1} + { + "sha1": { + "$in": list({rel.dst for rels in data.values() for rel in rels}) + } + }, + {"_id": 1, "sha1": 1}, ) } denorm: Dict[Sha1Git, Any] = {} - for each in set_data: - if src_relation != "revision": - denorm.setdefault(each.src, {}).setdefault( - str(dst_objs[each.dst]), [] - ).append(each.path) - else: - denorm.setdefault(each.src, []).append(dst_objs[each.dst]) + for src, rels in data.items(): + for rel in rels: + if src_relation != "revision": + denorm.setdefault(src, {}).setdefault( + str(dst_objs[rel.dst]), [] + ).append(rel.path) + else: + denorm.setdefault(src, []).append(dst_objs[rel.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( - {"sha1": {"$in": list(denorm)}} + {"sha1": {"$in": list(denorm.keys())}} ) } @@ -333,14 +338,16 @@ def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: + empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( - {"sha1": {"$in": list(sha1s)}}, {"_id": 0, "sha1": 1, dst: 1} + {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, + {"_id": 0, "sha1": 1, dst: 1}, ) } dst_ids = list( @@ -354,20 +361,24 @@ } if src != "revision": return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) + src_sha1: { + RelationData(dst=dst_sha1, path=path) + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) } else: return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) + src_sha1: { + RelationData(dst=dst_sha1, path=None) + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref } else: dst_objs = { @@ -382,61 +393,67 @@ {}, {"_id": 0, "sha1": 1, dst: 1} ) } + result: Dict[Sha1Git, Set[RelationData]] = {} if src != "revision": - return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) - for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) - } + for dst_sha1, dst_obj_id in dst_objs.items(): + for src_sha1, denorm in src_objs.items(): + for dst_obj_str, paths in denorm.items(): + if dst_obj_id == ObjectId(dst_obj_str): + result.setdefault(src_sha1, set()).update( + RelationData(dst=dst_sha1, path=path) + for path in paths + ) else: - return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) - for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref - } + for dst_sha1, dst_obj_id in dst_objs.items(): + for src_sha1, denorm in src_objs.items(): + if dst_obj_id in { + ObjectId(dst_obj_str) for dst_obj_str in denorm + }: + result.setdefault(src_sha1, set()).add( + RelationData(dst=dst_sha1, path=None) + ) + return result - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") + empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] - for x in self.db.get_collection(src).find({}, {"_id": 0, "sha1": 1, dst: 1}) + for x in self.db.get_collection(src).find( + {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} + ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } if src != "revision": - dst_objs = { - x["_id"]: x["sha1"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) + src_sha1: { + RelationData(dst=dst_sha1, path=path) + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } for src_sha1, denorm in src_objs.items() - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) } else: - dst_objs = { - x["_id"]: x["sha1"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) + src_sha1: { + RelationData(dst=dst_sha1, path=None) + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } for src_sha1, denorm in src_objs.items() - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref } def with_path(self) -> bool: diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -207,9 +207,11 @@ return result def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: - rows = [(rel.src, rel.dst, rel.path) for rel in data] + rows = [ + (src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts + ] try: if rows: rel_table = relation.value @@ -236,10 +238,12 @@ def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _entity_get_date( @@ -291,8 +295,8 @@ relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, - ) -> Set[RelationData]: - result: Set[RelationData] = set() + ) -> Dict[Sha1Git, Set[RelationData]]: + result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: @@ -311,7 +315,9 @@ cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) - result.update(RelationData(**row) for row in cursor) + for row in cursor: + src = row.pop("src") + result.setdefault(src, set()).add(RelationData(**row)) return result def with_path(self) -> bool: diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -126,12 +126,11 @@ # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: + cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["content_in_revision"]: + cnt_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add( - RelationType.CNT_EARLY_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_revision"] - ), + RelationType.CNT_EARLY_IN_REV, cnt_in_rev ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", @@ -139,26 +138,20 @@ ) if self.cache["content_in_directory"]: - while not self.storage.relation_add( - RelationType.CNT_IN_DIR, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_directory"] - ), - ): + cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["content_in_directory"]: + cnt_in_dir.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + while not self.storage.relation_add(RelationType.CNT_IN_DIR, cnt_in_dir): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_IN_DIR, ) if self.cache["directory_in_revision"]: - while not self.storage.relation_add( - RelationType.DIR_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["directory_in_revision"] - ), - ): + dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["directory_in_revision"]: + dir_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + while not self.storage.relation_add(RelationType.DIR_IN_REV, dir_in_rev): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.DIR_IN_REV, @@ -233,18 +226,14 @@ ) # Second, flat models for revisions' histories (ie. revision-before-revision). - data: Iterable[RelationData] = sum( - [ - [ - RelationData(src=prev, dst=next, path=None) - for next in self.cache["revision_before_revision"][prev] - ] - for prev in self.cache["revision_before_revision"] - ], - [], - ) - if data: - while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): + if self.cache["revision_before_revision"]: + rev_before_rev = { + src: {RelationData(dst=dst, path=None) for dst in dsts} + for src, dsts in self.cache["revision_before_revision"].items() + } + while not self.storage.relation_add( + RelationType.REV_BEFORE_REV, rev_before_rev + ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_BEFORE_REV, @@ -254,12 +243,11 @@ # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. - data = ( - RelationData(src=rev, dst=org, path=None) - for rev, org in self.cache["revision_in_origin"] - ) - if data: - while not self.storage.relation_add(RelationType.REV_IN_ORG, data): + if self.cache["revision_in_origin"]: + rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst in self.cache["revision_in_origin"]: + rev_in_org.setdefault(src, set()).add(RelationData(dst=dst, path=None)) + while not self.storage.relation_add(RelationType.REV_IN_ORG, rev_in_org): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_IN_ORG, diff --git a/swh/provenance/tests/test_origin_revision_layer.py b/swh/provenance/tests/test_origin_revision_layer.py --- a/swh/provenance/tests/test_origin_revision_layer.py +++ b/swh/provenance/tests/test_origin_revision_layer.py @@ -173,8 +173,11 @@ (x["dst"], x["src"], None) for x in synth_org["O_R"] ) assert rows["revision_in_origin"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.REV_IN_ORG) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.REV_IN_ORG + ).items() + for rel in rels }, synth_org["snap"] # check for R-R entries @@ -183,8 +186,9 @@ (x["dst"], x["src"], None) for x in synth_org["R_R"] ) assert rows["revision_before_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all( + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( RelationType.REV_BEFORE_REV - ) + ).items() + for rel in rels }, synth_org["snap"] diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import inspect import os -from typing import Any, Dict, Iterable, Optional, Set +from typing import Any, Dict, Iterable, Optional, Set, Tuple from swh.model.hashutil import hash_to_bytes from swh.model.identifiers import origin_identifier @@ -176,9 +176,12 @@ ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", -) -> Iterable[RelationData]: +) -> Iterable[Tuple[Sha1Git, RelationData]]: content = { - RelationData(entry["target"], ref, os.path.join(prefix, entry["name"])) + ( + entry["target"], + RelationData(dst=ref, path=os.path.join(prefix, entry["name"])), + ) for entry in dir["entries"] if entry["type"] == "file" } @@ -209,49 +212,62 @@ def relation_add_and_compare_result( - storage: ProvenanceStorageInterface, relation: RelationType, data: Set[RelationData] + storage: ProvenanceStorageInterface, + relation: RelationType, + data: Dict[Sha1Git, Set[RelationData]], ) -> None: # Source, destinations and locations must be added in advance. src, *_, dst = relation.value.split("_") + srcs = {sha1 for sha1 in data} if src != "origin": - assert entity_add(storage, EntityType(src), {entry.src for entry in data}) + assert entity_add(storage, EntityType(src), srcs) + dsts = {rel.dst for rels in data.values() for rel in rels} if dst != "origin": - assert entity_add(storage, EntityType(dst), {entry.dst for entry in data}) + assert entity_add(storage, EntityType(dst), dsts) if storage.with_path(): assert storage.location_add( - {entry.path for entry in data if entry.path is not None} + {rel.path for rels in data.values() for rel in rels if rel.path is not None} ) assert data assert storage.relation_add(relation, data) - for row in data: - assert relation_compare_result( - storage.relation_get(relation, [row.src]), - {entry for entry in data if entry.src == row.src}, + for src_sha1 in srcs: + relation_compare_result( + storage.relation_get(relation, [src_sha1]), + {src_sha1: data[src_sha1]}, storage.with_path(), ) - assert relation_compare_result( - storage.relation_get( - relation, - [row.dst], - reverse=True, - ), - {entry for entry in data if entry.dst == row.dst}, + for dst_sha1 in dsts: + relation_compare_result( + storage.relation_get(relation, [dst_sha1], reverse=True), + { + src_sha1: { + RelationData(dst=dst_sha1, path=rel.path) + for rel in rels + if dst_sha1 == rel.dst + } + for src_sha1, rels in data.items() + if dst_sha1 in {rel.dst for rel in rels} + }, storage.with_path(), ) - - assert relation_compare_result( + relation_compare_result( storage.relation_get_all(relation), data, storage.with_path() ) def relation_compare_result( - computed: Set[RelationData], expected: Set[RelationData], with_path: bool -) -> bool: - return { - RelationData(row.src, row.dst, row.path if with_path else None) - for row in expected + computed: Dict[Sha1Git, Set[RelationData]], + expected: Dict[Sha1Git, Set[RelationData]], + with_path: bool, +) -> None: + assert { + src_sha1: { + RelationData(dst=rel.dst, path=rel.path if with_path else None) + for rel in rels + } + for src_sha1, rels in expected.items() } == computed @@ -265,50 +281,39 @@ # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. - cnt_in_rev: Set[RelationData] = set() + cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) - cnt_in_rev.update(dircontent(data, rev["id"], root)) + for cnt, rel in dircontent(data, rev["id"], root): + cnt_in_rev.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. - cnt_in_dir: Set[RelationData] = set() + cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for dir in data["directory"]: - cnt_in_dir.update(dircontent(data, dir["id"], dir)) + for cnt, rel in dircontent(data, dir["id"], dir): + cnt_in_dir.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. - dir_in_rev = { - RelationData(rev["directory"], rev["id"], b".") for rev in data["revision"] - } + dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + dir_in_rev.setdefault(rev["directory"], set()).add( + RelationData(dst=rev["id"], path=b".") + ) relation_add_and_compare_result( provenance_storage, RelationType.DIR_IN_REV, dir_in_rev ) # Test revision-in-origin relation. - # Add all revisions that are head of some snapshot branch to the corresponding - # origin. - rev_in_org = { - RelationData( - branch["target"], - hash_to_bytes(origin_identifier({"url": status["origin"]})), - None, - ) - for status in data["origin_visit_status"] - if status["snapshot"] is not None - for snapshot in data["snapshot"] - if snapshot["id"] == status["snapshot"] - for _, branch in snapshot["branches"].items() - if branch["target_type"] == "revision" - } # Origins must be inserted in advance (cannot be done by `entity_add` inside # `relation_add_and_compare_result`). orgs = { @@ -316,18 +321,35 @@ for origin in data["origin"] } assert provenance_storage.origin_add(orgs) - + # Add all revisions that are head of some snapshot branch to the corresponding + # origin. + rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} + for status in data["origin_visit_status"]: + if status["snapshot"] is not None: + for snapshot in data["snapshot"]: + if snapshot["id"] == status["snapshot"]: + for branch in snapshot["branches"].values(): + if branch["target_type"] == "revision": + rev_in_org.setdefault(branch["target"], set()).add( + RelationData( + dst=hash_to_bytes( + origin_identifier({"url": status["origin"]}) + ), + path=None, + ) + ) relation_add_and_compare_result( provenance_storage, RelationType.REV_IN_ORG, rev_in_org ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. - rev_before_rev = { - RelationData(parent, rev["id"], None) - for rev in data["revision"] - for parent in rev["parents"] - } + rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + for parent in rev["parents"]: + rev_before_rev.setdefault(parent, set()).add( + RelationData(dst=rev["id"], path=None) + ) relation_add_and_compare_result( provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev ) diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py --- a/swh/provenance/tests/test_revision_content_layer.py +++ b/swh/provenance/tests/test_revision_content_layer.py @@ -224,10 +224,11 @@ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) assert rows["content_in_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all( + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( RelationType.CNT_EARLY_IN_REV - ) + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: @@ -250,8 +251,11 @@ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) assert rows["directory_in_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.DIR_IN_REV) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.DIR_IN_REV + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: @@ -267,8 +271,11 @@ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) assert rows["content_in_directory"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.CNT_IN_DIR) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.CNT_IN_DIR + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: