diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py index 9950edc..4dcce81 100644 --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -1,329 +1,330 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass from datetime import datetime import enum from typing import Dict, Generator, Iterable, Optional, Set, Union from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry class EntityType(enum.Enum): CONTENT = "content" DIRECTORY = "directory" REVISION = "revision" ORIGIN = "origin" class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" @dataclass(eq=True, frozen=True) class ProvenanceResult: content: Sha1Git revision: Sha1Git date: datetime origin: Optional[str] path: bytes @dataclass(eq=True, frozen=True) class RevisionData: """Object representing the data associated to a revision in the provenance model, where `date` is the optional date of the revision (specifying it acknowledges that the revision was already processed by the revision-content algorithm); and `origin` identifies the preferred origin for the revision, if any. """ date: Optional[datetime] origin: Optional[Sha1Git] @dataclass(eq=True, frozen=True) class RelationData: """Object representing a relation entry in the provenance model, where `src` and `dst` are the sha1 ids of the entities being related, and `path` is optional depending on the relation being represented. """ - src: Sha1Git dst: Sha1Git path: Optional[bytes] @runtime_checkable class ProvenanceStorageInterface(Protocol): @remote_api_endpoint("content_add") def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: """Add blobs identified by sha1 ids, with an optional associated date (as paired in `cnts`) to the provenance storage. Return a boolean stating whether the information was successfully stored. """ ... @remote_api_endpoint("content_find_first") def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... @remote_api_endpoint("content_find_all") def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... @remote_api_endpoint("content_get") def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("directory_add") def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: """Add directories identified by sha1 ids, with an optional associated date (as paired in `dirs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("directory_get") def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each directory sha1 in `ids`. If some directory has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance model. """ ... @remote_api_endpoint("location_add") def location_add(self, paths: Iterable[bytes]) -> bool: """Register the given `paths` in the storage.""" ... @remote_api_endpoint("location_get_all") def location_get_all(self) -> Set[bytes]: """Retrieve all paths present in the provenance model.""" ... @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired in `orgs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("origin_get") def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: """Retrieve the associated url for each origin sha1 in `ids`.""" ... @remote_api_endpoint("revision_add") def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: """Add revisions identified by sha1 ids, with optional associated date or origin (as paired in `revs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("revision_get") def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("relation_add") def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: """Add entries in the selected `relation`. This method assumes all entities being related are already registered in the storage. See `content_add`, `directory_add`, `origin_add`, and `revision_add`. """ ... @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... @remote_api_endpoint("relation_get_all") - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` that are present in the provenance model. """ ... @remote_api_endpoint("with_path") def with_path(self) -> bool: ... @runtime_checkable class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `directory` in the provenance model. `prefix` is the relative path from `directory` to `blob` (excluding `blob`'s name). """ ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `revision` in the provenance model. `prefix` is the absolute path from `revision`'s root directory to `blob` (excluding `blob`'s name). """ ... def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: """Retrieve the earliest known date of `blob`.""" ... def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each blob in `blobs`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: """Associate `date` to `blob` as it's earliest known date.""" ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: """Associate `directory` with `revision` in the provenance model. `path` is the absolute path from `revision`'s root directory to `directory` (including `directory`'s name). """ ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: """Retrieve the earliest known date of `directory` as an isochrone frontier in the provenance model. """ ... def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each directory in `dirs` as isochrone frontiers provenance model. If some directory has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: """Associate `date` to `directory` as it's earliest known date as an isochrone frontier in the provenance model. """ ... def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... def revision_add(self, revision: RevisionEntry) -> None: """Add `revision` to the provenance model. This implies storing `revision`'s date in the model, thus `revision.date` must be a valid date. """ ... def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `head` as an ancestor of the latter.""" ... def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `origin` as a head revision of the latter (ie. the target of an snapshot for `origin` in the archive).""" ... def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: """Retrieve the date associated to `revision`.""" ... def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: """Retrieve the preferred origin associated to `revision`.""" ... def revision_in_history(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be an ancestor of some head revision in the provenance model. """ ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `origin` as the preferred origin for `revision`.""" ... def revision_visited(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be a head revision for some origin in the provenance model. """ ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py index 981d38a..f963746 100644 --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,443 +1,460 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import os from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union from bson import ObjectId import pymongo.database from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) class ProvenanceStorageMongoDb: def __init__(self, db: pymongo.database.Database): self.db = db def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) existing = { x["sha1"]: x for x in self.db.content.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: cnt = existing[sha1] if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): self.db.content.update_one( {"_id": cnt["_id"]}, {"$set": {"ts": ts}} ) else: self.db.content.insert_one( { "sha1": sha1, "ts": ts, "revision": {}, "directory": {}, } ) return True def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: # get all the revisions # iterate and find the earliest content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) for directory in self.db.directory.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} ): for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( os.path.join(prefix, suffix) if prefix not in [b".", b""] else suffix ) occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp( revision["ts"], timezone.utc ), origin=origin["url"], path=path, ) ) yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.content.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) existing = { x["sha1"]: x for x in self.db.directory.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: dir = existing[sha1] if ts is not None and (dir["ts"] is None or ts < dir["ts"]): self.db.directory.update_one( {"_id": dir["_id"]}, {"$set": {"ts": ts}} ) else: self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) return True def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.directory.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] for x in self.db.get_collection(entity.value).find( {}, {"sha1": 1, "_id": 0} ) } def location_add(self, paths: Iterable[bytes]) -> bool: # TODO: implement this methods if path are to be stored in a separate collection return True def location_get_all(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] for content in contents: paths.extend(value for _, value in content["revision"].items()) paths.extend(value for _, value in content["directory"].items()) dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) for each_dir in dirs: paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x for x in self.db.origin.find( {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} ) } for sha1, url in orgs.items(): if sha1 not in existing: # add new origin self.db.origin.insert_one({"sha1": sha1, "url": url}) return True def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return { x["sha1"]: x["url"] for x in self.db.origin.find( {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} ) } def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: data = ( revs if isinstance(revs, dict) else dict.fromkeys(revs, RevisionData(date=None, origin=None)) ) existing = { x["sha1"]: x for x in self.db.revision.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, ) } for sha1, info in data.items(): ts = datetime.timestamp(info.date) if info.date is not None else None preferred = info.origin if sha1 in existing: rev = existing[sha1] if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): ts = rev["ts"] if preferred is None: preferred = rev["preferred"] if ts != rev["ts"] or preferred != rev["preferred"]: self.db.revision.update_one( {"_id": rev["_id"]}, {"$set": {"ts": ts, "preferred": preferred}}, ) else: self.db.revision.insert_one( { "sha1": sha1, "preferred": preferred, "origin": [], "revision": [], "ts": ts, } ) return True def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return { x["sha1"]: RevisionData( date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, origin=x["preferred"], ) for x in self.db.revision.find( { "sha1": {"$in": list(ids)}, "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], }, {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, ) } def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") - set_data = set(data) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": [x.dst for x in set_data]}}, {"_id": 1, "sha1": 1} + { + "sha1": { + "$in": list({rel.dst for rels in data.values() for rel in rels}) + } + }, + {"_id": 1, "sha1": 1}, ) } denorm: Dict[Sha1Git, Any] = {} - for each in set_data: - if src_relation != "revision": - denorm.setdefault(each.src, {}).setdefault( - str(dst_objs[each.dst]), [] - ).append(each.path) - else: - denorm.setdefault(each.src, []).append(dst_objs[each.dst]) + for src, rels in data.items(): + for rel in rels: + if src_relation != "revision": + denorm.setdefault(src, {}).setdefault( + str(dst_objs[rel.dst]), [] + ).append(rel.path) + else: + denorm.setdefault(src, []).append(dst_objs[rel.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( - {"sha1": {"$in": list(denorm)}} + {"sha1": {"$in": list(denorm.keys())}} ) } for sha1, dsts in denorm.items(): # update if src_relation != "revision": k = { obj_id: list(set(paths + dsts.get(obj_id, []))) for obj_id, paths in src_objs[sha1][dst_relation].items() } self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, {"$set": {dst_relation: dict(dsts, **k)}}, ) else: self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, { "$set": { dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) } }, ) return True def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: + empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( - {"sha1": {"$in": list(sha1s)}}, {"_id": 0, "sha1": 1, dst: 1} + {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, + {"_id": 0, "sha1": 1, dst: 1}, ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) + src_sha1: { + RelationData(dst=dst_sha1, path=path) + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) } else: return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) + src_sha1: { + RelationData(dst=dst_sha1, path=None) + for dst_sha1, dst_obj_id in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref } else: dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} ) } src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {}, {"_id": 0, "sha1": 1, dst: 1} ) } + result: Dict[Sha1Git, Set[RelationData]] = {} if src != "revision": - return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) - for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) - } + for dst_sha1, dst_obj_id in dst_objs.items(): + for src_sha1, denorm in src_objs.items(): + for dst_obj_str, paths in denorm.items(): + if dst_obj_id == ObjectId(dst_obj_str): + result.setdefault(src_sha1, set()).update( + RelationData(dst=dst_sha1, path=path) + for path in paths + ) else: - return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) - for src_sha1, denorm in src_objs.items() - for dst_sha1, dst_obj_id in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref - } + for dst_sha1, dst_obj_id in dst_objs.items(): + for src_sha1, denorm in src_objs.items(): + if dst_obj_id in { + ObjectId(dst_obj_str) for dst_obj_str in denorm + }: + result.setdefault(src_sha1, set()).add( + RelationData(dst=dst_sha1, path=None) + ) + return result - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") + empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] - for x in self.db.get_collection(src).find({}, {"_id": 0, "sha1": 1, dst: 1}) + for x in self.db.get_collection(src).find( + {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} + ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) + dst_objs = { + x["_id"]: x["sha1"] + for x in self.db.get_collection(dst).find( + {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} + ) + } if src != "revision": - dst_objs = { - x["_id"]: x["sha1"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } return { - RelationData(src=src_sha1, dst=dst_sha1, path=path) + src_sha1: { + RelationData(dst=dst_sha1, path=path) + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_str, paths in denorm.items() + for path in paths + if dst_obj_id == ObjectId(dst_obj_str) + } for src_sha1, denorm in src_objs.items() - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_str, paths in denorm.items() - for path in paths - if dst_obj_id == ObjectId(dst_obj_str) } else: - dst_objs = { - x["_id"]: x["sha1"] - for x in self.db.get_collection(dst).find( - {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} - ) - } return { - RelationData(src=src_sha1, dst=dst_sha1, path=None) + src_sha1: { + RelationData(dst=dst_sha1, path=None) + for dst_obj_id, dst_sha1 in dst_objs.items() + for dst_obj_ref in denorm + if dst_obj_id == dst_obj_ref + } for src_sha1, denorm in src_objs.items() - for dst_obj_id, dst_sha1 in dst_objs.items() - for dst_obj_ref in denorm - if dst_obj_id == dst_obj_ref } def with_path(self) -> bool: return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index 985a495..b4c4ded 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,318 +1,324 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime import itertools import logging from typing import Dict, Generator, Iterable, List, Optional, Set, Union import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) class ProvenanceStoragePostgreSql: def __init__( self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False ) -> None: BaseDb.adapt_conn(conn) self.conn = conn with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extensions.cursor, None, None]: self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: return self._entity_set_date("content", cnts) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] ) -> bool: return self._entity_set_date("directory", dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} def location_add(self, paths: Iterable[bytes]) -> bool: if not self.with_path(): return True try: values = [(path,) for path in paths] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=values) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: try: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items() ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: if isinstance(revs, dict): data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] else: data = [(sha1, None, None) for sha1 in revs] try: if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cur=cursor, sql=sql, argslist=data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result def relation_add( - self, relation: RelationType, data: Iterable[RelationData] + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: - rows = [(rel.src, rel.dst, rel.path) for rel in data] + rows = [ + (src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts + ] try: if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False - ) -> Set[RelationData]: + ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) - def relation_get_all(self, relation: RelationType) -> Set[RelationData]: + def relation_get_all( + self, relation: RelationType + ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( self, entity: Literal["content", "directory"], dates: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]], ) -> bool: data = dates if isinstance(dates, dict) else dict.fromkeys(dates) try: if data: sql = f""" INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, - ) -> Set[RelationData]: - result: Set[RelationData] = set() + ) -> Dict[Sha1Git, Set[RelationData]]: + result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) - result.update(RelationData(**row) for row in cursor) + for row in cursor: + src = row.pop("src") + result.setdefault(src, set()).add(RelationData(**row)) return result def with_path(self) -> bool: return "with-path" in self.flavor diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index 79f4b71..006419a 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,404 +1,392 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import os from typing import Dict, Generator, Iterable, Optional, Set, Tuple from typing_extensions import Literal, TypedDict from swh.model.model import Sha1Git from .interface import ( ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry LOGGER = logging.getLogger(__name__) class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] added: Set[Sha1Git] class OriginCache(TypedDict): data: Dict[Sha1Git, str] added: Set[Sha1Git] class RevisionCache(TypedDict): data: Dict[Sha1Git, Sha1Git] added: Set[Sha1Git] class ProvenanceCache(TypedDict): content: DatetimeCache directory: DatetimeCache revision: DatetimeCache # below are insertion caches only content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]] directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] # these two are for the origin layer origin: OriginCache revision_origin: RevisionCache revision_before_revision: Dict[Sha1Git, Set[Sha1Git]] revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]] def new_cache() -> ProvenanceCache: return ProvenanceCache( content=DatetimeCache(data={}, added=set()), directory=DatetimeCache(data={}, added=set()), revision=DatetimeCache(data={}, added=set()), content_in_revision=set(), content_in_directory=set(), directory_in_revision=set(), origin=OriginCache(data={}, added=set()), revision_origin=RevisionCache(data={}, added=set()), revision_before_revision={}, revision_in_origin=set(), ) class Provenance: def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() def clear_caches(self) -> None: self.cache = new_cache() def flush(self) -> None: # Revision-content layer insertions ############################################ # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. cnts = { src for src, _, _ in self.cache["content_in_revision"] | self.cache["content_in_directory"] } if cnts: while not self.storage.content_add(cnts): LOGGER.warning( "Unable to write content entities to the storage. Retrying..." ) dirs = {dst for _, dst, _ in self.cache["content_in_directory"]} if dirs: while not self.storage.directory_add(dirs): LOGGER.warning( "Unable to write directory entities to the storage. Retrying..." ) revs = { dst for _, dst, _ in self.cache["content_in_revision"] | self.cache["directory_in_revision"] } if revs: while not self.storage.revision_add(revs): LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) paths = { path for _, _, path in self.cache["content_in_revision"] | self.cache["content_in_directory"] | self.cache["directory_in_revision"] } if paths: while not self.storage.location_add(paths): LOGGER.warning( "Unable to write locations entities to the storage. Retrying..." ) # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: + cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["content_in_revision"]: + cnt_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) while not self.storage.relation_add( - RelationType.CNT_EARLY_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_revision"] - ), + RelationType.CNT_EARLY_IN_REV, cnt_in_rev ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_EARLY_IN_REV, ) if self.cache["content_in_directory"]: - while not self.storage.relation_add( - RelationType.CNT_IN_DIR, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["content_in_directory"] - ), - ): + cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["content_in_directory"]: + cnt_in_dir.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + while not self.storage.relation_add(RelationType.CNT_IN_DIR, cnt_in_dir): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_IN_DIR, ) if self.cache["directory_in_revision"]: - while not self.storage.relation_add( - RelationType.DIR_IN_REV, - ( - RelationData(src=src, dst=dst, path=path) - for src, dst, path in self.cache["directory_in_revision"] - ), - ): + dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst, path in self.cache["directory_in_revision"]: + dir_in_rev.setdefault(src, set()).add(RelationData(dst=dst, path=path)) + while not self.storage.relation_add(RelationType.DIR_IN_REV, dir_in_rev): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.DIR_IN_REV, ) # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. cnt_dates = { sha1: date for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } if cnt_dates: while not self.storage.content_add(cnt_dates): LOGGER.warning( "Unable to write content dates to the storage. Retrying..." ) dir_dates = { sha1: date for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } if dir_dates: while not self.storage.directory_add(dir_dates): LOGGER.warning( "Unable to write directory dates to the storage. Retrying..." ) rev_dates = { sha1: RevisionData(date=date, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } if rev_dates: while not self.storage.revision_add(rev_dates): LOGGER.warning( "Unable to write revision dates to the storage. Retrying..." ) # Origin-revision layer insertions ############################################# # Origins and revisions should be inserted first so that internal ids' # resolution works properly. urls = { sha1: url for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } if urls: while not self.storage.origin_add(urls): LOGGER.warning( "Unable to write origins urls to the storage. Retrying..." ) rev_orgs = { # Destinations in this relation should match origins in the next one **{ src: RevisionData(date=None, origin=None) for src in self.cache["revision_before_revision"] }, **{ # This relation comes second so that non-None origins take precedence src: RevisionData(date=None, origin=org) for src, org in self.cache["revision_in_origin"] }, } if rev_orgs: while not self.storage.revision_add(rev_orgs): LOGGER.warning( "Unable to write revision entities to the storage. Retrying..." ) # Second, flat models for revisions' histories (ie. revision-before-revision). - data: Iterable[RelationData] = sum( - [ - [ - RelationData(src=prev, dst=next, path=None) - for next in self.cache["revision_before_revision"][prev] - ] - for prev in self.cache["revision_before_revision"] - ], - [], - ) - if data: - while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): + if self.cache["revision_before_revision"]: + rev_before_rev = { + src: {RelationData(dst=dst, path=None) for dst in dsts} + for src, dsts in self.cache["revision_before_revision"].items() + } + while not self.storage.relation_add( + RelationType.REV_BEFORE_REV, rev_before_rev + ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_BEFORE_REV, ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. - data = ( - RelationData(src=rev, dst=org, path=None) - for rev, org in self.cache["revision_in_origin"] - ) - if data: - while not self.storage.relation_add(RelationType.REV_IN_ORG, data): + if self.cache["revision_in_origin"]: + rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} + for src, dst in self.cache["revision_in_origin"]: + rev_in_org.setdefault(src, set()).add(RelationData(dst=dst, path=None)) + while not self.storage.relation_add(RelationType.REV_IN_ORG, rev_in_org): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_IN_ORG, ) # clear local cache ############################################################ self.clear_caches() def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_directory"].add( (blob.id, directory.id, normalize(os.path.join(prefix, blob.name))) ) def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_revision"].add( (blob.id, revision.id, normalize(os.path.join(prefix, blob.name))) ) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: yield from self.storage.content_find_all(id, limit=limit) def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: return self.get_dates("content", [blob.id]).get(blob.id) def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("content", [blob.id for blob in blobs]) def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: self.cache["content"]["data"][blob.id] = date self.cache["content"]["added"].add(blob.id) def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: self.cache["directory_in_revision"].add( (directory.id, revision.id, normalize(path)) ) def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: return self.get_dates("directory", [directory.id]).get(directory.id) def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("directory", [directory.id for directory in dirs]) def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: self.cache["directory"]["data"][directory.id] = date self.cache["directory"]["added"].add(directory.id) def get_dates( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: cache = self.cache[entity] missing_ids = set(id for id in ids if id not in cache) if missing_ids: if entity == "revision": updated = { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) cache["data"].update(updated) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) def revision_add(self, revision: RevisionEntry) -> None: self.cache["revision"]["data"][revision.id] = revision.date self.cache["revision"]["added"].add(revision.id) def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: self.cache["revision_before_revision"].setdefault(revision.id, set()).add( head.id ) def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_in_origin"].add((revision.id, origin.id)) def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: return self.get_dates("revision", [revision.id]).get(revision.id) def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: cache = self.cache["revision_origin"]["data"] if revision.id not in cache: ret = self.storage.revision_get([revision.id]) if revision.id in ret: origin = ret[revision.id].origin if origin is not None: cache[revision.id] = origin return cache.get(revision.id) def revision_in_history(self, revision: RevisionEntry) -> bool: return revision.id in self.cache["revision_before_revision"] or bool( self.storage.relation_get(RelationType.REV_BEFORE_REV, [revision.id]) ) def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_origin"]["data"][revision.id] = origin.id self.cache["revision_origin"]["added"].add(revision.id) def revision_visited(self, revision: RevisionEntry) -> bool: return revision.id in dict(self.cache["revision_in_origin"]) or bool( self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id]) ) def normalize(path: bytes) -> bytes: return path[2:] if path.startswith(bytes("." + os.path.sep, "utf-8")) else path diff --git a/swh/provenance/tests/test_origin_revision_layer.py b/swh/provenance/tests/test_origin_revision_layer.py index 78c7634..1bbdf41 100644 --- a/swh/provenance/tests/test_origin_revision_layer.py +++ b/swh/provenance/tests/test_origin_revision_layer.py @@ -1,190 +1,194 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re from typing import Any, Dict, Iterable, Iterator, List, Set import pytest from typing_extensions import TypedDict from swh.model.hashutil import hash_to_bytes from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import EntityType, ProvenanceInterface, RelationType from swh.provenance.model import OriginEntry from swh.provenance.origin import origin_add from swh.provenance.tests.conftest import fill_storage, get_datafile, load_repo_data class SynthRelation(TypedDict): src: Sha1Git dst: Sha1Git name: str class SynthOrigin(TypedDict): sha1: Sha1Git url: str snap: Sha1Git O_R: List[SynthRelation] R_R: List[SynthRelation] def synthetic_origin_revision_result(filename: str) -> Iterator[SynthOrigin]: """Generates dict representations of synthetic origin visits found in the synthetic file (from the data/ directory) given as argument of the generator. Generated SynthOrigin (typed dict) with the following elements: "sha1": (Sha1Git) sha1 of the origin, "url": (str) url of the origin, "snap": (Sha1Git) sha1 of the visit's snapshot, "O_R": (list) new O-R relations added by this origin visit "R_R": (list) new R-R relations added by this origin visit Each relation above is a SynthRelation typed dict with: "src": (Sha1Git) sha1 of the source of the relation "dst": (Sha1Git) sha1 of the destination of the relation """ with open(get_datafile(filename), "r") as fobj: yield from _parse_synthetic_origin_revision_file(fobj) def _parse_synthetic_origin_revision_file(fobj: Iterable[str]) -> Iterator[SynthOrigin]: """Read a 'synthetic' file and generate a dict representation of the synthetic origin visit for each snapshot listed in the synthetic file. """ regs = [ "(?P[^ ]+)?", "(?P[^| ]*)", "(?PR[0-9]{2,4})?", "(?P[ORS]) (?P[0-9a-f]{40})", ] regex = re.compile("^ *" + r" *[|] *".join(regs) + r" *(#.*)?$") current_org: List[dict] = [] for m in (regex.match(line) for line in fobj): if m: d = m.groupdict() if d["url"]: if current_org: yield _mk_synth_org(current_org) current_org.clear() current_org.append(d) if current_org: yield _mk_synth_org(current_org) def _mk_synth_org(synth_org: List[Dict[str, str]]) -> SynthOrigin: assert synth_org[0]["type"] == "O" assert synth_org[1]["type"] == "S" org = SynthOrigin( sha1=hash_to_bytes(synth_org[0]["sha1"]), url=synth_org[0]["url"], snap=hash_to_bytes(synth_org[1]["sha1"]), O_R=[], R_R=[], ) for row in synth_org[2:]: if row["reltype"] == "O-R": assert row["type"] == "R" org["O_R"].append( SynthRelation( src=org["sha1"], dst=hash_to_bytes(row["sha1"]), name=row["revname"], ) ) elif row["reltype"] == "R-R": assert row["type"] == "R" org["R_R"].append( SynthRelation( src=org["O_R"][-1]["dst"], dst=hash_to_bytes(row["sha1"]), name=row["revname"], ) ) return org @pytest.mark.parametrize( "repo, visit", (("with-merges", "visits-01"),), ) def test_origin_revision_layer( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, visit: str, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) syntheticfile = get_datafile(f"origin-revision_{repo}_{visit}.txt") origins = [ {"url": status["origin"], "snap": status["snapshot"]} for status in data["origin_visit_status"] if status["snapshot"] is not None ] rows: Dict[str, Set[Any]] = { "origin": set(), "revision_in_origin": set(), "revision_before_revision": set(), "revision": set(), } for synth_org in synthetic_origin_revision_result(syntheticfile): for origin in ( org for org in origins if org["url"] == synth_org["url"] and org["snap"] == synth_org["snap"] ): entry = OriginEntry(url=origin["url"], snapshot=origin["snap"]) origin_add(provenance, archive, [entry]) # each "entry" in the synth file is one new origin visit rows["origin"].add(synth_org["sha1"]) assert rows["origin"] == provenance.storage.entity_get_all( EntityType.ORIGIN ), synth_org["url"] # check the url of the origin assert ( provenance.storage.origin_get([synth_org["sha1"]])[synth_org["sha1"]] == synth_org["url"] ), synth_org["snap"] # this origin visit might have added new revision objects rows["revision"] |= set(x["dst"] for x in synth_org["O_R"]) rows["revision"] |= set(x["dst"] for x in synth_org["R_R"]) assert rows["revision"] == provenance.storage.entity_get_all( EntityType.REVISION ), synth_org["snap"] # check for O-R (head) entries # these are added in the revision_in_origin relation rows["revision_in_origin"] |= set( (x["dst"], x["src"], None) for x in synth_org["O_R"] ) assert rows["revision_in_origin"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.REV_IN_ORG) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.REV_IN_ORG + ).items() + for rel in rels }, synth_org["snap"] # check for R-R entries # these are added in the revision_before_revision relation rows["revision_before_revision"] |= set( (x["dst"], x["src"], None) for x in synth_org["R_R"] ) assert rows["revision_before_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all( + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( RelationType.REV_BEFORE_REV - ) + ).items() + for rel in rels }, synth_org["snap"] diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index 841cea6..d959cac 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,460 +1,482 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import inspect import os -from typing import Any, Dict, Iterable, Optional, Set +from typing import Any, Dict, Iterable, Optional, Set, Tuple from swh.model.hashutil import hash_to_bytes from swh.model.identifiers import origin_identifier from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ( EntityType, ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) from swh.provenance.model import OriginEntry, RevisionEntry from swh.provenance.mongo.backend import ProvenanceStorageMongoDb from swh.provenance.origin import origin_add from swh.provenance.provenance import Provenance from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt def test_provenance_storage_content( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests content methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all content present in the current repo to the storage, just assigning their # creation dates. Then check that the returned results when querying are the same. cnts = {cnt["sha1_git"] for idx, cnt in enumerate(data["content"]) if idx % 2 == 0} cnt_dates = { cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"]) if idx % 2 == 1 } assert cnts or cnt_dates assert provenance_storage.content_add(cnts) assert provenance_storage.content_add(cnt_dates) assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates assert provenance_storage.entity_get_all(EntityType.CONTENT) == cnts | set( cnt_dates.keys() ) def test_provenance_storage_directory( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Of all directories present in the current repo, only assign a date to those # containing blobs (picking the max date among the available ones). Then check that # the returned results when querying are the same. def getmaxdate( directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] ) -> Optional[datetime]: dates = [ content["ctime"] for entry in directory["entries"] for content in contents if entry["type"] == "file" and entry["target"] == content["sha1_git"] ] return max(dates) if dates else None dirs = { dir["id"] for dir in data["directory"] if getmaxdate(dir, data["content"]) is None } dir_dates = { dir["id"]: getmaxdate(dir, data["content"]) for dir in data["directory"] if getmaxdate(dir, data["content"]) is not None } assert dirs assert provenance_storage.directory_add(dirs) assert provenance_storage.directory_add(dir_dates) assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == dirs | set( dir_dates.keys() ) def test_provenance_storage_location( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests location methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Add all names of entries present in the directories of the current repo as paths # to the storage. Then check that the returned results when querying are the same. paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} assert provenance_storage.location_add(paths) if isinstance(provenance_storage, ProvenanceStorageMongoDb): # TODO: remove this when `location_add` is properly implemented for MongoDb. return if provenance_storage.with_path(): assert provenance_storage.location_get_all() == paths else: assert provenance_storage.location_get_all() == set() def test_provenance_storage_origin( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test origin methods. # Add all origins present in the current repo to the storage. Then check that the # returned results when querying are the same. orgs = {hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"]} assert orgs assert provenance_storage.origin_add(orgs) assert provenance_storage.origin_get(set(orgs.keys())) == orgs assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) def test_provenance_storage_revision( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test revision methods. # Add all revisions present in the current repo to the storage, assigning their # dates and an arbitrary origin to each one. Then check that the returned results # when querying are the same. origin = next(iter(data["origin"])) origin_sha1 = hash_to_bytes(origin_identifier(origin)) # Origin must be inserted in advance. assert provenance_storage.origin_add({origin_sha1: origin["url"]}) revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} rev_data = { rev["id"]: RevisionData( date=ts2dt(rev["date"]) if idx % 2 != 0 else None, origin=origin_sha1 if idx % 3 != 0 else None, ) for idx, rev in enumerate(data["revision"]) if idx % 6 != 0 } assert revs assert provenance_storage.revision_add(revs) assert provenance_storage.revision_add(rev_data) assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( rev_data.keys() ) def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", -) -> Iterable[RelationData]: +) -> Iterable[Tuple[Sha1Git, RelationData]]: content = { - RelationData(entry["target"], ref, os.path.join(prefix, entry["name"])) + ( + entry["target"], + RelationData(dst=ref, path=os.path.join(prefix, entry["name"])), + ) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( dircontent(data, ref, child, os.path.join(prefix, entry["name"])) ) return content def entity_add( storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git] ) -> bool: if entity == EntityType.CONTENT: return storage.content_add({sha1: None for sha1 in ids}) elif entity == EntityType.DIRECTORY: return storage.directory_add({sha1: None for sha1 in ids}) else: # entity == EntityType.REVISION: return storage.revision_add( {sha1: RevisionData(date=None, origin=None) for sha1 in ids} ) def relation_add_and_compare_result( - storage: ProvenanceStorageInterface, relation: RelationType, data: Set[RelationData] + storage: ProvenanceStorageInterface, + relation: RelationType, + data: Dict[Sha1Git, Set[RelationData]], ) -> None: # Source, destinations and locations must be added in advance. src, *_, dst = relation.value.split("_") + srcs = {sha1 for sha1 in data} if src != "origin": - assert entity_add(storage, EntityType(src), {entry.src for entry in data}) + assert entity_add(storage, EntityType(src), srcs) + dsts = {rel.dst for rels in data.values() for rel in rels} if dst != "origin": - assert entity_add(storage, EntityType(dst), {entry.dst for entry in data}) + assert entity_add(storage, EntityType(dst), dsts) if storage.with_path(): assert storage.location_add( - {entry.path for entry in data if entry.path is not None} + {rel.path for rels in data.values() for rel in rels if rel.path is not None} ) assert data assert storage.relation_add(relation, data) - for row in data: - assert relation_compare_result( - storage.relation_get(relation, [row.src]), - {entry for entry in data if entry.src == row.src}, + for src_sha1 in srcs: + relation_compare_result( + storage.relation_get(relation, [src_sha1]), + {src_sha1: data[src_sha1]}, storage.with_path(), ) - assert relation_compare_result( - storage.relation_get( - relation, - [row.dst], - reverse=True, - ), - {entry for entry in data if entry.dst == row.dst}, + for dst_sha1 in dsts: + relation_compare_result( + storage.relation_get(relation, [dst_sha1], reverse=True), + { + src_sha1: { + RelationData(dst=dst_sha1, path=rel.path) + for rel in rels + if dst_sha1 == rel.dst + } + for src_sha1, rels in data.items() + if dst_sha1 in {rel.dst for rel in rels} + }, storage.with_path(), ) - - assert relation_compare_result( + relation_compare_result( storage.relation_get_all(relation), data, storage.with_path() ) def relation_compare_result( - computed: Set[RelationData], expected: Set[RelationData], with_path: bool -) -> bool: - return { - RelationData(row.src, row.dst, row.path if with_path else None) - for row in expected + computed: Dict[Sha1Git, Set[RelationData]], + expected: Dict[Sha1Git, Set[RelationData]], + with_path: bool, +) -> None: + assert { + src_sha1: { + RelationData(dst=rel.dst, path=rel.path if with_path else None) + for rel in rels + } + for src_sha1, rels in expected.items() } == computed def test_provenance_storage_relation( provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. - cnt_in_rev: Set[RelationData] = set() + cnt_in_rev: Dict[Sha1Git, Set[RelationData]] = {} for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) - cnt_in_rev.update(dircontent(data, rev["id"], root)) + for cnt, rel in dircontent(data, rev["id"], root): + cnt_in_rev.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. - cnt_in_dir: Set[RelationData] = set() + cnt_in_dir: Dict[Sha1Git, Set[RelationData]] = {} for dir in data["directory"]: - cnt_in_dir.update(dircontent(data, dir["id"], dir)) + for cnt, rel in dircontent(data, dir["id"], dir): + cnt_in_dir.setdefault(cnt, set()).add(rel) relation_add_and_compare_result( provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. - dir_in_rev = { - RelationData(rev["directory"], rev["id"], b".") for rev in data["revision"] - } + dir_in_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + dir_in_rev.setdefault(rev["directory"], set()).add( + RelationData(dst=rev["id"], path=b".") + ) relation_add_and_compare_result( provenance_storage, RelationType.DIR_IN_REV, dir_in_rev ) # Test revision-in-origin relation. - # Add all revisions that are head of some snapshot branch to the corresponding - # origin. - rev_in_org = { - RelationData( - branch["target"], - hash_to_bytes(origin_identifier({"url": status["origin"]})), - None, - ) - for status in data["origin_visit_status"] - if status["snapshot"] is not None - for snapshot in data["snapshot"] - if snapshot["id"] == status["snapshot"] - for _, branch in snapshot["branches"].items() - if branch["target_type"] == "revision" - } # Origins must be inserted in advance (cannot be done by `entity_add` inside # `relation_add_and_compare_result`). orgs = { hash_to_bytes(origin_identifier(origin)): origin["url"] for origin in data["origin"] } assert provenance_storage.origin_add(orgs) - + # Add all revisions that are head of some snapshot branch to the corresponding + # origin. + rev_in_org: Dict[Sha1Git, Set[RelationData]] = {} + for status in data["origin_visit_status"]: + if status["snapshot"] is not None: + for snapshot in data["snapshot"]: + if snapshot["id"] == status["snapshot"]: + for branch in snapshot["branches"].values(): + if branch["target_type"] == "revision": + rev_in_org.setdefault(branch["target"], set()).add( + RelationData( + dst=hash_to_bytes( + origin_identifier({"url": status["origin"]}) + ), + path=None, + ) + ) relation_add_and_compare_result( provenance_storage, RelationType.REV_IN_ORG, rev_in_org ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. - rev_before_rev = { - RelationData(parent, rev["id"], None) - for rev in data["revision"] - for parent in rev["parents"] - } + rev_before_rev: Dict[Sha1Git, Set[RelationData]] = {} + for rev in data["revision"]: + for parent in rev["parents"]: + rev_before_rev.setdefault(parent, set()).add( + RelationData(dst=rev["id"], path=None) + ) relation_add_and_compare_result( provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev ) def test_provenance_storage_find( archive: ArchiveInterface, provenance: ProvenanceInterface, provenance_storage: ProvenanceStorageInterface, ) -> None: """Tests `content_find_first` and `content_find_all` methods for every `ProvenanceStorageInterface` implementation. """ # Read data/README.md for more details on how these datasets are generated. data = load_repo_data("cmdbts2") fill_storage(archive.storage, data) # Test content_find_first and content_find_all, first only executing the # revision-content algorithm, then adding the origin-revision layer. def adapt_result( result: Optional[ProvenanceResult], with_path: bool ) -> Optional[ProvenanceResult]: if result is not None: return ProvenanceResult( result.content, result.revision, result.date, result.origin, result.path if with_path else b"", ) return result # Execute the revision-content algorithm on both storages. revisions = [ RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) for rev in data["revision"] ] revision_add(provenance, archive, revisions) revision_add(Provenance(provenance_storage), archive, revisions) assert adapt_result( ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin=None, path=b"A/B/C/a", ), provenance_storage.with_path(), ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) # Execute the origin-revision algorithm on both storages. origins = [ OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) for sta in data["origin_visit_status"] if sta["snapshot"] is not None ] origin_add(provenance, archive, origins) origin_add(Provenance(provenance_storage), archive, origins) assert adapt_result( ProvenanceResult( content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), date=datetime.fromtimestamp(1000000000.0, timezone.utc), origin="https://cmdbts2", path=b"A/B/C/a", ), provenance_storage.with_path(), ) == provenance_storage.content_find_first( hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") ) for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) def test_types(provenance_storage: ProvenanceInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (ProvenanceStorageInterface,), {})() assert "content_find_first" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: concrete_meth = getattr(provenance_storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] # If all the assertions above succeed, then this one should too. # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. assert isinstance(provenance_storage, ProvenanceStorageInterface) diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py index 36c8aa4..574c9fd 100644 --- a/swh/provenance/tests/test_revision_content_layer.py +++ b/swh/provenance/tests/test_revision_content_layer.py @@ -1,447 +1,454 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple import pytest from typing_extensions import TypedDict from swh.model.hashutil import hash_to_bytes from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import EntityType, ProvenanceInterface, RelationType from swh.provenance.model import RevisionEntry from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import ( fill_storage, get_datafile, load_repo_data, ts2dt, ) class SynthRelation(TypedDict): prefix: Optional[str] path: str src: Sha1Git dst: Sha1Git rel_ts: float class SynthRevision(TypedDict): sha1: Sha1Git date: float msg: str R_C: List[SynthRelation] R_D: List[SynthRelation] D_C: List[SynthRelation] def synthetic_revision_content_result(filename: str) -> Iterator[SynthRevision]: """Generates dict representations of synthetic revisions found in the synthetic file (from the data/ directory) given as argument of the generator. Generated SynthRevision (typed dict) with the following elements: "sha1": (Sha1Git) sha1 of the revision, "date": (float) timestamp of the revision, "msg": (str) commit message of the revision, "R_C": (list) new R---C relations added by this revision "R_D": (list) new R-D relations added by this revision "D_C": (list) new D-C relations added by this revision Each relation above is a SynthRelation typed dict with: "path": (str) location "src": (Sha1Git) sha1 of the source of the relation "dst": (Sha1Git) sha1 of the destination of the relation "rel_ts": (float) timestamp of the target of the relation (related to the timestamp of the revision) """ with open(get_datafile(filename), "r") as fobj: yield from _parse_synthetic_revision_content_file(fobj) def _parse_synthetic_revision_content_file( fobj: Iterable[str], ) -> Iterator[SynthRevision]: """Read a 'synthetic' file and generate a dict representation of the synthetic revision for each revision listed in the synthetic file. """ regs = [ "(?PR[0-9]{2,4})?", "(?P[^| ]*)", "([+] )?(?P[^| +]*?)[/]?", "(?P[RDC]) (?P[0-9a-f]{40})", "(?P-?[0-9]+(.[0-9]+)?)", ] regex = re.compile("^ *" + r" *[|] *".join(regs) + r" *(#.*)?$") current_rev: List[dict] = [] for m in (regex.match(line) for line in fobj): if m: d = m.groupdict() if d["revname"]: if current_rev: yield _mk_synth_rev(current_rev) current_rev.clear() current_rev.append(d) if current_rev: yield _mk_synth_rev(current_rev) def _mk_synth_rev(synth_rev: List[Dict[str, str]]) -> SynthRevision: assert synth_rev[0]["type"] == "R" rev = SynthRevision( sha1=hash_to_bytes(synth_rev[0]["sha1"]), date=float(synth_rev[0]["ts"]), msg=synth_rev[0]["revname"], R_C=[], R_D=[], D_C=[], ) current_path = None # path of the last R-D relation we parsed, used a prefix for next D-C # relations for row in synth_rev[1:]: if row["reltype"] == "R---C": assert row["type"] == "C" rev["R_C"].append( SynthRelation( prefix=None, path=row["path"], src=rev["sha1"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) current_path = None elif row["reltype"] == "R-D": assert row["type"] == "D" rev["R_D"].append( SynthRelation( prefix=None, path=row["path"], src=rev["sha1"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) current_path = row["path"] elif row["reltype"] == "D-C": assert row["type"] == "C" rev["D_C"].append( SynthRelation( prefix=current_path, path=row["path"], src=rev["R_D"][-1]["dst"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) return rev @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) def test_revision_content_result( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) revisions = {rev["id"]: rev for rev in data["revision"]} rows: Dict[str, Set[Any]] = { "content": set(), "content_in_directory": set(), "content_in_revision": set(), "directory": set(), "directory_in_revision": set(), "location": set(), "revision": set(), } def maybe_path(path: str) -> Optional[bytes]: if provenance.storage.with_path(): return path.encode("utf-8") return None for synth_rev in synthetic_revision_content_result(syntheticfile): revision = revisions[synth_rev["sha1"]] entry = RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth) # each "entry" in the synth file is one new revision rows["revision"].add(synth_rev["sha1"]) assert rows["revision"] == provenance.storage.entity_get_all( EntityType.REVISION ), synth_rev["msg"] # check the timestamp of the revision rev_ts = synth_rev["date"] rev_data = provenance.storage.revision_get([synth_rev["sha1"]])[ synth_rev["sha1"] ] assert ( rev_data.date is not None and rev_ts == rev_data.date.timestamp() ), synth_rev["msg"] # this revision might have added new content objects rows["content"] |= set(x["dst"] for x in synth_rev["R_C"]) rows["content"] |= set(x["dst"] for x in synth_rev["D_C"]) assert rows["content"] == provenance.storage.entity_get_all( EntityType.CONTENT ), synth_rev["msg"] # check for R-C (direct) entries # these are added directly in the content_early_in_rev table rows["content_in_revision"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) assert rows["content_in_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all( + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( RelationType.CNT_EARLY_IN_REV - ) + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: assert ( rev_ts + rc["rel_ts"] == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp() ), synth_rev["msg"] # check directories # each directory stored in the provenance index is an entry # in the "directory" table... rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"]) assert rows["directory"] == provenance.storage.entity_get_all( EntityType.DIRECTORY ), synth_rev["msg"] # ... + a number of rows in the "directory_in_rev" table... # check for R-D entries rows["directory_in_revision"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) assert rows["directory_in_revision"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.DIR_IN_REV) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.DIR_IN_REV + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: assert ( rev_ts + rd["rel_ts"] == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp() ), synth_rev["msg"] # ... + a number of rows in the "content_in_dir" table # for content of the directory. # check for D-C entries rows["content_in_directory"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) assert rows["content_in_directory"] == { - (rel.src, rel.dst, rel.path) - for rel in provenance.storage.relation_get_all(RelationType.CNT_IN_DIR) + (src, rel.dst, rel.path) + for src, rels in provenance.storage.relation_get_all( + RelationType.CNT_IN_DIR + ).items() + for rel in rels }, synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: assert ( rev_ts + dc["rel_ts"] == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp() ), synth_rev["msg"] if provenance.storage.with_path(): # check for location entries rows["location"] |= set(x["path"].encode() for x in synth_rev["R_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["D_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["R_D"]) assert rows["location"] == provenance.storage.location_get_all(), synth_rev[ "msg" ] @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_all( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] def maybe_path(path: str) -> str: if provenance.storage.with_path(): return path return "" if batch: revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) else: for revision in revisions: revision_add( provenance, archive, [revision], lower=lower, mindepth=mindepth ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_occurrences: Dict[str, List[Tuple[str, float, Optional[str], str]]] = {} for synth_rev in synthetic_revision_content_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: expected_occurrences.setdefault(rc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(rc["path"])) ) for dc in synth_rev["D_C"]: assert dc["prefix"] is not None # to please mypy expected_occurrences.setdefault(dc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(dc["prefix"] + "/" + dc["path"])) ) for content_id, results in expected_occurrences.items(): expected = [(content_id, *result) for result in results] db_occurrences = [ ( occur.content.hex(), occur.revision.hex(), occur.date.timestamp(), occur.origin, occur.path.decode(), ) for occur in provenance.content_find_all(hash_to_bytes(content_id)) ] if provenance.storage.with_path(): # this is not true if the db stores no path, because a same content # that appears several times in a given revision may be reported # only once by content_find_all() assert len(db_occurrences) == len(expected) assert set(db_occurrences) == set(expected) @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_first( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] if batch: revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) else: for revision in revisions: revision_add( provenance, archive, [revision], lower=lower, mindepth=mindepth ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_first: Dict[str, Tuple[str, float, List[str]]] = {} # dict of tuples (blob_id, rev_id, [path, ...]) the third element for path # is a list because a content can be added at several places in a single # revision, in which case the result of content_find_first() is one of # those path, but we have no guarantee which one it will return. for synth_rev in synthetic_revision_content_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: sha1 = rc["dst"].hex() if sha1 not in expected_first: assert rc["rel_ts"] == 0 expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) else: if rev_ts == expected_first[sha1][1]: expected_first[sha1][2].append(rc["path"]) elif rev_ts < expected_first[sha1][1]: expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) for dc in synth_rev["D_C"]: sha1 = rc["dst"].hex() assert sha1 in expected_first # nothing to do there, this content cannot be a "first seen file" for content_id, (rev_id, ts, paths) in expected_first.items(): occur = provenance.content_find_first(hash_to_bytes(content_id)) assert occur is not None assert occur.content.hex() == content_id assert occur.revision.hex() == rev_id assert occur.date.timestamp() == ts assert occur.origin is None if provenance.storage.with_path(): assert occur.path.decode() in paths