diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -10,7 +10,7 @@ from swh.core.db import BaseDb from swh.model.model import Sha1Git -from ..provenance import ProvenanceResult, RelationType +from ..provenance import EntityType, ProvenanceResult, RelationType class ProvenanceDBBase: @@ -60,6 +60,16 @@ def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + sql = f"SELECT sha1 FROM {entity.value}" + self.cursor.execute(sql) + return {row["sha1"] for row in self.cursor.fetchall()} + + def location_get(self) -> Set[bytes]: + sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location" + self.cursor.execute(sql) + return {row["path"] for row in self.cursor.fetchall()} + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: @@ -202,44 +212,12 @@ def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: - result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() - sha1s = tuple(ids) - if sha1s: - table = relation.value - src, *_, dst = table.split("_") - - # TODO: improve this! - if src == "revision" and dst == "revision": - src_field = "prev" - dst_field = "next" - else: - src_field = src - dst_field = dst - - joins = [ - f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", - f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", - ] - selected = ["S.sha1 AS src", "D.sha1 AS dst"] - selector = "S.sha1" if not reverse else "D.sha1" - - if self._relation_uses_location_table(relation): - joins.append("INNER JOIN location AS L ON (L.id=R.location)") - selected.append("L.path AS path") - else: - selected.append("NULL AS path") + return self._relation_get(relation, ids, reverse) - sql = f""" - SELECT {", ".join(selected)} - FROM {table} AS R - {" ".join(joins)} - WHERE {selector} IN %s - """ - self.cursor.execute(sql, (sha1s,)) - result.update( - (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() - ) - return result + def relation_get_all( + self, relation: RelationType + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + return self._relation_get(relation, None) def _entity_get_date( self, @@ -281,5 +259,57 @@ raise return False + def _relation_get( + self, + relation: RelationType, + ids: Optional[Iterable[Sha1Git]], + reverse: bool = False, + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() + + sha1s: Optional[Tuple[Tuple[bytes, ...]]] + if ids is not None: + sha1s = (tuple(ids),) + where = f"WHERE {'S.sha1' if not reverse else 'D.sha1'} IN %s" + else: + sha1s = None + where = "" + + if sha1s is None or sha1s[0]: + table = relation.value + src, *_, dst = table.split("_") + + # TODO: improve this! + if src == "revision" and dst == "revision": + src_field = "prev" + dst_field = "next" + else: + src_field = src + dst_field = dst + + joins = [ + f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", + f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", + ] + selected = ["S.sha1 AS src", "D.sha1 AS dst"] + + if self._relation_uses_location_table(relation): + joins.append("INNER JOIN location AS L ON (L.id=R.location)") + selected.append("L.path AS path") + else: + selected.append("NULL AS path") + + sql = f""" + SELECT {", ".join(selected)} + FROM {table} AS R + {" ".join(joins)} + {where} + """ + self.cursor.execute(sql, sha1s) + result.update( + (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() + ) + return result + def _relation_uses_location_table(self, relation: RelationType) -> bool: ... diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -9,6 +9,13 @@ from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry +class EntityType(enum.Enum): + CONTENT = "content" + DIRECTORY = "directory" + REVISION = "revision" + ORIGIN = "origin" + + class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" @@ -72,6 +79,16 @@ """ ... + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + """Retrieve all sha1 ids for entities of type `entity` present in the provenance + model. + """ + ... + + def location_get(self) -> Set[bytes]: + """Retrieve all paths present in the provenance model.""" + ... + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return a boolean stating whether the information was successfully stored. @@ -126,6 +143,15 @@ """ ... + def relation_get_all( + self, relation: RelationType + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + """Retrieve all tuples of the form (`src`, `dst`, `path`) present in the + provenance model, where `src` and `dst` are the sha1 ids of the entities being + related, and `path` is optional depending on the selected `relation`. + """ + ... + @runtime_checkable class ProvenanceInterface(Protocol): diff --git a/swh/provenance/tests/test_provenance_heuristics.py b/swh/provenance/tests/test_provenance_heuristics.py --- a/swh/provenance/tests/test_provenance_heuristics.py +++ b/swh/provenance/tests/test_provenance_heuristics.py @@ -3,18 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime from typing import Any, Dict, List, Optional, Set, Tuple -import psycopg2 import pytest from swh.model.hashutil import hash_to_bytes -from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.model import RevisionEntry from swh.provenance.postgresql.provenancedb_base import ProvenanceDBBase -from swh.provenance.provenance import ProvenanceInterface +from swh.provenance.provenance import EntityType, ProvenanceInterface, RelationType from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import ( fill_storage, @@ -26,77 +23,6 @@ from swh.storage.postgresql.storage import Storage -def sha1s(cur: psycopg2.extensions.cursor, table: str) -> Set[Sha1Git]: - """return the 'sha1' column from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT sha1 FROM {table}") - return set(row["sha1"].hex() for row in cur.fetchall()) - - -def locations(cur: psycopg2.extensions.cursor) -> Set[bytes]: - """return the 'path' column from the DB location table - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute("SELECT encode(location.path::bytea, 'escape') AS path FROM location") - return set(row["path"] for row in cur.fetchall()) - - -def relations( - cur: psycopg2.extensions.cursor, src: str, dst: str -) -> Set[Tuple[Sha1Git, Sha1Git, bytes]]: - """return the triplets ('sha1', 'sha1', 'path') from the DB - - for the relation between 'src' table and 'dst' table - (i.e. for C-R, C-D and D-R relations). - - 'cur' is a cursor to the provenance index DB. - """ - relation = f"{src}_in_{dst}" - cur.execute("SELECT swh_get_dbflavor() AS flavor") - with_path = cur.fetchone()["flavor"] == "with-path" - - # note that the columns have the same name as the relations they refer to, - # so we can write things like "rel.{dst}=src.id" in the query below - if with_path: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - encode(location.path::bytea, 'escape') AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (relation.{src} = src.id) - INNER JOIN {dst} AS dst ON (relation.{dst} = dst.id) - INNER JOIN location ON (relation.location = location.id) - """ - ) - else: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - '' AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (src.id = relation.{src}) - INNER JOIN {dst} AS dst ON (dst.id = relation.{dst}) - """ - ) - return set((row["src"], row["dst"], row["path"]) for row in cur.fetchall()) - - -def get_timestamp( - cur: psycopg2.extensions.cursor, table: str, sha1: Sha1Git -) -> List[datetime]: - """return the date for the 'sha1' from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT date FROM {table} WHERE sha1=%s", (sha1,)) - return [row["date"].timestamp() for row in cur.fetchall()] - - @pytest.mark.parametrize( "repo, lower, mindepth", ( @@ -133,14 +59,12 @@ "location": set(), "revision": set(), } - assert isinstance(provenance.storage, ProvenanceDBBase) - cursor = provenance.storage.cursor - def maybe_path(path: str) -> str: + def maybe_path(path: str) -> Optional[bytes]: assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: - return path - return "" + return path.encode("utf-8") + return None for synth_rev in synthetic_result(syntheticfile): revision = revisions[synth_rev["sha1"]] @@ -152,77 +76,87 @@ revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth) # each "entry" in the synth file is one new revision - rows["revision"].add(synth_rev["sha1"].hex()) - assert rows["revision"] == sha1s(cursor, "revision"), synth_rev["msg"] + rows["revision"].add(synth_rev["sha1"]) + assert rows["revision"] == provenance.storage.entity_get_all( + EntityType.REVISION + ), synth_rev["msg"] # check the timestamp of the revision rev_ts = synth_rev["date"] - assert get_timestamp(cursor, "revision", synth_rev["sha1"]) == [ - rev_ts - ], synth_rev["msg"] + rev_date, _ = provenance.storage.revision_get([synth_rev["sha1"]])[ + synth_rev["sha1"] + ] + assert rev_date is not None and rev_ts == rev_date.timestamp(), synth_rev["msg"] # this revision might have added new content objects - rows["content"] |= set(x["dst"].hex() for x in synth_rev["R_C"]) - rows["content"] |= set(x["dst"].hex() for x in synth_rev["D_C"]) - assert rows["content"] == sha1s(cursor, "content"), synth_rev["msg"] + rows["content"] |= set(x["dst"] for x in synth_rev["R_C"]) + rows["content"] |= set(x["dst"] for x in synth_rev["D_C"]) + assert rows["content"] == provenance.storage.entity_get_all( + EntityType.CONTENT + ), synth_rev["msg"] # check for R-C (direct) entries # these are added directly in the content_early_in_rev table rows["content_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) - assert rows["content_in_revision"] == relations( - cursor, "content", "revision" + assert rows["content_in_revision"] == provenance.storage.relation_get_all( + RelationType.CNT_EARLY_IN_REV ), synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: - assert get_timestamp(cursor, "content", rc["dst"]) == [ + assert ( rev_ts + rc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp() + ), synth_rev["msg"] # check directories # each directory stored in the provenance index is an entry # in the "directory" table... - rows["directory"] |= set(x["dst"].hex() for x in synth_rev["R_D"]) - assert rows["directory"] == sha1s(cursor, "directory"), synth_rev["msg"] + rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"]) + assert rows["directory"] == provenance.storage.entity_get_all( + EntityType.DIRECTORY + ), synth_rev["msg"] # ... + a number of rows in the "directory_in_rev" table... # check for R-D entries rows["directory_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_D"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) - assert rows["directory_in_revision"] == relations( - cursor, "directory", "revision" + assert rows["directory_in_revision"] == provenance.storage.relation_get_all( + RelationType.DIR_IN_REV ), synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: - assert get_timestamp(cursor, "directory", rd["dst"]) == [ + assert ( rev_ts + rd["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp() + ), synth_rev["msg"] # ... + a number of rows in the "content_in_dir" table # for content of the directory. # check for D-C entries rows["content_in_directory"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["D_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) - assert rows["content_in_directory"] == relations( - cursor, "content", "directory" + assert rows["content_in_directory"] == provenance.storage.relation_get_all( + RelationType.CNT_IN_DIR ), synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: - assert get_timestamp(cursor, "content", dc["dst"]) == [ + assert ( rev_ts + dc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp() + ), synth_rev["msg"] + assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: # check for location entries rows["location"] |= set(x["path"] for x in synth_rev["R_C"]) rows["location"] |= set(x["path"] for x in synth_rev["D_C"]) rows["location"] |= set(x["path"] for x in synth_rev["R_D"]) - assert rows["location"] == locations(cursor), synth_rev["msg"] + assert rows["location"] == provenance.storage.location_get(), synth_rev[ + "msg" + ] @pytest.mark.parametrize(