diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py index 4fbf470..9d8acdd 100644 --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -1,285 +1,315 @@ from datetime import datetime import itertools import logging from typing import Dict, Generator, Iterable, Optional, Set, Tuple import psycopg2 import psycopg2.extras from typing_extensions import Literal from swh.core.db import BaseDb from swh.model.model import Sha1Git -from ..provenance import ProvenanceResult, RelationType +from ..provenance import EntityType, ProvenanceResult, RelationType class ProvenanceDBBase: raise_on_commit: bool = False def __init__(self, conn: psycopg2.extensions.connection): BaseDb.adapt_conn(conn) conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_session(autocommit=True) self.conn = conn self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) # XXX: not sure this is the best place to do it! sql = "SET timezone TO 'UTC'" self.cursor.execute(sql) self._flavor: Optional[str] = None @property def flavor(self) -> str: if self._flavor is None: sql = "SELECT swh_get_dbflavor() AS flavor" self.cursor.execute(sql) self._flavor = self.cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def with_path(self) -> bool: return self.flavor == "with-path" def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: ... def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("content", dates) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("directory", dates) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + sql = f"SELECT sha1 FROM {entity.value}" + self.cursor.execute(sql) + return {row["sha1"] for row in self.cursor.fetchall()} + + def location_get(self) -> Set[bytes]: + sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location" + self.cursor.execute(sql) + return {row["path"] for row in self.cursor.fetchall()} + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: sql = """ LOCK TABLE ONLY origin; INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, urls.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) urls.update((row["sha1"], row["url"]) for row in self.cursor.fetchall()) return urls def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("revision", dates) def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: try: if origins: sql = """ LOCK TABLE ONLY revision; INSERT INTO revision(sha1, origin) (SELECT V.rev AS sha1, O.id AS origin FROM (VALUES %s) AS V(rev, org) JOIN origin AS O ON (O.sha1=V.org)) ON CONFLICT (sha1) DO UPDATE SET origin=EXCLUDED.origin """ psycopg2.extras.execute_values(self.cursor, sql, origins.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get( self, ids: Iterable[Sha1Git] ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: result: Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date, origin FROM revision WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) result.update( (row["sha1"], (row["date"], row["origin"])) for row in self.cursor.fetchall() ) return result def relation_add( self, relation: RelationType, data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], ) -> bool: try: if data: table = relation.value src, *_, dst = table.split("_") if src != "origin": # Origin entries should be inserted previously as they require extra # non-null information srcs = tuple(set((sha1,) for (sha1, _, _) in data)) sql = f""" LOCK TABLE ONLY {src}; INSERT INTO {src}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, srcs) if dst != "origin": # Origin entries should be inserted previously as they require extra # non-null information dsts = tuple(set((sha1,) for (_, sha1, _) in data)) sql = f""" LOCK TABLE ONLY {dst}; INSERT INTO {dst}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, dsts) joins = [ f"INNER JOIN {src} AS S ON (S.sha1=V.src)", f"INNER JOIN {dst} AS D ON (D.sha1=V.dst)", ] selected = ["S.id", "D.id"] if self._relation_uses_location_table(relation): locations = tuple(set((path,) for (_, _, path) in data)) sql = """ LOCK TABLE ONLY location; INSERT INTO location(path) VALUES %s ON CONFLICT (path) DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, locations) joins.append("INNER JOIN location AS L ON (L.path=V.path)") selected.append("L.id") sql = f""" INSERT INTO {table} (SELECT {", ".join(selected)} FROM (VALUES %s) AS V(src, dst, path) {''' '''.join(joins)}) ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: - result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() - sha1s = tuple(ids) - if sha1s: - table = relation.value - src, *_, dst = table.split("_") - - # TODO: improve this! - if src == "revision" and dst == "revision": - src_field = "prev" - dst_field = "next" - else: - src_field = src - dst_field = dst - - joins = [ - f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", - f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", - ] - selected = ["S.sha1 AS src", "D.sha1 AS dst"] - selector = "S.sha1" if not reverse else "D.sha1" - - if self._relation_uses_location_table(relation): - joins.append("INNER JOIN location AS L ON (L.id=R.location)") - selected.append("L.path AS path") - else: - selected.append("NULL AS path") + return self._relation_get(relation, ids, reverse) - sql = f""" - SELECT {", ".join(selected)} - FROM {table} AS R - {" ".join(joins)} - WHERE {selector} IN %s - """ - self.cursor.execute(sql, (sha1s,)) - result.update( - (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() - ) - return result + def relation_get_all( + self, relation: RelationType + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + return self._relation_get(relation, None) def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall()) return dates def _entity_set_date( self, entity: Literal["content", "directory", "revision"], data: Dict[Sha1Git, datetime], ) -> bool: try: if data: sql = f""" LOCK TABLE ONLY {entity}; INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ psycopg2.extras.execute_values(self.cursor, sql, data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False + def _relation_get( + self, + relation: RelationType, + ids: Optional[Iterable[Sha1Git]], + reverse: bool = False, + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() + + sha1s: Optional[Tuple[Tuple[bytes, ...]]] + if ids is not None: + sha1s = (tuple(ids),) + where = f"WHERE {'S.sha1' if not reverse else 'D.sha1'} IN %s" + else: + sha1s = None + where = "" + + if sha1s is None or sha1s[0]: + table = relation.value + src, *_, dst = table.split("_") + + # TODO: improve this! + if src == "revision" and dst == "revision": + src_field = "prev" + dst_field = "next" + else: + src_field = src + dst_field = dst + + joins = [ + f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", + f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", + ] + selected = ["S.sha1 AS src", "D.sha1 AS dst"] + + if self._relation_uses_location_table(relation): + joins.append("INNER JOIN location AS L ON (L.id=R.location)") + selected.append("L.path AS path") + else: + selected.append("NULL AS path") + + sql = f""" + SELECT {", ".join(selected)} + FROM {table} AS R + {" ".join(joins)} + {where} + """ + self.cursor.execute(sql, sha1s) + result.update( + (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() + ) + return result + def _relation_uses_location_table(self, relation: RelationType) -> bool: ... diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py index 13ca58b..b431c2e 100644 --- a/swh/provenance/provenance.py +++ b/swh/provenance/provenance.py @@ -1,264 +1,290 @@ from datetime import datetime import enum from typing import Dict, Generator, Iterable, Optional, Set, Tuple from typing_extensions import Protocol, runtime_checkable from swh.model.model import Sha1Git from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry +class EntityType(enum.Enum): + CONTENT = "content" + DIRECTORY = "directory" + REVISION = "revision" + ORIGIN = "origin" + + class RelationType(enum.Enum): CNT_EARLY_IN_REV = "content_in_revision" CNT_IN_DIR = "content_in_directory" DIR_IN_REV = "directory_in_revision" REV_IN_ORG = "revision_in_origin" REV_BEFORE_REV = "revision_before_revision" class ProvenanceResult: def __init__( self, content: Sha1Git, revision: Sha1Git, date: datetime, origin: Optional[str], path: bytes, ) -> None: self.content = content self.revision = revision self.date = date self.origin = origin self.path = path @runtime_checkable class ProvenanceStorageInterface(Protocol): raise_on_commit: bool = False def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: """Associate dates to blobs identified by sha1 ids, as paired in `dates`. Return a boolean stating whether the information was successfully stored. """ ... def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each blob sha1 in `ids`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: """Associate dates to directories identified by sha1 ids, as paired in `dates`. Return a boolean stating whether the information was successfully stored. """ ... def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: """Retrieve the associated date for each directory sha1 in `ids`. If some directory has no associated date, it is not present in the resulting dictionary. """ ... + def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]: + """Retrieve all sha1 ids for entities of type `entity` present in the provenance + model. + """ + ... + + def location_get(self) -> Set[bytes]: + """Retrieve all paths present in the provenance model.""" + ... + def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: """Associate urls to origins identified by sha1 ids, as paired in `urls`. Return a boolean stating whether the information was successfully stored. """ ... def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: """Retrieve the associated url for each origin sha1 in `ids`. If some origin has no associated date, it is not present in the resulting dictionary. """ ... def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: """Associate dates to revisions identified by sha1 ids, as paired in `dates`. Return a boolean stating whether the information was successfully stored. """ ... def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: """Associate origins to revisions identified by sha1 ids, as paired in `origins` (revision ids are keys and origin ids, values). Return a boolean stating whether the information was successfully stored. """ ... def revision_get( self, ids: Iterable[Sha1Git] ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: """Retrieve the associated date and origin for each revision sha1 in `ids`. If some revision has no associated date nor origin, it is not present in the resulting dictionary. """ ... def relation_add( self, relation: RelationType, data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], ) -> bool: """Add entries in the selected `relation`. Each tuple in `data` is of the from (`src`, `dst`, `path`), where `src` and `dst` are the sha1 ids of the entities being related, and `path` is optional depending on the selected `relation`. """ ... def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: """Retrieve all tuples in the selected `relation` whose source entities are identified by some sha1 id in `ids`. If `reverse` is set, destination entities are matched instead. """ ... + def relation_get_all( + self, relation: RelationType + ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: + """Retrieve all tuples of the form (`src`, `dst`, `path`) present in the + provenance model, where `src` and `dst` are the sha1 ids of the entities being + related, and `path` is optional depending on the selected `relation`. + """ + ... + @runtime_checkable class ProvenanceInterface(Protocol): storage: ProvenanceStorageInterface def flush(self) -> None: """Flush internal cache to the underlying `storage`.""" ... def content_add_to_directory( self, directory: DirectoryEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `directory` in the provenance model. `prefix` is the relative path from `directory` to `blob` (excluding `blob`'s name). """ ... def content_add_to_revision( self, revision: RevisionEntry, blob: FileEntry, prefix: bytes ) -> None: """Associate `blob` with `revision` in the provenance model. `prefix` is the absolute path from `revision`'s root directory to `blob` (excluding `blob`'s name). """ ... def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: """Retrieve the first occurrence of the blob identified by `id`.""" ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: """Retrieve all the occurrences of the blob identified by `id`.""" ... def content_get_early_date(self, blob: FileEntry) -> Optional[datetime]: """Retrieve the earliest known date of `blob`.""" ... def content_get_early_dates( self, blobs: Iterable[FileEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each blob in `blobs`. If some blob has no associated date, it is not present in the resulting dictionary. """ ... def content_set_early_date(self, blob: FileEntry, date: datetime) -> None: """Associate `date` to `blob` as it's earliest known date.""" ... def directory_add_to_revision( self, revision: RevisionEntry, directory: DirectoryEntry, path: bytes ) -> None: """Associate `directory` with `revision` in the provenance model. `path` is the absolute path from `revision`'s root directory to `directory` (including `directory`'s name). """ ... def directory_get_date_in_isochrone_frontier( self, directory: DirectoryEntry ) -> Optional[datetime]: """Retrieve the earliest known date of `directory` as an isochrone frontier in the provenance model. """ ... def directory_get_dates_in_isochrone_frontier( self, dirs: Iterable[DirectoryEntry] ) -> Dict[Sha1Git, datetime]: """Retrieve the earliest known date for each directory in `dirs` as isochrone frontiers provenance model. If some directory has no associated date, it is not present in the resulting dictionary. """ ... def directory_set_date_in_isochrone_frontier( self, directory: DirectoryEntry, date: datetime ) -> None: """Associate `date` to `directory` as it's earliest known date as an isochrone frontier in the provenance model. """ ... def origin_add(self, origin: OriginEntry) -> None: """Add `origin` to the provenance model.""" ... def revision_add(self, revision: RevisionEntry) -> None: """Add `revision` to the provenance model. This implies storing `revision`'s date in the model, thus `revision.date` must be a valid date. """ ... def revision_add_before_revision( self, head: RevisionEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `head` as an ancestor of the latter.""" ... def revision_add_to_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `revision` to `origin` as a head revision of the latter (ie. the target of an snapshot for `origin` in the archive).""" ... def revision_get_date(self, revision: RevisionEntry) -> Optional[datetime]: """Retrieve the date associated to `revision`.""" ... def revision_get_preferred_origin( self, revision: RevisionEntry ) -> Optional[Sha1Git]: """Retrieve the preferred origin associated to `revision`.""" ... def revision_in_history(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be an ancestor of some head revision in the provenance model. """ ... def revision_set_preferred_origin( self, origin: OriginEntry, revision: RevisionEntry ) -> None: """Associate `origin` as the preferred origin for `revision`.""" ... def revision_visited(self, revision: RevisionEntry) -> bool: """Check if `revision` is known to be a head revision for some origin in the provenance model. """ ... diff --git a/swh/provenance/tests/test_provenance_heuristics.py b/swh/provenance/tests/test_provenance_heuristics.py index f047b0c..8927f54 100644 --- a/swh/provenance/tests/test_provenance_heuristics.py +++ b/swh/provenance/tests/test_provenance_heuristics.py @@ -1,382 +1,316 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime from typing import Any, Dict, List, Optional, Set, Tuple -import psycopg2 import pytest from swh.model.hashutil import hash_to_bytes -from swh.model.model import Sha1Git from swh.provenance.archive import ArchiveInterface from swh.provenance.model import RevisionEntry from swh.provenance.postgresql.provenancedb_base import ProvenanceDBBase -from swh.provenance.provenance import ProvenanceInterface +from swh.provenance.provenance import EntityType, ProvenanceInterface, RelationType from swh.provenance.revision import revision_add from swh.provenance.tests.conftest import ( fill_storage, get_datafile, load_repo_data, synthetic_result, ) from swh.provenance.tests.test_provenance_db import ts2dt from swh.storage.postgresql.storage import Storage -def sha1s(cur: psycopg2.extensions.cursor, table: str) -> Set[Sha1Git]: - """return the 'sha1' column from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT sha1 FROM {table}") - return set(row["sha1"].hex() for row in cur.fetchall()) - - -def locations(cur: psycopg2.extensions.cursor) -> Set[bytes]: - """return the 'path' column from the DB location table - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute("SELECT encode(location.path::bytea, 'escape') AS path FROM location") - return set(row["path"] for row in cur.fetchall()) - - -def relations( - cur: psycopg2.extensions.cursor, src: str, dst: str -) -> Set[Tuple[Sha1Git, Sha1Git, bytes]]: - """return the triplets ('sha1', 'sha1', 'path') from the DB - - for the relation between 'src' table and 'dst' table - (i.e. for C-R, C-D and D-R relations). - - 'cur' is a cursor to the provenance index DB. - """ - relation = f"{src}_in_{dst}" - cur.execute("SELECT swh_get_dbflavor() AS flavor") - with_path = cur.fetchone()["flavor"] == "with-path" - - # note that the columns have the same name as the relations they refer to, - # so we can write things like "rel.{dst}=src.id" in the query below - if with_path: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - encode(location.path::bytea, 'escape') AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (relation.{src} = src.id) - INNER JOIN {dst} AS dst ON (relation.{dst} = dst.id) - INNER JOIN location ON (relation.location = location.id) - """ - ) - else: - cur.execute( - f""" - SELECT encode(src.sha1::bytea, 'hex') AS src, - encode(dst.sha1::bytea, 'hex') AS dst, - '' AS path - FROM {relation} as relation - INNER JOIN {src} AS src ON (src.id = relation.{src}) - INNER JOIN {dst} AS dst ON (dst.id = relation.{dst}) - """ - ) - return set((row["src"], row["dst"], row["path"]) for row in cur.fetchall()) - - -def get_timestamp( - cur: psycopg2.extensions.cursor, table: str, sha1: Sha1Git -) -> List[datetime]: - """return the date for the 'sha1' from the DB 'table' (as hex) - - 'cur' is a cursor to the provenance index DB. - """ - cur.execute(f"SELECT date FROM {table} WHERE sha1=%s", (sha1,)) - return [row["date"].timestamp() for row in cur.fetchall()] - - @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) def test_provenance_heuristics( provenance: ProvenanceInterface, swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(swh_storage, data) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) revisions = {rev["id"]: rev for rev in data["revision"]} rows: Dict[str, Set[Any]] = { "content": set(), "content_in_directory": set(), "content_in_revision": set(), "directory": set(), "directory_in_revision": set(), "location": set(), "revision": set(), } - assert isinstance(provenance.storage, ProvenanceDBBase) - cursor = provenance.storage.cursor - def maybe_path(path: str) -> str: + def maybe_path(path: str) -> Optional[bytes]: assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: - return path - return "" + return path.encode("utf-8") + return None for synth_rev in synthetic_result(syntheticfile): revision = revisions[synth_rev["sha1"]] entry = RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth) # each "entry" in the synth file is one new revision - rows["revision"].add(synth_rev["sha1"].hex()) - assert rows["revision"] == sha1s(cursor, "revision"), synth_rev["msg"] + rows["revision"].add(synth_rev["sha1"]) + assert rows["revision"] == provenance.storage.entity_get_all( + EntityType.REVISION + ), synth_rev["msg"] # check the timestamp of the revision rev_ts = synth_rev["date"] - assert get_timestamp(cursor, "revision", synth_rev["sha1"]) == [ - rev_ts - ], synth_rev["msg"] + rev_date, _ = provenance.storage.revision_get([synth_rev["sha1"]])[ + synth_rev["sha1"] + ] + assert rev_date is not None and rev_ts == rev_date.timestamp(), synth_rev["msg"] # this revision might have added new content objects - rows["content"] |= set(x["dst"].hex() for x in synth_rev["R_C"]) - rows["content"] |= set(x["dst"].hex() for x in synth_rev["D_C"]) - assert rows["content"] == sha1s(cursor, "content"), synth_rev["msg"] + rows["content"] |= set(x["dst"] for x in synth_rev["R_C"]) + rows["content"] |= set(x["dst"] for x in synth_rev["D_C"]) + assert rows["content"] == provenance.storage.entity_get_all( + EntityType.CONTENT + ), synth_rev["msg"] # check for R-C (direct) entries # these are added directly in the content_early_in_rev table rows["content_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"] ) - assert rows["content_in_revision"] == relations( - cursor, "content", "revision" + assert rows["content_in_revision"] == provenance.storage.relation_get_all( + RelationType.CNT_EARLY_IN_REV ), synth_rev["msg"] # check timestamps for rc in synth_rev["R_C"]: - assert get_timestamp(cursor, "content", rc["dst"]) == [ + assert ( rev_ts + rc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp() + ), synth_rev["msg"] # check directories # each directory stored in the provenance index is an entry # in the "directory" table... - rows["directory"] |= set(x["dst"].hex() for x in synth_rev["R_D"]) - assert rows["directory"] == sha1s(cursor, "directory"), synth_rev["msg"] + rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"]) + assert rows["directory"] == provenance.storage.entity_get_all( + EntityType.DIRECTORY + ), synth_rev["msg"] # ... + a number of rows in the "directory_in_rev" table... # check for R-D entries rows["directory_in_revision"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["R_D"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"] ) - assert rows["directory_in_revision"] == relations( - cursor, "directory", "revision" + assert rows["directory_in_revision"] == provenance.storage.relation_get_all( + RelationType.DIR_IN_REV ), synth_rev["msg"] # check timestamps for rd in synth_rev["R_D"]: - assert get_timestamp(cursor, "directory", rd["dst"]) == [ + assert ( rev_ts + rd["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp() + ), synth_rev["msg"] # ... + a number of rows in the "content_in_dir" table # for content of the directory. # check for D-C entries rows["content_in_directory"] |= set( - (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"])) - for x in synth_rev["D_C"] + (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"] ) - assert rows["content_in_directory"] == relations( - cursor, "content", "directory" + assert rows["content_in_directory"] == provenance.storage.relation_get_all( + RelationType.CNT_IN_DIR ), synth_rev["msg"] # check timestamps for dc in synth_rev["D_C"]: - assert get_timestamp(cursor, "content", dc["dst"]) == [ + assert ( rev_ts + dc["rel_ts"] - ], synth_rev["msg"] + == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp() + ), synth_rev["msg"] + assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: # check for location entries rows["location"] |= set(x["path"] for x in synth_rev["R_C"]) rows["location"] |= set(x["path"] for x in synth_rev["D_C"]) rows["location"] |= set(x["path"] for x in synth_rev["R_D"]) - assert rows["location"] == locations(cursor), synth_rev["msg"] + assert rows["location"] == provenance.storage.location_get(), synth_rev[ + "msg" + ] @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) def test_provenance_heuristics_content_find_all( provenance: ProvenanceInterface, swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(swh_storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] def maybe_path(path: str) -> str: assert isinstance(provenance.storage, ProvenanceDBBase) if provenance.storage.with_path: return path return "" # XXX adding all revisions at once should be working just fine, but it does not... # revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) # ...so add revisions one at a time for now for revision in revisions: revision_add(provenance, archive, [revision], lower=lower, mindepth=mindepth) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_occurrences: Dict[str, List[Tuple[str, float, Optional[str], str]]] = {} for synth_rev in synthetic_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: expected_occurrences.setdefault(rc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(rc["path"])) ) for dc in synth_rev["D_C"]: assert dc["prefix"] is not None # to please mypy expected_occurrences.setdefault(dc["dst"].hex(), []).append( (rev_id, rev_ts, None, maybe_path(dc["prefix"] + "/" + dc["path"])) ) assert isinstance(provenance.storage, ProvenanceDBBase) for content_id, results in expected_occurrences.items(): expected = [(content_id, *result) for result in results] db_occurrences = [ ( occur.content.hex(), occur.revision.hex(), occur.date.timestamp(), occur.origin, occur.path.decode(), ) for occur in provenance.content_find_all(hash_to_bytes(content_id)) ] if provenance.storage.with_path: # this is not true if the db stores no path, because a same content # that appears several times in a given revision may be reported # only once by content_find_all() assert len(db_occurrences) == len(expected) assert set(db_occurrences) == set(expected) @pytest.mark.parametrize( "repo, lower, mindepth", ( ("cmdbts2", True, 1), ("cmdbts2", False, 1), ("cmdbts2", True, 2), ("cmdbts2", False, 2), ("out-of-order", True, 1), ), ) def test_provenance_heuristics_content_find_first( provenance: ProvenanceInterface, swh_storage: Storage, archive: ArchiveInterface, repo: str, lower: bool, mindepth: int, ) -> None: # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(swh_storage, data) revisions = [ RevisionEntry( id=revision["id"], date=ts2dt(revision["date"]), root=revision["directory"], ) for revision in data["revision"] ] # XXX adding all revisions at once should be working just fine, but it does not... # revision_add(provenance, archive, revisions, lower=lower, mindepth=mindepth) # ...so add revisions one at a time for now for revision in revisions: revision_add(provenance, archive, [revision], lower=lower, mindepth=mindepth) syntheticfile = get_datafile( f"synthetic_{repo}_{'lower' if lower else 'upper'}_{mindepth}.txt" ) expected_first: Dict[str, Tuple[str, float, List[str]]] = {} # dict of tuples (blob_id, rev_id, [path, ...]) the third element for path # is a list because a content can be added at several places in a single # revision, in which case the result of content_find_first() is one of # those path, but we have no guarantee which one it will return. for synth_rev in synthetic_result(syntheticfile): rev_id = synth_rev["sha1"].hex() rev_ts = synth_rev["date"] for rc in synth_rev["R_C"]: sha1 = rc["dst"].hex() if sha1 not in expected_first: assert rc["rel_ts"] == 0 expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) else: if rev_ts == expected_first[sha1][1]: expected_first[sha1][2].append(rc["path"]) elif rev_ts < expected_first[sha1][1]: expected_first[sha1] = (rev_id, rev_ts, [rc["path"]]) for dc in synth_rev["D_C"]: sha1 = rc["dst"].hex() assert sha1 in expected_first # nothing to do there, this content cannot be a "first seen file" assert isinstance(provenance.storage, ProvenanceDBBase) for content_id, (rev_id, ts, paths) in expected_first.items(): occur = provenance.content_find_first(hash_to_bytes(content_id)) assert occur is not None assert occur.content.hex() == content_id assert occur.revision.hex() == rev_id assert occur.date.timestamp() == ts assert occur.origin is None if provenance.storage.with_path: assert occur.path.decode() in paths