diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py index e1e20b6..9950edc 100644 --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -1,322 +1,329 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass from datetime import datetime import enum -from typing import Dict, Generator, Iterable, Optional, Set +from typing import Dict, Generator, Iterable, Optional, Set, Union from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry class EntityType(enum.Enum): CONTENT = "content" DIRECTORY = "directory" REVISION = "revision" ORIGIN = "origin" class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" @dataclass(eq=True, frozen=True) class ProvenanceResult: content: Sha1Git revision: Sha1Git date: datetime origin: Optional[str] path: bytes @dataclass(eq=True, frozen=True) class RevisionData: """Object representing the data associated to a revision in the provenance model, where `date` is the optional date of the revision (specifying it acknowledges that the revision was already processed by the revision-content algorithm); and `origin` identifies the preferred origin for the revision, if any. """ date: Optional[datetime] origin: Optional[Sha1Git] @dataclass(eq=True, frozen=True) class RelationData: """Object representing a relation entry in the provenance model, where `src` and `dst` are the sha1 ids of the entities being related, and `path` is optional depending on the relation being represented. """ src: Sha1Git dst: Sha1Git path: Optional[bytes] @runtime_checkable class ProvenanceStorageInterface(Protocol): + @remote_api_endpoint("content_add") + def content_add( + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + """Add blobs identified by sha1 ids, with an optional associated date (as paired + in `cnts`) to the provenance storage. Return a boolean stating whether the + information was successfully stored. + """ + ... + @remote_api_endpoint("content_find_first") def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... @remote_api_endpoint("content_find_all") def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... - @remote_api_endpoint("content_set_date") - def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to blobs identified by sha1 ids, as paired in `dates`. Return - a boolean stating whether the information was successfully stored. - """ - ... - @remote_api_endpoint("content_get") def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... - @remote_api_endpoint("directory_set_date") - def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to directories identified by sha1 ids, as paired in - `dates`. Return a boolean stating whether the information was successfully - stored. + @remote_api_endpoint("directory_add") + def directory_add( + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + """Add directories identified by sha1 ids, with an optional associated date (as + paired in `dirs`) to the provenance storage. Return a boolean stating if the + information was successfully stored. """ ... @remote_api_endpoint("directory_get") def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each directory sha1 in `ids`. If some directory has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance model. """ ... - @remote_api_endpoint("location_get") - def location_get(self) -> Set[bytes]: + @remote_api_endpoint("location_add") + def location_add(self, paths: Iterable[bytes]) -> bool: + """Register the given `paths` in the storage.""" + ... + + @remote_api_endpoint("location_get_all") + def location_get_all(self) -> Set[bytes]: """Retrieve all paths present in the provenance model.""" ... - @remote_api_endpoint("origin_set_url") - def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: - """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return - a boolean stating whether the information was successfully stored. + @remote_api_endpoint("origin_add") + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: + """Add origins identified by sha1 ids, with their corresponding url (as paired + in `orgs`) to the provenance storage. Return a boolean stating if the + information was successfully stored. """ ... @remote_api_endpoint("origin_get") def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: - """Retrieve the associated url for each origin sha1 in `ids`. If some origin has - no associated date, it is not present in the resulting dictionary. - """ - ... - - @remote_api_endpoint("revision_set_date") - def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - """Associate dates to revisions identified by sha1 ids, as paired in `dates`. - Return a boolean stating whether the information was successfully stored. - """ + """Retrieve the associated url for each origin sha1 in `ids`.""" ... - @remote_api_endpoint("revision_set_origin") - def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: - """Associate origins to revisions identified by sha1 ids, as paired in - `origins` (revision ids are keys and origin ids, values). Return a boolean - stating whether the information was successfully stored. + @remote_api_endpoint("revision_add") + def revision_add( + self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] + ) -> bool: + """Add revisions identified by sha1 ids, with optional associated date or origin + (as paired in `revs`) to the provenance storage. Return a boolean stating if the + information was successfully stored. """ ... @remote_api_endpoint("revision_get") def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("relation_add") def relation_add( self, relation: RelationType, data: Iterable[RelationData] ) -> bool: - """Add entries in the selected `relation`.""" + """Add entries in the selected `relation`. This method assumes all entities + being related are already registered in the storage. See `content_add`, + `directory_add`, `origin_add`, and `revision_add`. + """ ... @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[RelationData]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... @remote_api_endpoint("relation_get_all") def relation_get_all(self, relation: RelationType) -> Set[RelationData]: """Retrieve all entries in the selected `relation` that are present in the provenance model. """ ... @remote_api_endpoint("with_path") def with_path(self) -> bool: ... @runtime_checkable class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `directory` in the provenance model. `prefix` is the relative path from `directory` to `blob` (excluding `blob`'s name). """ ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `revision` in the provenance model. `prefix` is the absolute path from `revision`'s root directory to `blob` (excluding `blob`'s name). """ ... def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: """Retrieve the earliest known date of `blob`.""" ... def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each blob in `blobs`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: """Associate `date` to `blob` as it's earliest known date.""" ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: """Associate `directory` with `revision` in the provenance model. `path` is the absolute path from `revision`'s root directory to `directory` (including `directory`'s name). """ ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: """Retrieve the earliest known date of `directory` as an isochrone frontier in the provenance model. """ ... def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each directory in `dirs` as isochrone frontiers provenance model. If some directory has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: """Associate `date` to `directory` as it's earliest known date as an isochrone frontier in the provenance model. """ ... def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... def revision_add(self, revision: RevisionEntry) -> None: """Add `revision` to the provenance model. This implies storing `revision`'s date in the model, thus `revision.date` must be a valid date. """ ... def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `head` as an ancestor of the latter.""" ... def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `origin` as a head revision of the latter (ie. the target of an snapshot for `origin` in the archive).""" ... def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: """Retrieve the date associated to `revision`.""" ... def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: """Retrieve the preferred origin associated to `revision`.""" ... def revision_in_history(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be an ancestor of some head revision in the provenance model. """ ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `origin` as the preferred origin for `revision`.""" ... def revision_visited(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be a head revision for some origin in the provenance model. """ ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py index 3998669..981d38a 100644 --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,488 +1,443 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import os -from typing import Any, Dict, Generator, Iterable, List, Optional, Set +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union from bson import ObjectId import pymongo.database from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) class ProvenanceStorageMongoDb: def __init__(self, db: pymongo.database.Database): self.db = db + def content_add( + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) + existing = { + x["sha1"]: x + for x in self.db.content.find( + {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} + ) + } + for sha1, date in data.items(): + ts = datetime.timestamp(date) if date is not None else None + if sha1 in existing: + cnt = existing[sha1] + if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): + self.db.content.update_one( + {"_id": cnt["_id"]}, {"$set": {"ts": ts}} + ) + else: + self.db.content.insert_one( + { + "sha1": sha1, + "ts": ts, + "revision": {}, + "directory": {}, + } + ) + return True + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: # get all the revisions # iterate and find the earliest content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - assert origin is not None + if revision["preferred"] is not None: + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + else: + origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - assert origin is not None + if revision["preferred"] is not None: + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + else: + origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) for directory in self.db.directory.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} ): for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} ): - origin = self.db.origin.find_one({"sha1": revision["preferred"]}) - assert origin is not None + if revision["preferred"] is not None: + origin = self.db.origin.find_one({"sha1": revision["preferred"]}) + else: + origin = {"url": None} for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( os.path.join(prefix, suffix) if prefix not in [b".", b""] else suffix ) occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp( revision["ts"], timezone.utc ), origin=origin["url"], path=path, ) ) yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.content.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } - def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - # get all the docuemtns with the id, add date, add missing records - cnts = { - x["sha1"]: x - for x in self.db.content.find( - {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} - ) - } - - for sha1, date in dates.items(): - ts = datetime.timestamp(date) - if sha1 in cnts: - # update - if cnts[sha1]["ts"] is None or ts < cnts[sha1]["ts"]: - self.db.content.update_one( - {"_id": cnts[sha1]["_id"]}, {"$set": {"ts": ts}} - ) - else: - # add new content - self.db.content.insert_one( - { - "sha1": sha1, - "ts": ts, - "revision": {}, - "directory": {}, - } - ) - return True - - def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - dirs = { + def directory_add( + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) + existing = { x["sha1"]: x for x in self.db.directory.find( - {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } - for sha1, date in dates.items(): - ts = datetime.timestamp(date) - if sha1 in dirs: - # update - if dirs[sha1]["ts"] is None or ts < dirs[sha1]["ts"]: + for sha1, date in data.items(): + ts = datetime.timestamp(date) if date is not None else None + if sha1 in existing: + dir = existing[sha1] + if ts is not None and (dir["ts"] is None or ts < dir["ts"]): self.db.directory.update_one( - {"_id": dirs[sha1]["_id"]}, {"$set": {"ts": ts}} + {"_id": dir["_id"]}, {"$set": {"ts": ts}} ) else: - # add new dir self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) return True def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.directory.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] for x in self.db.get_collection(entity.value).find( {}, {"sha1": 1, "_id": 0} ) } - def location_get(self) -> Set[bytes]: + def location_add(self, paths: Iterable[bytes]) -> bool: + # TODO: implement this methods if path are to be stored in a separate collection + return True + + def location_get_all(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] for content in contents: paths.extend(value for _, value in content["revision"].items()) paths.extend(value for _, value in content["directory"].items()) dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) for each_dir in dirs: paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) - def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: - origins = { + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: + existing = { x["sha1"]: x for x in self.db.origin.find( - {"sha1": {"$in": list(urls)}}, {"sha1": 1, "url": 1, "_id": 1} + {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} ) } - for sha1, url in urls.items(): - if sha1 not in origins: + for sha1, url in orgs.items(): + if sha1 not in existing: # add new origin self.db.origin.insert_one({"sha1": sha1, "url": url}) return True def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return { x["sha1"]: x["url"] for x in self.db.origin.find( {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} ) } - def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - revs = { + def revision_add( + self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] + ) -> bool: + data = ( + revs + if isinstance(revs, dict) + else dict.fromkeys(revs, RevisionData(date=None, origin=None)) + ) + existing = { x["sha1"]: x for x in self.db.revision.find( - {"sha1": {"$in": list(dates)}}, {"sha1": 1, "ts": 1, "_id": 1} + {"sha1": {"$in": list(data)}}, + {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, ) } - for sha1, date in dates.items(): - ts = datetime.timestamp(date) - if sha1 in revs: - # update - if revs[sha1]["ts"] is None or ts < revs[sha1]["ts"]: + for sha1, info in data.items(): + ts = datetime.timestamp(info.date) if info.date is not None else None + preferred = info.origin + if sha1 in existing: + rev = existing[sha1] + if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): + ts = rev["ts"] + if preferred is None: + preferred = rev["preferred"] + if ts != rev["ts"] or preferred != rev["preferred"]: self.db.revision.update_one( - {"_id": revs[sha1]["_id"]}, {"$set": {"ts": ts}} + {"_id": rev["_id"]}, + {"$set": {"ts": ts, "preferred": preferred}}, ) else: - # add new rev self.db.revision.insert_one( { "sha1": sha1, - "preferred": None, + "preferred": preferred, "origin": [], "revision": [], "ts": ts, } ) return True - def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: - revs = { - x["sha1"]: x - for x in self.db.revision.find( - {"sha1": {"$in": list(origins)}}, {"sha1": 1, "preferred": 1, "_id": 1} - ) - } - for sha1, origin in origins.items(): - if sha1 in revs: - self.db.revision.update_one( - {"_id": revs[sha1]["_id"]}, {"$set": {"preferred": origin}} - ) - else: - # add new rev - self.db.revision.insert_one( - { - "sha1": sha1, - "preferred": origin, - "origin": [], - "revision": [], - "ts": None, - } - ) - return True - def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return { x["sha1"]: RevisionData( date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, origin=x["preferred"], ) for x in self.db.revision.find( - {"sha1": {"$in": list(ids)}}, + { + "sha1": {"$in": list(ids)}, + "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], + }, {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, ) } def relation_add( self, relation: RelationType, data: Iterable[RelationData] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") set_data = set(data) - dst_sha1s = {x.dst for x in set_data} - if dst_relation in ["content", "directory", "revision"]: - dst_obj: Dict[str, Any] = {"ts": None} - if dst_relation == "content": - dst_obj["revision"] = {} - dst_obj["directory"] = {} - if dst_relation == "directory": - dst_obj["revision"] = {} - if dst_relation == "revision": - dst_obj["preferred"] = None - dst_obj["origin"] = [] - dst_obj["revision"] = [] - - existing = { - x["sha1"] - for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 0, "sha1": 1} - ) - } - - for sha1 in dst_sha1s: - if sha1 not in existing: - self.db.get_collection(dst_relation).insert_one( - dict(dst_obj, **{"sha1": sha1}) - ) - elif dst_relation == "origin": - # TODO, check origins are already in the DB - # if not, algo has something wrong (algo inserts it initially) - pass - dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( - {"sha1": {"$in": list(dst_sha1s)}}, {"_id": 1, "sha1": 1} + {"sha1": {"$in": [x.dst for x in set_data]}}, {"_id": 1, "sha1": 1} ) } denorm: Dict[Sha1Git, Any] = {} for each in set_data: if src_relation != "revision": denorm.setdefault(each.src, {}).setdefault( str(dst_objs[each.dst]), [] ).append(each.path) else: denorm.setdefault(each.src, []).append(dst_objs[each.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( {"sha1": {"$in": list(denorm)}} ) } for sha1, dsts in denorm.items(): - if sha1 in src_objs: - # update - if src_relation != "revision": - k = { - obj_id: list(set(paths + dsts.get(obj_id, []))) - for obj_id, paths in src_objs[sha1][dst_relation].items() - } - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - {"$set": {dst_relation: dict(dsts, **k)}}, - ) - else: - self.db.get_collection(src_relation).update_one( - {"_id": src_objs[sha1]["_id"]}, - { - "$set": { - dst_relation: list( - set(src_objs[sha1][dst_relation] + dsts) - ) - } - }, - ) + # update + if src_relation != "revision": + k = { + obj_id: list(set(paths + dsts.get(obj_id, []))) + for obj_id, paths in src_objs[sha1][dst_relation].items() + } + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + {"$set": {dst_relation: dict(dsts, **k)}}, + ) else: - # add new rev - src_obj: Dict[str, Any] = {"ts": None} - if src_relation == "content": - src_obj["revision"] = {} - src_obj["directory"] = {} - if src_relation == "directory": - src_obj["revision"] = {} - if src_relation == "revision": - src_obj["preferred"] = None - src_obj["origin"] = [] - src_obj["revision"] = [] - self.db.get_collection(src_relation).insert_one( - dict(src_obj, **{"sha1": sha1, dst_relation: dsts}) + self.db.get_collection(src_relation).update_one( + {"_id": src_objs[sha1]["_id"]}, + { + "$set": { + dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) + } + }, ) return True def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[RelationData]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 0, "sha1": 1, dst: 1} ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { RelationData(src=src_sha1, dst=dst_sha1, path=path) for src_sha1, denorm in src_objs.items() for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } else: return { RelationData(src=src_sha1, dst=dst_sha1, path=None) for src_sha1, denorm in src_objs.items() for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } else: dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} ) } src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {}, {"_id": 0, "sha1": 1, dst: 1} ) } if src != "revision": return { RelationData(src=src_sha1, dst=dst_sha1, path=path) for src_sha1, denorm in src_objs.items() for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } else: return { RelationData(src=src_sha1, dst=dst_sha1, path=None) for src_sha1, denorm in src_objs.items() for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } def relation_get_all(self, relation: RelationType) -> Set[RelationData]: src, *_, dst = relation.value.split("_") src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find({}, {"_id": 0, "sha1": 1, dst: 1}) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) if src != "revision": dst_objs = { x["_id"]: x["sha1"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } return { RelationData(src=src_sha1, dst=dst_sha1, path=path) for src_sha1, denorm in src_objs.items() for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } else: dst_objs = { x["_id"]: x["sha1"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } return { RelationData(src=src_sha1, dst=dst_sha1, path=None) for src_sha1, denorm in src_objs.items() for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } def with_path(self) -> bool: return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index 1d80ffd..985a495 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,315 +1,318 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime import itertools import logging -from typing import Dict, Generator, Iterable, List, Optional, Set +from typing import Dict, Generator, Iterable, List, Optional, Set, Union import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) class ProvenanceStoragePostgreSql: def __init__( self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False ) -> None: BaseDb.adapt_conn(conn) self.conn = conn with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extensions.cursor, None, None]: self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor + def content_add( + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + return self._entity_set_date("content", cnts) + def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) - def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - return self._entity_set_date("content", dates) - def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) - def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - return self._entity_set_date("directory", dates) + def directory_add( + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + ) -> bool: + return self._entity_set_date("directory", dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} - def location_get(self) -> Set[bytes]: + def location_add(self, paths: Iterable[bytes]) -> bool: + if not self.with_path(): + return True + try: + values = [(path,) for path in paths] + if values: + sql = """ + INSERT INTO location(path) VALUES %s + ON CONFLICT DO NOTHING + """ + with self.transaction() as cursor: + psycopg2.extras.execute_values(cursor, sql, argslist=values) + return True + except: # noqa: E722 + # Unexpected error occurred, rollback all changes and log message + LOGGER.exception("Unexpected error") + if self.raise_on_commit: + raise + return False + + def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} - def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: + def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: try: - if urls: + if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=urls.items() + cur=cursor, sql=sql, argslist=orgs.items() ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls - def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: - return self._entity_set_date("revision", dates) - - def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: + def revision_add( + self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] + ) -> bool: + if isinstance(revs, dict): + data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] + else: + data = [(sha1, None, None) for sha1 in revs] try: - if origins: + if data: sql = """ - INSERT INTO revision(sha1, origin) - (SELECT V.rev AS sha1, O.id AS origin - FROM (VALUES %s) AS V(rev, org) - JOIN origin AS O ON (O.sha1=V.org)) + INSERT INTO revision(sha1, date, origin) + (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin + FROM (VALUES %s) AS V(rev, date, org) + LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO - UPDATE SET origin=EXCLUDED.origin + UPDATE SET + date=LEAST(EXCLUDED.date, revision.date), + origin=COALESCE(EXCLUDED.origin, revision.origin) """ with self.transaction() as cursor: - psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=origins.items() - ) + psycopg2.extras.execute_values(cur=cursor, sql=sql, argslist=data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) + AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result def relation_add( self, relation: RelationType, data: Iterable[RelationData] ) -> bool: + rows = [(rel.src, rel.dst, rel.path) for rel in data] try: - rows = [(rel.src, rel.dst, rel.path) for rel in data] if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") - if src_table != "origin": - # Origin entries should be inserted previously as they require extra - # non-null information - srcs = tuple(set((sha1,) for (sha1, _, _) in rows)) - sql = f""" - INSERT INTO {src_table}(sha1) VALUES %s - ON CONFLICT DO NOTHING - """ - with self.transaction() as cursor: - psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=srcs - ) - - if dst_table != "origin": - # Origin entries should be inserted previously as they require extra - # non-null information - dsts = tuple(set((sha1,) for (_, sha1, _) in rows)) - sql = f""" - INSERT INTO {dst_table}(sha1) VALUES %s - ON CONFLICT DO NOTHING - """ - with self.transaction() as cursor: - psycopg2.extras.execute_values( - cur=cursor, sql=sql, argslist=dsts - ) - # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[RelationData]: return self._relation_get(relation, ids, reverse) def relation_get_all(self, relation: RelationType) -> Set[RelationData]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( self, - entity: Literal["content", "directory", "revision"], - data: Dict[Sha1Git, datetime], + entity: Literal["content", "directory"], + dates: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]], ) -> bool: + data = dates if isinstance(dates, dict) else dict.fromkeys(dates) try: if data: sql = f""" INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Set[RelationData]: result: Set[RelationData] = set() sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) result.update(RelationData(**row) for row in cursor) return result def with_path(self) -> bool: return "with-path" in self.flavor diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index 3a78209..79f4b71 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,354 +1,404 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import os from typing import Dict, Generator, Iterable, Optional, Set, Tuple from typing_extensions import Literal, TypedDict from swh.model.model import Sha1Git from .interface import ( ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, + RevisionData, ) from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry LOGGER = logging.getLogger(__name__) class DatetimeCache(TypedDict): data: Dict[Sha1Git, Optional[datetime]] added: Set[Sha1Git] class OriginCache(TypedDict): data: Dict[Sha1Git, str] added: Set[Sha1Git] class RevisionCache(TypedDict): data: Dict[Sha1Git, Sha1Git] added: Set[Sha1Git] class ProvenanceCache(TypedDict): content: DatetimeCache directory: DatetimeCache revision: DatetimeCache # below are insertion caches only content_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] content_in_directory: Set[Tuple[Sha1Git, Sha1Git, bytes]] directory_in_revision: Set[Tuple[Sha1Git, Sha1Git, bytes]] # these two are for the origin layer origin: OriginCache revision_origin: RevisionCache revision_before_revision: Dict[Sha1Git, Set[Sha1Git]] revision_in_origin: Set[Tuple[Sha1Git, Sha1Git]] def new_cache() -> ProvenanceCache: return ProvenanceCache( content=DatetimeCache(data={}, added=set()), directory=DatetimeCache(data={}, added=set()), revision=DatetimeCache(data={}, added=set()), content_in_revision=set(), content_in_directory=set(), directory_in_revision=set(), origin=OriginCache(data={}, added=set()), revision_origin=RevisionCache(data={}, added=set()), revision_before_revision={}, revision_in_origin=set(), ) class Provenance: def __init__(self, storage: ProvenanceStorageInterface) -> None: self.storage = storage self.cache = new_cache() def clear_caches(self) -> None: self.cache = new_cache() def flush(self) -> None: # Revision-content layer insertions ############################################ + # After relations, dates for the entities can be safely set, acknowledging that + # these entities won't need to be reprocessed in case of failure. + cnts = { + src + for src, _, _ in self.cache["content_in_revision"] + | self.cache["content_in_directory"] + } + if cnts: + while not self.storage.content_add(cnts): + LOGGER.warning( + "Unable to write content entities to the storage. Retrying..." + ) + + dirs = {dst for _, dst, _ in self.cache["content_in_directory"]} + if dirs: + while not self.storage.directory_add(dirs): + LOGGER.warning( + "Unable to write directory entities to the storage. Retrying..." + ) + + revs = { + dst + for _, dst, _ in self.cache["content_in_revision"] + | self.cache["directory_in_revision"] + } + if revs: + while not self.storage.revision_add(revs): + LOGGER.warning( + "Unable to write revision entities to the storage. Retrying..." + ) + + paths = { + path + for _, _, path in self.cache["content_in_revision"] + | self.cache["content_in_directory"] + | self.cache["directory_in_revision"] + } + if paths: + while not self.storage.location_add(paths): + LOGGER.warning( + "Unable to write locations entities to the storage. Retrying..." + ) + # For this layer, relations need to be inserted first so that, in case of # failure, reprocessing the input does not generated an inconsistent database. if self.cache["content_in_revision"]: while not self.storage.relation_add( RelationType.CNT_EARLY_IN_REV, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["content_in_revision"] ), ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_EARLY_IN_REV, ) if self.cache["content_in_directory"]: while not self.storage.relation_add( RelationType.CNT_IN_DIR, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["content_in_directory"] ), ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.CNT_IN_DIR, ) if self.cache["directory_in_revision"]: while not self.storage.relation_add( RelationType.DIR_IN_REV, ( RelationData(src=src, dst=dst, path=path) for src, dst, path in self.cache["directory_in_revision"] ), ): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.DIR_IN_REV, ) # After relations, dates for the entities can be safely set, acknowledging that # these entities won't need to be reprocessed in case of failure. - dates = { + cnt_dates = { sha1: date for sha1, date in self.cache["content"]["data"].items() if sha1 in self.cache["content"]["added"] and date is not None } - if dates: - while not self.storage.content_set_date(dates): + if cnt_dates: + while not self.storage.content_add(cnt_dates): LOGGER.warning( "Unable to write content dates to the storage. Retrying..." ) - dates = { + dir_dates = { sha1: date for sha1, date in self.cache["directory"]["data"].items() if sha1 in self.cache["directory"]["added"] and date is not None } - if dates: - while not self.storage.directory_set_date(dates): + if dir_dates: + while not self.storage.directory_add(dir_dates): LOGGER.warning( "Unable to write directory dates to the storage. Retrying..." ) - dates = { - sha1: date + rev_dates = { + sha1: RevisionData(date=date, origin=None) for sha1, date in self.cache["revision"]["data"].items() if sha1 in self.cache["revision"]["added"] and date is not None } - if dates: - while not self.storage.revision_set_date(dates): + if rev_dates: + while not self.storage.revision_add(rev_dates): LOGGER.warning( "Unable to write revision dates to the storage. Retrying..." ) # Origin-revision layer insertions ############################################# - # Origins urls should be inserted first so that internal ids' resolution works - # properly. + # Origins and revisions should be inserted first so that internal ids' + # resolution works properly. urls = { sha1: url for sha1, url in self.cache["origin"]["data"].items() if sha1 in self.cache["origin"]["added"] } if urls: - while not self.storage.origin_set_url(urls): + while not self.storage.origin_add(urls): LOGGER.warning( "Unable to write origins urls to the storage. Retrying..." ) + rev_orgs = { + # Destinations in this relation should match origins in the next one + **{ + src: RevisionData(date=None, origin=None) + for src in self.cache["revision_before_revision"] + }, + **{ + # This relation comes second so that non-None origins take precedence + src: RevisionData(date=None, origin=org) + for src, org in self.cache["revision_in_origin"] + }, + } + if rev_orgs: + while not self.storage.revision_add(rev_orgs): + LOGGER.warning( + "Unable to write revision entities to the storage. Retrying..." + ) + # Second, flat models for revisions' histories (ie. revision-before-revision). data: Iterable[RelationData] = sum( [ [ RelationData(src=prev, dst=next, path=None) for next in self.cache["revision_before_revision"][prev] ] for prev in self.cache["revision_before_revision"] ], [], ) if data: while not self.storage.relation_add(RelationType.REV_BEFORE_REV, data): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_BEFORE_REV, ) # Heads (ie. revision-in-origin entries) should be inserted once flat models for # their histories were already added. This is to guarantee consistent results if # something needs to be reprocessed due to a failure: already inserted heads # won't get reprocessed in such a case. data = ( RelationData(src=rev, dst=org, path=None) for rev, org in self.cache["revision_in_origin"] ) if data: while not self.storage.relation_add(RelationType.REV_IN_ORG, data): LOGGER.warning( "Unable to write %s rows to the storage. Retrying...", RelationType.REV_IN_ORG, ) - # Finally, preferred origins for the visited revisions are set (this step can be - # reordered if required). - origins = { - sha1: self.cache["revision_origin"]["data"][sha1] - for sha1 in self.cache["revision_origin"]["added"] - } - if origins: - while not self.storage.revision_set_origin(origins): - LOGGER.warning( - "Unable to write preferred origins to the storage. Retrying..." - ) - # clear local cache ############################################################ self.clear_caches() def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_directory"].add( (blob.id, directory.id, normalize(os.path.join(prefix, blob.name))) ) def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: self.cache["content_in_revision"].add( (blob.id, revision.id, normalize(os.path.join(prefix, blob.name))) ) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: return self.storage.content_find_first(id) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: yield from self.storage.content_find_all(id, limit=limit) def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: return self.get_dates("content", [blob.id]).get(blob.id) def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("content", [blob.id for blob in blobs]) def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: self.cache["content"]["data"][blob.id] = date self.cache["content"]["added"].add(blob.id) def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: self.cache["directory_in_revision"].add( (directory.id, revision.id, normalize(path)) ) def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: return self.get_dates("directory", [directory.id]).get(directory.id) def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: return self.get_dates("directory", [directory.id for directory in dirs]) def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: self.cache["directory"]["data"][directory.id] = date self.cache["directory"]["added"].add(directory.id) def get_dates( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: cache = self.cache[entity] missing_ids = set(id for id in ids if id not in cache) if missing_ids: if entity == "revision": updated = { id: rev.date for id, rev in self.storage.revision_get(missing_ids).items() } else: updated = getattr(self.storage, f"{entity}_get")(missing_ids) cache["data"].update(updated) dates: Dict[Sha1Git, datetime] = {} for sha1 in ids: date = cache["data"].setdefault(sha1, None) if date is not None: dates[sha1] = date return dates def origin_add(self, origin: OriginEntry) -> None: self.cache["origin"]["data"][origin.id] = origin.url self.cache["origin"]["added"].add(origin.id) def revision_add(self, revision: RevisionEntry) -> None: self.cache["revision"]["data"][revision.id] = revision.date self.cache["revision"]["added"].add(revision.id) def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: self.cache["revision_before_revision"].setdefault(revision.id, set()).add( head.id ) def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_in_origin"].add((revision.id, origin.id)) def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: return self.get_dates("revision", [revision.id]).get(revision.id) def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: cache = self.cache["revision_origin"]["data"] if revision.id not in cache: ret = self.storage.revision_get([revision.id]) if revision.id in ret: origin = ret[revision.id].origin if origin is not None: cache[revision.id] = origin return cache.get(revision.id) def revision_in_history(self, revision: RevisionEntry) -> bool: return revision.id in self.cache["revision_before_revision"] or bool( self.storage.relation_get(RelationType.REV_BEFORE_REV, [revision.id]) ) def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: self.cache["revision_origin"]["data"][revision.id] = origin.id self.cache["revision_origin"]["added"].add(revision.id) def revision_visited(self, revision: RevisionEntry) -> bool: return revision.id in dict(self.cache["revision_in_origin"]) or bool( self.storage.relation_get(RelationType.REV_IN_ORG, [revision.id]) ) def normalize(path: bytes) -> bytes: return path[2:] if path.startswith(bytes("." + os.path.sep, "utf-8")) else path diff --git a/swh/provenance/sql/40-funcs.sql b/swh/provenance/sql/40-funcs.sql index d6dffa4..1bb2436 100644 --- a/swh/provenance/sql/40-funcs.sql +++ b/swh/provenance/sql/40-funcs.sql @@ -1,711 +1,701 @@ -- psql variables to get the current database flavor select position('denormalized' in swh_get_dbflavor()::text) = 0 as dbflavor_norm \gset select position('with-path' in swh_get_dbflavor()::text) != 0 as dbflavor_with_path \gset create or replace function swh_mktemp_relation_add() returns void language sql as $$ create temp table tmp_relation_add ( src sha1_git not null, dst sha1_git not null, path unix_path ) on commit drop $$; \if :dbflavor_norm \if :dbflavor_with_path -- -- with path and normalized -- create or replace function swh_provenance_content_find_first(content_id sha1_git) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, L.path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join location as L on (L.id = CR.location) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id order by date, revision, origin, path asc limit 1 $$; create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ (select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, L.path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join location as L on (L.id = CR.location) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id) union (select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, case DL.path when '' then CL.path when '.' then CL.path else (DL.path || '/' || CL.path)::unix_path end as path from content as C inner join content_in_directory as CD on (CD.content = C.id) inner join directory_in_revision as DR on (DR.directory = CD.directory) inner join revision as R on (R.id = DR.revision) inner join location as CL on (CL.id = CD.location) inner join location as DL on (DL.id = DR.location) left join origin as O on (O.id = R.origin) where C.sha1 = content_id) order by date, revision, origin, path limit early_cut $$; create or replace function swh_provenance_relation_add_from_temp( rel_table regclass, src_table regclass, dst_table regclass ) returns void language plpgsql volatile as $$ declare select_fields text; join_location text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - insert into location(path) - select V.path - from tmp_relation_add as V - on conflict (path) do nothing; - select_fields := 'D.id, L.id'; join_location := 'inner join location as L on (L.path = V.path)'; else select_fields := 'D.id'; join_location := ''; end if; execute format( 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || join_location || ' on conflict do nothing', rel_table, src_table, dst_table ); end; $$; create or replace function swh_provenance_relation_get( rel_table regclass, src_table regclass, dst_table regclass, filter rel_flt, sha1s sha1_git[] ) returns table ( src sha1_git, dst sha1_git, path unix_path ) language plpgsql stable as $$ declare src_field text; dst_field text; join_location text; proj_location text; filter_result text; begin if rel_table = 'revision_before_revision'::regclass then src_field := 'prev'; dst_field := 'next'; else src_field := src_table::text; dst_field := dst_table::text; end if; if src_table in ('content'::regclass, 'directory'::regclass) then join_location := 'inner join location as L on (L.id = R.location)'; proj_location := 'L.path'; else join_location := ''; proj_location := 'NULL::unix_path'; end if; case filter when 'filter-src'::rel_flt then filter_result := 'where S.sha1 = any($1)'; when 'filter-dst'::rel_flt then filter_result := 'where D.sha1 = any($1)'; else filter_result := ''; end case; return query execute format( 'select S.sha1 as src, D.sha1 as dst, ' || proj_location || ' as path from %s as R inner join %s as S on (S.id = R.' || src_field || ') inner join %s as D on (D.id = R.' || dst_field || ') ' || join_location || ' ' || filter_result, rel_table, src_table, dst_table ) using sha1s; end; $$; \else -- -- without path and normalized -- create or replace function swh_provenance_content_find_first(content_id sha1_git) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id order by date, revision, origin asc limit 1 $$; create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ (select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from content as C inner join content_in_revision as CR on (CR.content = C.id) inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id) union (select C.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from content as C inner join content_in_directory as CD on (CD.content = C.id) inner join directory_in_revision as DR on (DR.directory = CD.directory) inner join revision as R on (R.id = DR.revision) left join origin as O on (O.id = R.origin) where C.sha1 = content_id) order by date, revision, origin, path limit early_cut $$; create or replace function swh_provenance_relation_add_from_temp( rel_table regclass, src_table regclass, dst_table regclass ) returns void language plpgsql volatile as $$ begin execute format( 'insert into %s select S.id, D.id from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) on conflict do nothing', rel_table, src_table, dst_table ); end; $$; create or replace function swh_provenance_relation_get( rel_table regclass, src_table regclass, dst_table regclass, filter rel_flt, sha1s sha1_git[] ) returns table ( src sha1_git, dst sha1_git, path unix_path ) language plpgsql stable as $$ declare src_field text; dst_field text; filter_result text; begin if rel_table = 'revision_before_revision'::regclass then src_field := 'prev'; dst_field := 'next'; else src_field := src_table::text; dst_field := dst_table::text; end if; case filter when 'filter-src'::rel_flt then filter_result := 'where S.sha1 = any($1)'; when 'filter-dst'::rel_flt then filter_result := 'where D.sha1 = any($1)'; else filter_result := ''; end case; return query execute format( 'select S.sha1 as src, D.sha1 as dst, NULL::unix_path as path from %s as R inner join %s as S on (S.id = R.' || src_field || ') inner join %s as D on (D.id = R.' || dst_field || ') ' || filter_result, rel_table, src_table, dst_table ) using sha1s; end; $$; -- :dbflavor_with_path \endif -- :dbflavor_norm \else \if :dbflavor_with_path -- -- with path and denormalized -- create or replace function swh_provenance_content_find_first(content_id sha1_git) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ select CL.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, L.path as path from ( select C.sha1 as sha1, unnest(CR.revision) as revision from content_in_revision as CR inner join content as C on (C.id = CR.content) where C.sha1 = content_id ) as CL inner join revision as R on (R.id = (CL.revision).id) inner join location as L on (L.id = (CL.revision).loc) left join origin as O on (O.id = R.origin) order by date, revision, origin, path asc limit 1 $$; create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ (with cntrev as ( select C.sha1 as sha1, unnest(CR.revision) as revision from content_in_revision as CR inner join content as C on (C.id = CR.content) where C.sha1 = content_id) select CR.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, L.path as path from cntrev as CR inner join revision as R on (R.id = (CR.revision).id) inner join location as L on (L.id = (CR.revision).loc) left join origin as O on (O.id = R.origin)) union (with cntdir as ( select C.sha1 as sha1, unnest(CD.directory) as directory from content as C inner join content_in_directory as CD on (CD.content = C.id) where C.sha1 = content_id), cntrev as ( select CD.sha1 as sha1, L.path as path, unnest(DR.revision) as revision from cntdir as CD inner join directory_in_revision as DR on (DR.directory = (CD.directory).id) inner join location as L on (L.id = (CD.directory).loc)) select CR.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, case DL.path when '' then CR.path when '.' then CR.path else (DL.path || '/' || CR.path)::unix_path end as path from cntrev as CR inner join revision as R on (R.id = (CR.revision).id) inner join location as DL on (DL.id = (CR.revision).loc) left join origin as O on (O.id = R.origin)) order by date, revision, origin, path limit early_cut $$; create or replace function swh_provenance_relation_add_from_temp( rel_table regclass, src_table regclass, dst_table regclass ) returns void language plpgsql volatile as $$ declare select_fields text; join_location text; group_entries text; on_conflict text; begin if src_table in ('content'::regclass, 'directory'::regclass) then - insert into location(path) - select V.path - from tmp_relation_add as V - on conflict (path) do nothing; - select_fields := 'array_agg((D.id, L.id)::rel_dst)'; join_location := 'inner join location as L on (L.path = V.path)'; group_entries := 'group by S.id'; on_conflict := format(' (%s) do update set %s=array( select distinct unnest( %s.' || dst_table::text || ' || excluded.' || dst_table::text || ' ) )', src_table, dst_table, rel_table, rel_table, rel_table ); else select_fields := 'D.id'; join_location := ''; group_entries := ''; on_conflict := 'do nothing'; end if; execute format( 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || join_location || ' ' || group_entries || ' on conflict ' || on_conflict, rel_table, src_table, dst_table ); end; $$; create or replace function swh_provenance_relation_get( rel_table regclass, src_table regclass, dst_table regclass, filter rel_flt, sha1s sha1_git[] ) returns table ( src sha1_git, dst sha1_git, path unix_path ) language plpgsql stable as $$ declare src_field text; dst_field text; proj_dst_id text; proj_unnested text; proj_location text; join_location text; filter_inner_result text; filter_outer_result text; begin if rel_table = 'revision_before_revision'::regclass then src_field := 'prev'; dst_field := 'next'; else src_field := src_table::text; dst_field := dst_table::text; end if; if src_table in ('content'::regclass, 'directory'::regclass) then proj_unnested := 'unnest(R.' || dst_field || ') as dst'; proj_dst_id := '(CL.dst).id'; join_location := 'inner join location as L on (L.id = (CL.dst).loc)'; proj_location := 'L.path'; else proj_unnested := 'R.' || dst_field || ' as dst'; proj_dst_id := 'CL.dst'; join_location := ''; proj_location := 'NULL::unix_path'; end if; case filter when 'filter-src'::rel_flt then filter_inner_result := 'where S.sha1 = any($1)'; filter_outer_result := ''; when 'filter-dst'::rel_flt then filter_inner_result := ''; filter_outer_result := 'where D.sha1 = any($1)'; else filter_inner_result := ''; filter_outer_result := ''; end case; return query execute format( 'select CL.src, D.sha1 as dst, ' || proj_location || ' as path from (select S.sha1 as src, ' || proj_unnested || ' from %s as R inner join %s as S on (S.id = R.' || src_field || ') ' || filter_inner_result || ') as CL inner join %s as D on (D.id = ' || proj_dst_id || ') ' || join_location || ' ' || filter_outer_result, rel_table, src_table, dst_table ) using sha1s; end; $$; \else -- -- without path and denormalized -- create or replace function swh_provenance_content_find_first(content_id sha1_git) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ select CL.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from ( select C.sha1, unnest(revision) as revision from content_in_revision as CR inner join content as C on (C.id = CR.content) where C.sha1=content_id ) as CL inner join revision as R on (R.id = CL.revision) left join origin as O on (O.id = R.origin) order by date, revision, origin, path asc limit 1 $$; create or replace function swh_provenance_content_find_all(content_id sha1_git, early_cut int) returns table ( content sha1_git, revision sha1_git, date timestamptz, origin text, path unix_path ) language sql stable as $$ (with cntrev as ( select C.sha1 as sha1, unnest(CR.revision) as revision from content_in_revision as CR inner join content as C on (C.id = CR.content) where C.sha1 = content_id) select CR.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from cntrev as CR inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin)) union (with cntdir as ( select C.sha1 as sha1, unnest(CD.directory) as directory from content as C inner join content_in_directory as CD on (CD.content = C.id) where C.sha1 = content_id), cntrev as ( select CD.sha1 as sha1, unnest(DR.revision) as revision from cntdir as CD inner join directory_in_revision as DR on (DR.directory = CD.directory)) select CR.sha1 as content, R.sha1 as revision, R.date as date, O.url as origin, '\x'::unix_path as path from cntrev as CR inner join revision as R on (R.id = CR.revision) left join origin as O on (O.id = R.origin)) order by date, revision, origin, path limit early_cut $$; create or replace function swh_provenance_relation_add_from_temp( rel_table regclass, src_table regclass, dst_table regclass ) returns void language plpgsql volatile as $$ declare select_fields text; group_entries text; on_conflict text; begin if src_table in ('content'::regclass, 'directory'::regclass) then select_fields := 'array_agg(D.id)'; group_entries := 'group by S.id'; on_conflict := format(' (%s) do update set %s=array( select distinct unnest( %s.' || dst_table::text || ' || excluded.' || dst_table::text || ' ) )', src_table, dst_table, rel_table, rel_table ); else select_fields := 'D.id'; group_entries := ''; on_conflict := 'do nothing'; end if; execute format( 'insert into %s select S.id, ' || select_fields || ' from tmp_relation_add as V inner join %s as S on (S.sha1 = V.src) inner join %s as D on (D.sha1 = V.dst) ' || group_entries || ' on conflict ' || on_conflict, rel_table, src_table, dst_table ); end; $$; create or replace function swh_provenance_relation_get( rel_table regclass, src_table regclass, dst_table regclass, filter rel_flt, sha1s sha1_git[] ) returns table ( src sha1_git, dst sha1_git, path unix_path ) language plpgsql stable as $$ declare src_field text; dst_field text; proj_unnested text; filter_inner_result text; filter_outer_result text; begin if rel_table = 'revision_before_revision'::regclass then src_field := 'prev'; dst_field := 'next'; else src_field := src_table::text; dst_field := dst_table::text; end if; if src_table in ('content'::regclass, 'directory'::regclass) then proj_unnested := 'unnest(R.' || dst_field || ') as dst'; else proj_unnested := 'R.' || dst_field || ' as dst'; end if; case filter when 'filter-src'::rel_flt then filter_inner_result := 'where S.sha1 = any($1)'; filter_outer_result := ''; when 'filter-dst'::rel_flt then filter_inner_result := ''; filter_outer_result := 'where D.sha1 = any($1)'; else filter_inner_result := ''; filter_outer_result := ''; end case; return query execute format( 'select CL.src, D.sha1 as dst, NULL::unix_path as path from (select S.sha1 as src, ' || proj_unnested || ' from %s as R inner join %s as S on (S.id = R.' || src_field || ') ' || filter_inner_result || ') as CL inner join %s as D on (D.id = CL.dst) ' || filter_outer_result, rel_table, src_table, dst_table ) using sha1s; end; $$; \endif -- :dbflavor_with_path \endif -- :dbflavor_norm diff --git a/swh/provenance/tests/test_provenance_storage.py b/swh/provenance/tests/test_provenance_storage.py index 32a8629..841cea6 100644 --- a/swh/provenance/tests/test_provenance_storage.py +++ b/swh/provenance/tests/test_provenance_storage.py @@ -1,349 +1,460 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime +from datetime import datetime, timezone import inspect import os from typing import Any, Dict, Iterable, Optional, Set -import pytest - from swh.model.hashutil import hash_to_bytes from swh.model.identifiers import origin_identifier from swh.model.model import Sha1Git +from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import ( EntityType, ProvenanceInterface, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, + RevisionData, ) -from swh.provenance.tests.conftest import load_repo_data, ts2dt +from swh.provenance.model import OriginEntry, RevisionEntry +from swh.provenance.mongo.backend import ProvenanceStorageMongoDb +from swh.provenance.origin import origin_add +from swh.provenance.provenance import Provenance +from swh.provenance.revision import revision_add +from swh.provenance.tests.conftest import fill_storage, load_repo_data, ts2dt -def relation_add_and_compare_result( - relation: RelationType, - data: Set[RelationData], - refstorage: ProvenanceStorageInterface, - storage: ProvenanceStorageInterface, - with_path: bool = True, +def test_provenance_storage_content( + provenance_storage: ProvenanceStorageInterface, ) -> None: - assert data - assert refstorage.relation_add(relation, data) == storage.relation_add( - relation, data - ) + """Tests content methods for every `ProvenanceStorageInterface` implementation.""" - assert relation_compare_result( - refstorage.relation_get(relation, (reldata.src for reldata in data)), - storage.relation_get(relation, (reldata.src for reldata in data)), - with_path, - ) - assert relation_compare_result( - refstorage.relation_get( - relation, - (reldata.dst for reldata in data), - reverse=True, - ), - storage.relation_get( - relation, - (reldata.dst for reldata in data), - reverse=True, - ), - with_path, + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Add all content present in the current repo to the storage, just assigning their + # creation dates. Then check that the returned results when querying are the same. + cnts = {cnt["sha1_git"] for idx, cnt in enumerate(data["content"]) if idx % 2 == 0} + cnt_dates = { + cnt["sha1_git"]: cnt["ctime"] + for idx, cnt in enumerate(data["content"]) + if idx % 2 == 1 + } + assert cnts or cnt_dates + assert provenance_storage.content_add(cnts) + assert provenance_storage.content_add(cnt_dates) + assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates + assert provenance_storage.entity_get_all(EntityType.CONTENT) == cnts | set( + cnt_dates.keys() ) - assert relation_compare_result( - refstorage.relation_get_all(relation), - storage.relation_get_all(relation), - with_path, + + +def test_provenance_storage_directory( + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests directory methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Of all directories present in the current repo, only assign a date to those + # containing blobs (picking the max date among the available ones). Then check that + # the returned results when querying are the same. + def getmaxdate( + directory: Dict[str, Any], contents: Iterable[Dict[str, Any]] + ) -> Optional[datetime]: + dates = [ + content["ctime"] + for entry in directory["entries"] + for content in contents + if entry["type"] == "file" and entry["target"] == content["sha1_git"] + ] + return max(dates) if dates else None + + dirs = { + dir["id"] + for dir in data["directory"] + if getmaxdate(dir, data["content"]) is None + } + dir_dates = { + dir["id"]: getmaxdate(dir, data["content"]) + for dir in data["directory"] + if getmaxdate(dir, data["content"]) is not None + } + assert dirs + assert provenance_storage.directory_add(dirs) + assert provenance_storage.directory_add(dir_dates) + assert provenance_storage.directory_get(set(dir_dates.keys())) == dir_dates + assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == dirs | set( + dir_dates.keys() ) -def relation_compare_result( - expected: Set[RelationData], computed: Set[RelationData], with_path: bool -) -> bool: - return { - RelationData(reldata.src, reldata.dst, reldata.path if with_path else None) - for reldata in expected - } == computed +def test_provenance_storage_location( + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests location methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Add all names of entries present in the directories of the current repo as paths + # to the storage. Then check that the returned results when querying are the same. + paths = {entry["name"] for dir in data["directory"] for entry in dir["entries"]} + assert provenance_storage.location_add(paths) + + if isinstance(provenance_storage, ProvenanceStorageMongoDb): + # TODO: remove this when `location_add` is properly implemented for MongoDb. + return + + if provenance_storage.with_path(): + assert provenance_storage.location_get_all() == paths + else: + assert provenance_storage.location_get_all() == set() + + +def test_provenance_storage_origin( + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests origin methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Test origin methods. + # Add all origins present in the current repo to the storage. Then check that the + # returned results when querying are the same. + orgs = {hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"]} + assert orgs + assert provenance_storage.origin_add(orgs) + assert provenance_storage.origin_get(set(orgs.keys())) == orgs + assert provenance_storage.entity_get_all(EntityType.ORIGIN) == set(orgs.keys()) + + +def test_provenance_storage_revision( + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests revision methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + + # Test revision methods. + # Add all revisions present in the current repo to the storage, assigning their + # dates and an arbitrary origin to each one. Then check that the returned results + # when querying are the same. + origin = next(iter(data["origin"])) + origin_sha1 = hash_to_bytes(origin_identifier(origin)) + # Origin must be inserted in advance. + assert provenance_storage.origin_add({origin_sha1: origin["url"]}) + + revs = {rev["id"] for idx, rev in enumerate(data["revision"]) if idx % 6 == 0} + rev_data = { + rev["id"]: RevisionData( + date=ts2dt(rev["date"]) if idx % 2 != 0 else None, + origin=origin_sha1 if idx % 3 != 0 else None, + ) + for idx, rev in enumerate(data["revision"]) + if idx % 6 != 0 + } + assert revs + assert provenance_storage.revision_add(revs) + assert provenance_storage.revision_add(rev_data) + assert provenance_storage.revision_get(set(rev_data.keys())) == rev_data + assert provenance_storage.entity_get_all(EntityType.REVISION) == revs | set( + rev_data.keys() + ) def dircontent( data: Dict[str, Any], ref: Sha1Git, dir: Dict[str, Any], prefix: bytes = b"", ) -> Iterable[RelationData]: content = { RelationData(entry["target"], ref, os.path.join(prefix, entry["name"])) for entry in dir["entries"] if entry["type"] == "file" } for entry in dir["entries"]: if entry["type"] == "dir": child = next( subdir for subdir in data["directory"] if subdir["id"] == entry["target"] ) content.update( dircontent(data, ref, child, os.path.join(prefix, entry["name"])) ) return content -@pytest.mark.parametrize( - "repo", - ("cmdbts2", "out-of-order", "with-merges"), -) -def test_provenance_storage( - provenance: ProvenanceInterface, - provenance_storage: ProvenanceStorageInterface, - repo: str, +def entity_add( + storage: ProvenanceStorageInterface, entity: EntityType, ids: Set[Sha1Git] +) -> bool: + if entity == EntityType.CONTENT: + return storage.content_add({sha1: None for sha1 in ids}) + elif entity == EntityType.DIRECTORY: + return storage.directory_add({sha1: None for sha1 in ids}) + else: # entity == EntityType.REVISION: + return storage.revision_add( + {sha1: RevisionData(date=None, origin=None) for sha1 in ids} + ) + + +def relation_add_and_compare_result( + storage: ProvenanceStorageInterface, relation: RelationType, data: Set[RelationData] ) -> None: - """Tests every ProvenanceStorageInterface implementation against the one provided - for provenance.storage.""" - # Read data/README.md for more details on how these datasets are generated. - data = load_repo_data(repo) + # Source, destinations and locations must be added in advance. + src, *_, dst = relation.value.split("_") + if src != "origin": + assert entity_add(storage, EntityType(src), {entry.src for entry in data}) + if dst != "origin": + assert entity_add(storage, EntityType(dst), {entry.dst for entry in data}) + if storage.with_path(): + assert storage.location_add( + {entry.path for entry in data if entry.path is not None} + ) - # Assuming provenance.storage has the 'with-path' flavor. - assert provenance.storage.with_path() + assert data + assert storage.relation_add(relation, data) - # Test origin methods. - # Add all origins present in the current repo to both storages. Then check that the - # inserted data is the same in both cases. - org_urls = { - hash_to_bytes(origin_identifier(org)): org["url"] for org in data["origin"] - } - assert org_urls - assert provenance.storage.origin_set_url( - org_urls - ) == provenance_storage.origin_set_url(org_urls) + for row in data: + assert relation_compare_result( + storage.relation_get(relation, [row.src]), + {entry for entry in data if entry.src == row.src}, + storage.with_path(), + ) + assert relation_compare_result( + storage.relation_get( + relation, + [row.dst], + reverse=True, + ), + {entry for entry in data if entry.dst == row.dst}, + storage.with_path(), + ) - assert provenance.storage.origin_get(org_urls) == provenance_storage.origin_get( - org_urls + assert relation_compare_result( + storage.relation_get_all(relation), data, storage.with_path() ) - assert provenance.storage.entity_get_all( - EntityType.ORIGIN - ) == provenance_storage.entity_get_all(EntityType.ORIGIN) + + +def relation_compare_result( + computed: Set[RelationData], expected: Set[RelationData], with_path: bool +) -> bool: + return { + RelationData(row.src, row.dst, row.path if with_path else None) + for row in expected + } == computed + + +def test_provenance_storage_relation( + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests relation methods for every `ProvenanceStorageInterface` implementation.""" + + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") # Test content-in-revision relation. # Create flat models of every root directory for the revisions in the dataset. cnt_in_rev: Set[RelationData] = set() for rev in data["revision"]: root = next( subdir for subdir in data["directory"] if subdir["id"] == rev["directory"] ) cnt_in_rev.update(dircontent(data, rev["id"], root)) - relation_add_and_compare_result( - RelationType.CNT_EARLY_IN_REV, - cnt_in_rev, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + provenance_storage, RelationType.CNT_EARLY_IN_REV, cnt_in_rev ) # Test content-in-directory relation. # Create flat models for every directory in the dataset. cnt_in_dir: Set[RelationData] = set() for dir in data["directory"]: cnt_in_dir.update(dircontent(data, dir["id"], dir)) - relation_add_and_compare_result( - RelationType.CNT_IN_DIR, - cnt_in_dir, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + provenance_storage, RelationType.CNT_IN_DIR, cnt_in_dir ) # Test content-in-directory relation. # Add root directories to their correspondent revision in the dataset. dir_in_rev = { RelationData(rev["directory"], rev["id"], b".") for rev in data["revision"] } - relation_add_and_compare_result( - RelationType.DIR_IN_REV, - dir_in_rev, - provenance.storage, - provenance_storage, - provenance_storage.with_path(), + provenance_storage, RelationType.DIR_IN_REV, dir_in_rev ) # Test revision-in-origin relation. # Add all revisions that are head of some snapshot branch to the corresponding # origin. rev_in_org = { RelationData( branch["target"], hash_to_bytes(origin_identifier({"url": status["origin"]})), None, ) for status in data["origin_visit_status"] if status["snapshot"] is not None for snapshot in data["snapshot"] if snapshot["id"] == status["snapshot"] for _, branch in snapshot["branches"].items() if branch["target_type"] == "revision" } + # Origins must be inserted in advance (cannot be done by `entity_add` inside + # `relation_add_and_compare_result`). + orgs = { + hash_to_bytes(origin_identifier(origin)): origin["url"] + for origin in data["origin"] + } + assert provenance_storage.origin_add(orgs) relation_add_and_compare_result( - RelationType.REV_IN_ORG, - rev_in_org, - provenance.storage, - provenance_storage, + provenance_storage, RelationType.REV_IN_ORG, rev_in_org ) # Test revision-before-revision relation. # For each revision in the data set add an entry for each parent to the relation. rev_before_rev = { RelationData(parent, rev["id"], None) for rev in data["revision"] for parent in rev["parents"] } - relation_add_and_compare_result( - RelationType.REV_BEFORE_REV, - rev_before_rev, - provenance.storage, - provenance_storage, + provenance_storage, RelationType.REV_BEFORE_REV, rev_before_rev ) - # Test content methods. - # Add all content present in the current repo to both storages, just assigning their - # creation dates. Then check that the inserted content is the same in both cases. - cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]} - assert cnt_dates - assert provenance.storage.content_set_date( - cnt_dates - ) == provenance_storage.content_set_date(cnt_dates) - - assert provenance.storage.content_get(cnt_dates) == provenance_storage.content_get( - cnt_dates - ) - assert provenance.storage.entity_get_all( - EntityType.CONTENT - ) == provenance_storage.entity_get_all(EntityType.CONTENT) - # Test directory methods. - # Of all directories present in the current repo, only assign a date to those - # containing blobs (picking the max date among the available ones). Then check that - # the inserted data is the same in both storages. - def getmaxdate( - dir: Dict[str, Any], cnt_dates: Dict[Sha1Git, datetime] - ) -> Optional[datetime]: - dates = [ - cnt_dates[entry["target"]] - for entry in dir["entries"] - if entry["type"] == "file" - ] - return max(dates) if dates else None - - dir_dates = {dir["id"]: getmaxdate(dir, cnt_dates) for dir in data["directory"]} - assert dir_dates - assert provenance.storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) == provenance_storage.directory_set_date( - {sha1: date for sha1, date in dir_dates.items() if date is not None} - ) - assert provenance.storage.directory_get( - dir_dates - ) == provenance_storage.directory_get(dir_dates) - assert provenance.storage.entity_get_all( - EntityType.DIRECTORY - ) == provenance_storage.entity_get_all(EntityType.DIRECTORY) +def test_provenance_storage_find( + archive: ArchiveInterface, + provenance: ProvenanceInterface, + provenance_storage: ProvenanceStorageInterface, +) -> None: + """Tests `content_find_first` and `content_find_all` methods for every + `ProvenanceStorageInterface` implementation. + """ - # Test revision methods. - # Add all revisions present in the current repo to both storages, assigning their - # dataes and an arbitrary origin to each one. Then check that the inserted data is - # the same in both cases. - rev_dates = {rev["id"]: ts2dt(rev["date"]) for rev in data["revision"]} - assert rev_dates - assert provenance.storage.revision_set_date( - rev_dates - ) == provenance_storage.revision_set_date(rev_dates) - - rev_origins = { - rev["id"]: next(iter(org_urls)) # any arbitrary origin will do - for rev in data["revision"] - } - assert rev_origins - assert provenance.storage.revision_set_origin( - rev_origins - ) == provenance_storage.revision_set_origin(rev_origins) - - assert provenance.storage.revision_get( - rev_dates - ) == provenance_storage.revision_get(rev_dates) - assert provenance.storage.entity_get_all( - EntityType.REVISION - ) == provenance_storage.entity_get_all(EntityType.REVISION) - - # Test location_get. - if provenance_storage.with_path(): - assert provenance.storage.location_get() == provenance_storage.location_get() + # Read data/README.md for more details on how these datasets are generated. + data = load_repo_data("cmdbts2") + fill_storage(archive.storage, data) - # Test content_find_first and content_find_all. + # Test content_find_first and content_find_all, first only executing the + # revision-content algorithm, then adding the origin-revision layer. def adapt_result( result: Optional[ProvenanceResult], with_path: bool ) -> Optional[ProvenanceResult]: if result is not None: return ProvenanceResult( result.content, result.revision, result.date, result.origin, result.path if with_path else b"", ) return result - for cnt in cnt_dates: + # Execute the revision-content algorithm on both storages. + revisions = [ + RevisionEntry(id=rev["id"], date=ts2dt(rev["date"]), root=rev["directory"]) + for rev in data["revision"] + ] + revision_add(provenance, archive, revisions) + revision_add(Provenance(provenance_storage), archive, revisions) + + assert adapt_result( + ProvenanceResult( + content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), + revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), + date=datetime.fromtimestamp(1000000000.0, timezone.utc), + origin=None, + path=b"A/B/C/a", + ), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first( + hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") + ) + + for cnt in {cnt["sha1_git"] for cnt in data["content"]}: assert adapt_result( provenance.storage.content_find_first(cnt), provenance_storage.with_path() ) == provenance_storage.content_find_first(cnt) + assert { + adapt_result(occur, provenance_storage.with_path()) + for occur in provenance.storage.content_find_all(cnt) + } == set(provenance_storage.content_find_all(cnt)) + + # Execute the origin-revision algorithm on both storages. + origins = [ + OriginEntry(url=sta["origin"], snapshot=sta["snapshot"]) + for sta in data["origin_visit_status"] + if sta["snapshot"] is not None + ] + origin_add(provenance, archive, origins) + origin_add(Provenance(provenance_storage), archive, origins) + + assert adapt_result( + ProvenanceResult( + content=hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494"), + revision=hash_to_bytes("c0d8929936631ecbcf9147be6b8aa13b13b014e4"), + date=datetime.fromtimestamp(1000000000.0, timezone.utc), + origin="https://cmdbts2", + path=b"A/B/C/a", + ), + provenance_storage.with_path(), + ) == provenance_storage.content_find_first( + hash_to_bytes("20329687bb9c1231a7e05afe86160343ad49b494") + ) + for cnt in {cnt["sha1_git"] for cnt in data["content"]}: + assert adapt_result( + provenance.storage.content_find_first(cnt), provenance_storage.with_path() + ) == provenance_storage.content_find_first(cnt) assert { adapt_result(occur, provenance_storage.with_path()) for occur in provenance.storage.content_find_all(cnt) } == set(provenance_storage.content_find_all(cnt)) -def test_types(provenance: ProvenanceInterface) -> None: +def test_types(provenance_storage: ProvenanceInterface) -> None: """Checks all methods of ProvenanceStorageInterface are implemented by this backend, and that they have the same signature.""" # Create an instance of the protocol (which cannot be instantiated # directly, so this creates a subclass, then instantiates it) interface = type("_", (ProvenanceStorageInterface,), {})() assert "content_find_first" in dir(interface) missing_methods = [] for meth_name in dir(interface): if meth_name.startswith("_"): continue interface_meth = getattr(interface, meth_name) try: - concrete_meth = getattr(provenance.storage, meth_name) + concrete_meth = getattr(provenance_storage, meth_name) except AttributeError: if not getattr(interface_meth, "deprecated_endpoint", False): # The backend is missing a (non-deprecated) endpoint missing_methods.append(meth_name) continue expected_signature = inspect.signature(interface_meth) actual_signature = inspect.signature(concrete_meth) assert expected_signature == actual_signature, meth_name assert missing_methods == [] # If all the assertions above succeed, then this one should too. # But there's no harm in double-checking. # And we could replace the assertions above by this one, but unlike # the assertions above, it doesn't explain what is missing. - assert isinstance(provenance.storage, ProvenanceStorageInterface) + assert isinstance(provenance_storage, ProvenanceStorageInterface) diff --git a/swh/provenance/tests/test_revision_content_layer.py b/swh/provenance/tests/test_revision_content_layer.py index 4d59114..36c8aa4 100644 --- a/swh/provenance/tests/test_revision_content_layer.py +++ b/swh/provenance/tests/test_revision_content_layer.py @@ -1,447 +1,447 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple import pytest from typing_extensions import TypedDict from swh.model.hashutil import hash_to_bytes from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.interface import EntityType, ProvenanceInterface, RelationType from swh.provenance.model import RevisionEntry from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import ( fill_storage, get_datafile, load_repo_data, ts2dt, ) class SynthRelation(TypedDict): prefix: Optional[str] path: str src: Sha1Git dst: Sha1Git rel_ts: float class SynthRevision(TypedDict): sha1: Sha1Git date: float msg: str R_C: List[SynthRelation] R_D: List[SynthRelation] D_C: List[SynthRelation] def synthetic_revision_content_result(filename: str) -> Iterator[SynthRevision]: """Generates dict representations of synthetic revisions found in the synthetic file (from the data/ directory) given as argument of the generator. Generated SynthRevision (typed dict) with the following elements: "sha1": (Sha1Git) sha1 of the revision, "date": (float) timestamp of the revision, "msg": (str) commit message of the revision, "R_C": (list) new R---C relations added by this revision "R_D": (list) new R-D relations added by this revision "D_C": (list) new D-C relations added by this revision Each relation above is a SynthRelation typed dict with: "path": (str) location "src": (Sha1Git) sha1 of the source of the relation "dst": (Sha1Git) sha1 of the destination of the relation "rel_ts": (float) timestamp of the target of the relation (related to the timestamp of the revision) """ with open(get_datafile(filename), "r") as fobj: yield from _parse_synthetic_revision_content_file(fobj) def _parse_synthetic_revision_content_file( fobj: Iterable[str], ) -> Iterator[SynthRevision]: """Read a 'synthetic' file and generate a dict representation of the synthetic revision for each revision listed in the synthetic file. """ regs = [ "(?PR[0-9]{2,4})?", "(?P[^| ]*)", "([+] )?(?P[^| +]*?)[/]?", "(?P[RDC]) (?P[0-9a-f]{40})", "(?P-?[0-9]+(.[0-9]+)?)", ] regex = re.compile("^ *" + r" *[|] *".join(regs) + r" *(#.*)?$") current_rev: List[dict] = [] for m in (regex.match(line) for line in fobj): if m: d = m.groupdict() if d["revname"]: if current_rev: yield _mk_synth_rev(current_rev) current_rev.clear() current_rev.append(d) if current_rev: yield _mk_synth_rev(current_rev) def _mk_synth_rev(synth_rev: List[Dict[str, str]]) -> SynthRevision: assert synth_rev[0]["type"] == "R" rev = SynthRevision( sha1=hash_to_bytes(synth_rev[0]["sha1"]), date=float(synth_rev[0]["ts"]), msg=synth_rev[0]["revname"], R_C=[], R_D=[], D_C=[], ) current_path = None # path of the last R-D relation we parsed, used a prefix for next D-C # relations for row in synth_rev[1:]: if row["reltype"] == "R---C": assert row["type"] == "C" rev["R_C"].append( SynthRelation( prefix=None, path=row["path"], src=rev["sha1"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) current_path = None elif row["reltype"] == "R-D": assert row["type"] == "D" rev["R_D"].append( SynthRelation( prefix=None, path=row["path"], src=rev["sha1"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) current_path = row["path"] elif row["reltype"] == "D-C": assert row["type"] == "C" rev["D_C"].append( SynthRelation( prefix=current_path, path=row["path"], src=rev["R_D"][-1]["dst"], dst=hash_to_bytes(row["sha1"]), rel_ts=float(row["ts"]), ) ) return rev @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) def test_revision_content_result( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) revisions = {rev["id"]: rev for rev in data["revision"]} rows: Dict[str, Set[Any]] = { "content": set(), "content_in_directory": set(), "content_in_revision": set(), "directory": set(), "directory_in_revision": set(), "location": set(), "revision": set(), } def maybe_path(path: str) -> Optional[bytes]: if provenance.storage.with_path(): return path.encode("utf-8") return None for synth_rev in synthetic_revision_content_result(syntheticfile): revision = revisions[synth_rev["sha1"]] entry = RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth) # each "entry" in the synth file is one new revision rows["revision"].add(synth_rev["sha1"]) assert rows["revision"] == provenance.storage.entity_get_all( EntityType.REVISION ), synth_rev["msg"] # check the timestamp of the revision rev_ts = synth_rev["date"] rev_data = provenance.storage.revision_get([synth_rev["sha1"]])[ synth_rev["sha1"] ] assert ( rev_data.date is not None and rev_ts == rev_data.date.timestamp() ), synth_rev["msg"] # this revision might have added new content objects rows["content"] |= set(x["dst"] for x in synth_rev["R_C"]) rows["content"] |= set(x["dst"] for x in synth_rev["D_C"]) assert rows["content"] == provenance.storage.entity_get_all( EntityType.CONTENT ), synth_rev["msg"] # check for R-C (direct) entries # these are added directly in the content_early_in_rev table rows["content_in_revision"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) assert rows["content_in_revision"] == { (rel.src, rel.dst, rel.path) for rel in provenance.storage.relation_get_all( RelationType.CNT_EARLY_IN_REV ) }, synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: assert ( rev_ts + rc["rel_ts"] == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp() ), synth_rev["msg"] # check directories # each directory stored in the provenance index is an entry # in the "directory" table... rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"]) assert rows["directory"] == provenance.storage.entity_get_all( EntityType.DIRECTORY ), synth_rev["msg"] # ... + a number of rows in the "directory_in_rev" table... # check for R-D entries rows["directory_in_revision"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) assert rows["directory_in_revision"] == { (rel.src, rel.dst, rel.path) for rel in provenance.storage.relation_get_all(RelationType.DIR_IN_REV) }, synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: assert ( rev_ts + rd["rel_ts"] == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp() ), synth_rev["msg"] # ... + a number of rows in the "content_in_dir" table # for content of the directory. # check for D-C entries rows["content_in_directory"] |= set( (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) assert rows["content_in_directory"] == { (rel.src, rel.dst, rel.path) for rel in provenance.storage.relation_get_all(RelationType.CNT_IN_DIR) }, synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: assert ( rev_ts + dc["rel_ts"] == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp() ), synth_rev["msg"] if provenance.storage.with_path(): # check for location entries rows["location"] |= set(x["path"].encode() for x in synth_rev["R_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["D_C"]) rows["location"] |= set(x["path"].encode() for x in synth_rev["R_D"]) - assert rows["location"] == provenance.storage.location_get(), synth_rev[ + assert rows["location"] == provenance.storage.location_get_all(), synth_rev[ "msg" ] @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_all( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] def maybe_path(path: str) -> str: if provenance.storage.with_path(): return path return "" if batch: revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) else: for revision in revisions: revision_add( provenance, archive, [revision], lower=lower, mindepth=mindepth ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_occurrences: Dict[str, List[Tuple[str, float, Optional[str], str]]] = {} for synth_rev in synthetic_revision_content_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: expected_occurrences.setdefault(rc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(rc["path"])) ) for dc in synth_rev["D_C"]: assert dc["prefix"] is not None # to please mypy expected_occurrences.setdefault(dc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(dc["prefix"] + "/" + dc["path"])) ) for content_id, results in expected_occurrences.items(): expected = [(content_id, *result) for result in results] db_occurrences = [ ( occur.content.hex(), occur.revision.hex(), occur.date.timestamp(), occur.origin, occur.path.decode(), ) for occur in provenance.content_find_all(hash_to_bytes(content_id)) ] if provenance.storage.with_path(): # this is not true if the db stores no path, because a same content # that appears several times in a given revision may be reported # only once by content_find_all() assert len(db_occurrences) == len(expected) assert set(db_occurrences) == set(expected) @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) @pytest.mark.parametrize("batch", (True, False)) def test_provenance_heuristics_content_find_first( provenance: ProvenanceInterface, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, batch: bool, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(archive.storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] if batch: revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) else: for revision in revisions: revision_add( provenance, archive, [revision], lower=lower, mindepth=mindepth ) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_first: Dict[str, Tuple[str, float, List[str]]] = {} # dict of tuples (blob_id, rev_id, [path, ...]) the third element for path # is a list because a content can be added at several places in a single # revision, in which case the result of content_find_first() is one of # those path, but we have no guarantee which one it will return. for synth_rev in synthetic_revision_content_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: sha1 = rc["dst"].hex() if sha1 not in expected_first: assert rc["rel_ts"] == 0 expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) else: if rev_ts == expected_first[sha1][1]: expected_first[sha1][2].append(rc["path"]) elif rev_ts < expected_first[sha1][1]: expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) for dc in synth_rev["D_C"]: sha1 = rc["dst"].hex() assert sha1 in expected_first # nothing to do there, this content cannot be a "first seen file" for content_id, (rev_id, ts, paths) in expected_first.items(): occur = provenance.content_find_first(hash_to_bytes(content_id)) assert occur is not None assert occur.content.hex() == content_id assert occur.revision.hex() == rev_id assert occur.date.timestamp() == ts assert occur.origin is None if provenance.storage.with_path(): assert occur.path.decode() in paths