diff --git a/swh/provenance/interface.py b/swh/provenance/interface.py index 4dcce81..7a026f8 100644 --- a/swh/provenance/interface.py +++ b/swh/provenance/interface.py @@ -1,330 +1,330 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass from datetime import datetime import enum from typing import Dict, Generator, Iterable, Optional, Set, Union from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry class EntityType(enum.Enum): CONTENT = "content" DIRECTORY = "directory" REVISION = "revision" ORIGIN = "origin" class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" @dataclass(eq=True, frozen=True) class ProvenanceResult: content: Sha1Git revision: Sha1Git date: datetime origin: Optional[str] path: bytes @dataclass(eq=True, frozen=True) class RevisionData: """Object representing the data associated to a revision in the provenance model, where `date` is the optional date of the revision (specifying it acknowledges that the revision was already processed by the revision-content algorithm); and `origin` identifies the preferred origin for the revision, if any. """ date: Optional[datetime] origin: Optional[Sha1Git] @dataclass(eq=True, frozen=True) class RelationData: """Object representing a relation entry in the provenance model, where `src` and `dst` are the sha1 ids of the entities being related, and `path` is optional depending on the relation being represented. """ dst: Sha1Git path: Optional[bytes] @runtime_checkable class ProvenanceStorageInterface(Protocol): @remote_api_endpoint("content_add") def content_add( - self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: """Add blobs identified by sha1 ids, with an optional associated date (as paired in `cnts`) to the provenance storage. Return a boolean stating whether the information was successfully stored. """ ... @remote_api_endpoint("content_find_first") def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... @remote_api_endpoint("content_find_all") def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... @remote_api_endpoint("content_get") def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("directory_add") def directory_add( - self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: """Add directories identified by sha1 ids, with an optional associated date (as paired in `dirs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("directory_get") def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each directory sha1 in `ids`. If some directory has no associated date, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("entity_get_all") def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: """Retrieve all sha1 ids for entities of type `entity` present in the provenance model. """ ... @remote_api_endpoint("location_add") def location_add(self, paths: Iterable[bytes]) -> bool: """Register the given `paths` in the storage.""" ... @remote_api_endpoint("location_get_all") def location_get_all(self) -> Set[bytes]: """Retrieve all paths present in the provenance model.""" ... @remote_api_endpoint("origin_add") def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: """Add origins identified by sha1 ids, with their corresponding url (as paired in `orgs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("origin_get") def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: """Retrieve the associated url for each origin sha1 in `ids`.""" ... @remote_api_endpoint("revision_add") def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: """Add revisions identified by sha1 ids, with optional associated date or origin (as paired in `revs`) to the provenance storage. Return a boolean stating if the information was successfully stored. """ ... @remote_api_endpoint("revision_get") def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... @remote_api_endpoint("relation_add") def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: """Add entries in the selected `relation`. This method assumes all entities being related are already registered in the storage. See `content_add`, `directory_add`, `origin_add`, and `revision_add`. """ ... @remote_api_endpoint("relation_get") def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... @remote_api_endpoint("relation_get_all") def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: """Retrieve all entries in the selected `relation` that are present in the provenance model. """ ... @remote_api_endpoint("with_path") def with_path(self) -> bool: ... @runtime_checkable class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `directory` in the provenance model. `prefix` is the relative path from `directory` to `blob` (excluding `blob`'s name). """ ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `revision` in the provenance model. `prefix` is the absolute path from `revision`'s root directory to `blob` (excluding `blob`'s name). """ ... def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: """Retrieve the earliest known date of `blob`.""" ... def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each blob in `blobs`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: """Associate `date` to `blob` as it's earliest known date.""" ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: """Associate `directory` with `revision` in the provenance model. `path` is the absolute path from `revision`'s root directory to `directory` (including `directory`'s name). """ ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: """Retrieve the earliest known date of `directory` as an isochrone frontier in the provenance model. """ ... def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each directory in `dirs` as isochrone frontiers provenance model. If some directory has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: """Associate `date` to `directory` as it's earliest known date as an isochrone frontier in the provenance model. """ ... def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... def revision_add(self, revision: RevisionEntry) -> None: """Add `revision` to the provenance model. This implies storing `revision`'s date in the model, thus `revision.date` must be a valid date. """ ... def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `head` as an ancestor of the latter.""" ... def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `origin` as a head revision of the latter (ie. the target of an snapshot for `origin` in the archive).""" ... def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: """Retrieve the date associated to `revision`.""" ... def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: """Retrieve the preferred origin associated to `revision`.""" ... def revision_in_history(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be an ancestor of some head revision in the provenance model. """ ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `origin` as the preferred origin for `revision`.""" ... def revision_visited(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be a head revision for some origin in the provenance model. """ ... diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py index f963746..664fa45 100644 --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,460 +1,460 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import os from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union from bson import ObjectId import pymongo.database from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) class ProvenanceStorageMongoDb: def __init__(self, db: pymongo.database.Database): self.db = db def content_add( - self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) existing = { x["sha1"]: x for x in self.db.content.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: cnt = existing[sha1] if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): self.db.content.update_one( {"_id": cnt["_id"]}, {"$set": {"ts": ts}} ) else: self.db.content.insert_one( { "sha1": sha1, "ts": ts, "revision": {}, "directory": {}, } ) return True def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: # get all the revisions # iterate and find the earliest content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) for directory in self.db.directory.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} ): for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( os.path.join(prefix, suffix) if prefix not in [b".", b""] else suffix ) occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp( revision["ts"], timezone.utc ), origin=origin["url"], path=path, ) ) yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.content.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def directory_add( - self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) existing = { x["sha1"]: x for x in self.db.directory.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: dir = existing[sha1] if ts is not None and (dir["ts"] is None or ts < dir["ts"]): self.db.directory.update_one( {"_id": dir["_id"]}, {"$set": {"ts": ts}} ) else: self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) return True def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.directory.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] for x in self.db.get_collection(entity.value).find( {}, {"sha1": 1, "_id": 0} ) } def location_add(self, paths: Iterable[bytes]) -> bool: # TODO: implement this methods if path are to be stored in a separate collection return True def location_get_all(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] for content in contents: paths.extend(value for _, value in content["revision"].items()) paths.extend(value for _, value in content["directory"].items()) dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) for each_dir in dirs: paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x for x in self.db.origin.find( {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} ) } for sha1, url in orgs.items(): if sha1 not in existing: # add new origin self.db.origin.insert_one({"sha1": sha1, "url": url}) return True def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return { x["sha1"]: x["url"] for x in self.db.origin.find( {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} ) } def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: data = ( revs if isinstance(revs, dict) else dict.fromkeys(revs, RevisionData(date=None, origin=None)) ) existing = { x["sha1"]: x for x in self.db.revision.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, ) } for sha1, info in data.items(): ts = datetime.timestamp(info.date) if info.date is not None else None preferred = info.origin if sha1 in existing: rev = existing[sha1] if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): ts = rev["ts"] if preferred is None: preferred = rev["preferred"] if ts != rev["ts"] or preferred != rev["preferred"]: self.db.revision.update_one( {"_id": rev["_id"]}, {"$set": {"ts": ts, "preferred": preferred}}, ) else: self.db.revision.insert_one( { "sha1": sha1, "preferred": preferred, "origin": [], "revision": [], "ts": ts, } ) return True def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return { x["sha1"]: RevisionData( date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, origin=x["preferred"], ) for x in self.db.revision.find( { "sha1": {"$in": list(ids)}, "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], }, {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, ) } def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( { "sha1": { "$in": list({rel.dst for rels in data.values() for rel in rels}) } }, {"_id": 1, "sha1": 1}, ) } denorm: Dict[Sha1Git, Any] = {} for src, rels in data.items(): for rel in rels: if src_relation != "revision": denorm.setdefault(src, {}).setdefault( str(dst_objs[rel.dst]), [] ).append(rel.path) else: denorm.setdefault(src, []).append(dst_objs[rel.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( {"sha1": {"$in": list(denorm.keys())}} ) } for sha1, dsts in denorm.items(): # update if src_relation != "revision": k = { obj_id: list(set(paths + dsts.get(obj_id, []))) for obj_id, paths in src_objs[sha1][dst_relation].items() } self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, {"$set": {dst_relation: dict(dsts, **k)}}, ) else: self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, { "$set": { dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) } }, ) return True def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1}, ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } else: dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} ) } src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {}, {"_id": 0, "sha1": 1, dst: 1} ) } result: Dict[Sha1Git, Set[RelationData]] = {} if src != "revision": for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): for dst_obj_str, paths in denorm.items(): if dst_obj_id == ObjectId(dst_obj_str): result.setdefault(src_sha1, set()).update( RelationData(dst=dst_sha1, path=path) for path in paths ) else: for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): if dst_obj_id in { ObjectId(dst_obj_str) for dst_obj_str in denorm }: result.setdefault(src_sha1, set()).add( RelationData(dst=dst_sha1, path=None) ) return result def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["_id"]: x["sha1"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } def with_path(self) -> bool: return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index b4c4ded..358af8f 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,324 +1,322 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime import itertools import logging from typing import Dict, Generator, Iterable, List, Optional, Set, Union import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) class ProvenanceStoragePostgreSql: def __init__( self, conn: psycopg2.extensions.connection, raise_on_commit: bool = False ) -> None: BaseDb.adapt_conn(conn) self.conn = conn with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extensions.cursor, None, None]: self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor def content_add( - self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("content", cnts) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_add( - self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]] + self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("directory", dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} def location_add(self, paths: Iterable[bytes]) -> bool: if not self.with_path(): return True try: values = [(path,) for path in paths] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=values) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: try: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items() ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: if isinstance(revs, dict): data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] else: data = [(sha1, None, None) for sha1 in revs] try: if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cur=cursor, sql=sql, argslist=data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: - rows = [ - (src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts - ] + rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] try: if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( self, entity: Literal["content", "directory"], - dates: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]], + dates: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]], ) -> bool: data = dates if isinstance(dates, dict) else dict.fromkeys(dates) try: if data: sql = f""" INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) for row in cursor: src = row.pop("src") result.setdefault(src, set()).add(RelationData(**row)) return result def with_path(self) -> bool: return "with-path" in self.flavor