diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py index 0f4cbb5..82693cc 100644 --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,489 +1,511 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from datetime import datetime, timezone import os from types import TracebackType from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union from bson import ObjectId import mongomock import pymongo +from swh.core.statsd import statsd from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) +STORAGE_DURATION_METRIC = "swh_provenance_storage_mongodb_duration_seconds" + class ProvenanceStorageMongoDb: def __init__(self, engine: str, **kwargs): self.engine = engine self.dbname = kwargs.pop("dbname") self.conn_args = kwargs def __enter__(self) -> ProvenanceStorageInterface: self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) def close(self) -> None: self.db.client.close() + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) existing = { x["sha1"]: x for x in self.db.content.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: cnt = existing[sha1] if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): self.db.content.update_one( {"_id": cnt["_id"]}, {"$set": {"ts": ts}} ) else: self.db.content.insert_one( { "sha1": sha1, "ts": ts, "revision": {}, "directory": {}, } ) return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: # get all the revisions # iterate and find the earliest content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: content = self.db.content.find_one({"sha1": id}) if not content: return None occurs = [] for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp(revision["ts"], timezone.utc), origin=origin["url"], path=path, ) ) for directory in self.db.directory.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} ): for revision in self.db.revision.find( {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} ): if revision["preferred"] is not None: origin = self.db.origin.find_one({"sha1": revision["preferred"]}) else: origin = {"url": None} for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( os.path.join(prefix, suffix) if prefix not in [b".", b""] else suffix ) occurs.append( ProvenanceResult( content=id, revision=revision["sha1"], date=datetime.fromtimestamp( revision["ts"], timezone.utc ), origin=origin["url"], path=path, ) ) yield from sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path)) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.content.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) existing = { x["sha1"]: x for x in self.db.directory.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} ) } for sha1, date in data.items(): ts = datetime.timestamp(date) if date is not None else None if sha1 in existing: dir = existing[sha1] if ts is not None and (dir["ts"] is None or ts < dir["ts"]): self.db.directory.update_one( {"_id": dir["_id"]}, {"$set": {"ts": ts}} ) else: self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return { x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) for x in self.db.directory.find( {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, {"sha1": 1, "ts": 1, "_id": 0}, ) } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: return { x["sha1"] for x in self.db.get_collection(entity.value).find( {}, {"sha1": 1, "_id": 0} ) } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) def location_add(self, paths: Iterable[bytes]) -> bool: # TODO: implement this methods if path are to be stored in a separate collection return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) def location_get_all(self) -> Set[bytes]: contents = self.db.content.find({}, {"revision": 1, "_id": 0, "directory": 1}) paths: List[Iterable[bytes]] = [] for content in contents: paths.extend(value for _, value in content["revision"].items()) paths.extend(value for _, value in content["directory"].items()) dirs = self.db.directory.find({}, {"revision": 1, "_id": 0}) for each_dir in dirs: paths.extend(value for _, value in each_dir["revision"].items()) return set(sum(paths, [])) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: if self.engine == "mongomock": self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname) else: # assume real MongoDB server by default self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: existing = { x["sha1"]: x for x in self.db.origin.find( {"sha1": {"$in": list(orgs)}}, {"sha1": 1, "url": 1, "_id": 1} ) } for sha1, url in orgs.items(): if sha1 not in existing: # add new origin self.db.origin.insert_one({"sha1": sha1, "url": url}) return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"}) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return { x["sha1"]: x["url"] for x in self.db.origin.find( {"sha1": {"$in": list(ids)}}, {"sha1": 1, "url": 1, "_id": 0} ) } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: data = ( revs if isinstance(revs, dict) else dict.fromkeys(revs, RevisionData(date=None, origin=None)) ) existing = { x["sha1"]: x for x in self.db.revision.find( {"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "preferred": 1, "_id": 1}, ) } for sha1, info in data.items(): ts = datetime.timestamp(info.date) if info.date is not None else None preferred = info.origin if sha1 in existing: rev = existing[sha1] if ts is None or (rev["ts"] is not None and ts >= rev["ts"]): ts = rev["ts"] if preferred is None: preferred = rev["preferred"] if ts != rev["ts"] or preferred != rev["preferred"]: self.db.revision.update_one( {"_id": rev["_id"]}, {"$set": {"ts": ts, "preferred": preferred}}, ) else: self.db.revision.insert_one( { "sha1": sha1, "preferred": preferred, "origin": [], "revision": [], "ts": ts, } ) return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: return { x["sha1"]: RevisionData( date=datetime.fromtimestamp(x["ts"], timezone.utc) if x["ts"] else None, origin=x["preferred"], ) for x in self.db.revision.find( { "sha1": {"$in": list(ids)}, "$or": [{"preferred": {"$ne": None}}, {"ts": {"$ne": None}}], }, {"sha1": 1, "preferred": 1, "ts": 1, "_id": 0}, ) } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: src_relation, *_, dst_relation = relation.value.split("_") dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst_relation).find( { "sha1": { "$in": list({rel.dst for rels in data.values() for rel in rels}) } }, {"_id": 1, "sha1": 1}, ) } denorm: Dict[Sha1Git, Any] = {} for src, rels in data.items(): for rel in rels: if src_relation != "revision": denorm.setdefault(src, {}).setdefault( str(dst_objs[rel.dst]), [] ).append(rel.path) else: denorm.setdefault(src, []).append(dst_objs[rel.dst]) src_objs = { x["sha1"]: x for x in self.db.get_collection(src_relation).find( {"sha1": {"$in": list(denorm.keys())}} ) } for sha1, dsts in denorm.items(): # update if src_relation != "revision": k = { obj_id: list(set(paths + dsts.get(obj_id, []))) for obj_id, paths in src_objs[sha1][dst_relation].items() } self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, {"$set": {dst_relation: dict(dsts, **k)}}, ) else: self.db.get_collection(src_relation).update_one( {"_id": src_objs[sha1]["_id"]}, { "$set": { dst_relation: list(set(src_objs[sha1][dst_relation] + dsts)) } }, ) return True + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") sha1s = set(ids) if not reverse: empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {"sha1": {"$in": list(sha1s)}, dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1}, ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_sha1, dst_obj_id in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } else: dst_objs = { x["sha1"]: x["_id"] for x in self.db.get_collection(dst).find( {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} ) } src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {}, {"_id": 0, "sha1": 1, dst: 1} ) } result: Dict[Sha1Git, Set[RelationData]] = {} if src != "revision": for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): for dst_obj_str, paths in denorm.items(): if dst_obj_id == ObjectId(dst_obj_str): result.setdefault(src_sha1, set()).update( RelationData(dst=dst_sha1, path=path) for path in paths ) else: for dst_sha1, dst_obj_id in dst_objs.items(): for src_sha1, denorm in src_objs.items(): if dst_obj_id in { ObjectId(dst_obj_str) for dst_obj_str in denorm }: result.setdefault(src_sha1, set()).add( RelationData(dst=dst_sha1, path=None) ) return result + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: src, *_, dst = relation.value.split("_") empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] src_objs = { x["sha1"]: x[dst] for x in self.db.get_collection(src).find( {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} ) } dst_ids = list( {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} ) dst_objs = { x["_id"]: x["sha1"] for x in self.db.get_collection(dst).find( {"_id": {"$in": dst_ids}}, {"_id": 1, "sha1": 1} ) } if src != "revision": return { src_sha1: { RelationData(dst=dst_sha1, path=path) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_str, paths in denorm.items() for path in paths if dst_obj_id == ObjectId(dst_obj_str) } for src_sha1, denorm in src_objs.items() } else: return { src_sha1: { RelationData(dst=dst_sha1, path=None) for dst_obj_id, dst_sha1 in dst_objs.items() for dst_obj_ref in denorm if dst_obj_id == dst_obj_ref } for src_sha1, denorm in src_objs.items() } + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"}) def with_path(self) -> bool: return True diff --git a/swh/provenance/postgresql/provenance.py b/swh/provenance/postgresql/provenance.py index dd25bc1..e85b742 100644 --- a/swh/provenance/postgresql/provenance.py +++ b/swh/provenance/postgresql/provenance.py @@ -1,342 +1,364 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations from contextlib import contextmanager from datetime import datetime import itertools import logging from types import TracebackType from typing import Dict, Generator, Iterable, List, Optional, Set, Type, Union import psycopg2.extensions import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb +from swh.core.statsd import statsd from swh.model.model import Sha1Git from ..interface import ( EntityType, ProvenanceResult, ProvenanceStorageInterface, RelationData, RelationType, RevisionData, ) LOGGER = logging.getLogger(__name__) +STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds" + class ProvenanceStoragePostgreSql: def __init__(self, raise_on_commit: bool = False, **kwargs) -> None: self.conn_args = kwargs self._flavor: Optional[str] = None self.raise_on_commit = raise_on_commit def __enter__(self) -> ProvenanceStorageInterface: self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: self.close() @contextmanager def transaction( self, readonly: bool = False ) -> Generator[psycopg2.extensions.cursor, None, None]: self.conn.set_session(readonly=readonly) with self.conn: with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: yield cur @property def flavor(self) -> str: if self._flavor is None: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT swh_get_dbflavor() AS flavor") self._flavor = cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def denormalized(self) -> bool: return "denormalized" in self.flavor + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) def close(self) -> None: self.conn.close() + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) def content_add( self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("content", cnts) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: sql = "SELECT * FROM swh_provenance_content_find_first(%s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id,)) row = cursor.fetchone() return ProvenanceResult(**row) if row is not None else None + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: sql = "SELECT * FROM swh_provenance_content_find_all(%s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=(id, limit)) yield from (ProvenanceResult(**row) for row in cursor) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_get"}) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) def directory_add( self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] ) -> bool: return self._entity_set_date("directory", dirs) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "entity_get_all"}) def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: with self.transaction(readonly=True) as cursor: cursor.execute(f"SELECT sha1 FROM {entity.value}") return {row["sha1"] for row in cursor} + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_add"}) def location_add(self, paths: Iterable[bytes]) -> bool: if not self.with_path(): return True try: values = [(path,) for path in paths] if values: sql = """ INSERT INTO location(path) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=values) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "location_get_all"}) def location_get_all(self) -> Set[bytes]: with self.transaction(readonly=True) as cursor: cursor.execute("SELECT location.path AS path FROM location") return {row["path"] for row in cursor} + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: try: if orgs: sql = """ INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ with self.transaction() as cursor: psycopg2.extras.execute_values( cur=cursor, sql=sql, argslist=orgs.items() ) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: self.conn = BaseDb.connect(**self.conn_args).conn BaseDb.adapt_conn(self.conn) with self.transaction() as cursor: cursor.execute("SET timezone TO 'UTC'") + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_get"}) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) urls.update((row["sha1"], row["url"]) for row in cursor) return urls + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_add"}) def revision_add( self, revs: Union[Iterable[Sha1Git], Dict[Sha1Git, RevisionData]] ) -> bool: if isinstance(revs, dict): data = [(sha1, rev.date, rev.origin) for sha1, rev in revs.items()] else: data = [(sha1, None, None) for sha1 in revs] try: if data: sql = """ INSERT INTO revision(sha1, date, origin) (SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin FROM (VALUES %s) AS V(rev, date, org) LEFT JOIN origin AS O ON (O.sha1=V.org::sha1_git)) ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date, revision.date), origin=COALESCE(EXCLUDED.origin, revision.origin) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cur=cursor, sql=sql, argslist=data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "revision_get"}) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: result: Dict[Sha1Git, RevisionData] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT R.sha1, R.date, O.sha1 AS origin FROM revision AS R LEFT JOIN origin AS O ON (O.id=R.origin) WHERE R.sha1 IN ({values}) AND (R.date is not NULL OR O.sha1 is not NULL) """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) result.update( (row["sha1"], RevisionData(date=row["date"], origin=row["origin"])) for row in cursor ) return result + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_add"}) def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: rows = [(src, rel.dst, rel.path) for src, dsts in data.items() for rel in dsts] try: if rows: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") # Put the next three queries in a manual single transaction: # they use the same temp table with self.transaction() as cursor: cursor.execute("SELECT swh_mktemp_relation_add()") psycopg2.extras.execute_values( cur=cursor, sql="INSERT INTO tmp_relation_add(src, dst, path) VALUES %s", argslist=rows, ) sql = "SELECT swh_provenance_relation_add_from_temp(%s, %s, %s)" cursor.execute(query=sql, vars=(rel_table, src_table, dst_table)) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get"}) def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, ids, reverse) + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) def relation_get_all( self, relation: RelationType ) -> Dict[Sha1Git, Set[RelationData]]: return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: # TODO: consider splitting this query in several ones if sha1s is too big! values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) AND date IS NOT NULL """ with self.transaction(readonly=True) as cursor: cursor.execute(query=sql, vars=sha1s) dates.update((row["sha1"], row["date"]) for row in cursor) return dates def _entity_set_date( self, entity: Literal["content", "directory"], dates: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]], ) -> bool: data = dates if isinstance(dates, dict) else dict.fromkeys(dates) try: if data: sql = f""" INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ with self.transaction() as cursor: psycopg2.extras.execute_values(cursor, sql, argslist=data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message LOGGER.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_get( self, relation: RelationType, ids: Optional[Iterable[Sha1Git]], reverse: bool = False, ) -> Dict[Sha1Git, Set[RelationData]]: result: Dict[Sha1Git, Set[RelationData]] = {} sha1s: List[Sha1Git] if ids is not None: sha1s = list(ids) filter = "filter-src" if not reverse else "filter-dst" else: sha1s = [] filter = "no-filter" if filter == "no-filter" or sha1s: rel_table = relation.value src_table, *_, dst_table = rel_table.split("_") sql = "SELECT * FROM swh_provenance_relation_get(%s, %s, %s, %s, %s)" with self.transaction(readonly=True) as cursor: cursor.execute( query=sql, vars=(rel_table, src_table, dst_table, filter, sha1s) ) for row in cursor: src = row.pop("src") result.setdefault(src, set()).add(RelationData(**row)) return result + @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "with_path"}) def with_path(self) -> bool: return "with-path" in self.flavor