Page MenuHomeSoftware Heritage

D5972.diff
No OneTemporary

D5972.diff

diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py
--- a/swh/provenance/postgresql/provenancedb_base.py
+++ b/swh/provenance/postgresql/provenancedb_base.py
@@ -10,7 +10,7 @@
from swh.core.db import BaseDb
from swh.model.model import Sha1Git
-from ..provenance import ProvenanceResult, RelationType
+from ..provenance import EntityType, ProvenanceResult, RelationType
class ProvenanceDBBase:
@@ -60,6 +60,16 @@
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]:
return self._entity_get_date("directory", ids)
+ def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
+ sql = f"SELECT sha1 FROM {entity.value}"
+ self.cursor.execute(sql)
+ return {row["sha1"] for row in self.cursor.fetchall()}
+
+ def location_get(self) -> Set[bytes]:
+ sql = "SELECT encode(location.path::bytea, 'escape') AS path FROM location"
+ self.cursor.execute(sql)
+ return {row["path"] for row in self.cursor.fetchall()}
+
def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool:
try:
if urls:
@@ -202,44 +212,12 @@
def relation_get(
self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False
) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]:
- result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set()
- sha1s = tuple(ids)
- if sha1s:
- table = relation.value
- src, *_, dst = table.split("_")
-
- # TODO: improve this!
- if src == "revision" and dst == "revision":
- src_field = "prev"
- dst_field = "next"
- else:
- src_field = src
- dst_field = dst
-
- joins = [
- f"INNER JOIN {src} AS S ON (S.id=R.{src_field})",
- f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})",
- ]
- selected = ["S.sha1 AS src", "D.sha1 AS dst"]
- selector = "S.sha1" if not reverse else "D.sha1"
-
- if self._relation_uses_location_table(relation):
- joins.append("INNER JOIN location AS L ON (L.id=R.location)")
- selected.append("L.path AS path")
- else:
- selected.append("NULL AS path")
+ return self._relation_get(relation, ids, reverse)
- sql = f"""
- SELECT {", ".join(selected)}
- FROM {table} AS R
- {" ".join(joins)}
- WHERE {selector} IN %s
- """
- self.cursor.execute(sql, (sha1s,))
- result.update(
- (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall()
- )
- return result
+ def relation_get_all(
+ self, relation: RelationType
+ ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]:
+ return self._relation_get(relation, None)
def _entity_get_date(
self,
@@ -281,5 +259,57 @@
raise
return False
+ def _relation_get(
+ self,
+ relation: RelationType,
+ ids: Optional[Iterable[Sha1Git]],
+ reverse: bool = False,
+ ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]:
+ result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set()
+
+ sha1s: Optional[Tuple[Tuple[bytes, ...]]]
+ if ids is not None:
+ sha1s = (tuple(ids),)
+ where = f"WHERE {'S.sha1' if not reverse else 'D.sha1'} IN %s"
+ else:
+ sha1s = None
+ where = ""
+
+ if sha1s is None or sha1s[0]:
+ table = relation.value
+ src, *_, dst = table.split("_")
+
+ # TODO: improve this!
+ if src == "revision" and dst == "revision":
+ src_field = "prev"
+ dst_field = "next"
+ else:
+ src_field = src
+ dst_field = dst
+
+ joins = [
+ f"INNER JOIN {src} AS S ON (S.id=R.{src_field})",
+ f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})",
+ ]
+ selected = ["S.sha1 AS src", "D.sha1 AS dst"]
+
+ if self._relation_uses_location_table(relation):
+ joins.append("INNER JOIN location AS L ON (L.id=R.location)")
+ selected.append("L.path AS path")
+ else:
+ selected.append("NULL AS path")
+
+ sql = f"""
+ SELECT {", ".join(selected)}
+ FROM {table} AS R
+ {" ".join(joins)}
+ {where}
+ """
+ self.cursor.execute(sql, sha1s)
+ result.update(
+ (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall()
+ )
+ return result
+
def _relation_uses_location_table(self, relation: RelationType) -> bool:
...
diff --git a/swh/provenance/provenance.py b/swh/provenance/provenance.py
--- a/swh/provenance/provenance.py
+++ b/swh/provenance/provenance.py
@@ -9,6 +9,13 @@
from .model import DirectoryEntry, FileEntry, OriginEntry, RevisionEntry
+class EntityType(enum.Enum):
+ CONTENT = "content"
+ DIRECTORY = "directory"
+ REVISION = "revision"
+ ORIGIN = "origin"
+
+
class RelationType(enum.Enum):
CNT_EARLY_IN_REV = "content_in_revision"
CNT_IN_DIR = "content_in_directory"
@@ -72,6 +79,16 @@
"""
...
+ def entity_get_all(self, entity: EntityType) -> Set[Sha1Git]:
+ """Retrieve all sha1 ids for entities of type `entity` present in the provenance
+ model.
+ """
+ ...
+
+ def location_get(self) -> Set[bytes]:
+ """Retrieve all paths present in the provenance model."""
+ ...
+
def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool:
"""Associate urls to origins identified by sha1 ids, as paired in `urls`. Return
a boolean stating whether the information was successfully stored.
@@ -126,6 +143,15 @@
"""
...
+ def relation_get_all(
+ self, relation: RelationType
+ ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]:
+ """Retrieve all tuples of the form (`src`, `dst`, `path`) present in the
+ provenance model, where `src` and `dst` are the sha1 ids of the entities being
+ related, and `path` is optional depending on the selected `relation`.
+ """
+ ...
+
@runtime_checkable
class ProvenanceInterface(Protocol):
diff --git a/swh/provenance/tests/test_provenance_heuristics.py b/swh/provenance/tests/test_provenance_heuristics.py
--- a/swh/provenance/tests/test_provenance_heuristics.py
+++ b/swh/provenance/tests/test_provenance_heuristics.py
@@ -3,18 +3,15 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
-from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
-import psycopg2
import pytest
from swh.model.hashutil import hash_to_bytes
-from swh.model.model import Sha1Git
from swh.provenance.archive import ArchiveInterface
from swh.provenance.model import RevisionEntry
from swh.provenance.postgresql.provenancedb_base import ProvenanceDBBase
-from swh.provenance.provenance import ProvenanceInterface
+from swh.provenance.provenance import EntityType, ProvenanceInterface, RelationType
from swh.provenance.revision import revision_add
from swh.provenance.tests.conftest import (
fill_storage,
@@ -26,77 +23,6 @@
from swh.storage.postgresql.storage import Storage
-def sha1s(cur: psycopg2.extensions.cursor, table: str) -> Set[Sha1Git]:
- """return the 'sha1' column from the DB 'table' (as hex)
-
- 'cur' is a cursor to the provenance index DB.
- """
- cur.execute(f"SELECT sha1 FROM {table}")
- return set(row["sha1"].hex() for row in cur.fetchall())
-
-
-def locations(cur: psycopg2.extensions.cursor) -> Set[bytes]:
- """return the 'path' column from the DB location table
-
- 'cur' is a cursor to the provenance index DB.
- """
- cur.execute("SELECT encode(location.path::bytea, 'escape') AS path FROM location")
- return set(row["path"] for row in cur.fetchall())
-
-
-def relations(
- cur: psycopg2.extensions.cursor, src: str, dst: str
-) -> Set[Tuple[Sha1Git, Sha1Git, bytes]]:
- """return the triplets ('sha1', 'sha1', 'path') from the DB
-
- for the relation between 'src' table and 'dst' table
- (i.e. for C-R, C-D and D-R relations).
-
- 'cur' is a cursor to the provenance index DB.
- """
- relation = f"{src}_in_{dst}"
- cur.execute("SELECT swh_get_dbflavor() AS flavor")
- with_path = cur.fetchone()["flavor"] == "with-path"
-
- # note that the columns have the same name as the relations they refer to,
- # so we can write things like "rel.{dst}=src.id" in the query below
- if with_path:
- cur.execute(
- f"""
- SELECT encode(src.sha1::bytea, 'hex') AS src,
- encode(dst.sha1::bytea, 'hex') AS dst,
- encode(location.path::bytea, 'escape') AS path
- FROM {relation} as relation
- INNER JOIN {src} AS src ON (relation.{src} = src.id)
- INNER JOIN {dst} AS dst ON (relation.{dst} = dst.id)
- INNER JOIN location ON (relation.location = location.id)
- """
- )
- else:
- cur.execute(
- f"""
- SELECT encode(src.sha1::bytea, 'hex') AS src,
- encode(dst.sha1::bytea, 'hex') AS dst,
- '' AS path
- FROM {relation} as relation
- INNER JOIN {src} AS src ON (src.id = relation.{src})
- INNER JOIN {dst} AS dst ON (dst.id = relation.{dst})
- """
- )
- return set((row["src"], row["dst"], row["path"]) for row in cur.fetchall())
-
-
-def get_timestamp(
- cur: psycopg2.extensions.cursor, table: str, sha1: Sha1Git
-) -> List[datetime]:
- """return the date for the 'sha1' from the DB 'table' (as hex)
-
- 'cur' is a cursor to the provenance index DB.
- """
- cur.execute(f"SELECT date FROM {table} WHERE sha1=%s", (sha1,))
- return [row["date"].timestamp() for row in cur.fetchall()]
-
-
@pytest.mark.parametrize(
"repo, lower, mindepth",
(
@@ -133,14 +59,12 @@
"location": set(),
"revision": set(),
}
- assert isinstance(provenance.storage, ProvenanceDBBase)
- cursor = provenance.storage.cursor
- def maybe_path(path: str) -> str:
+ def maybe_path(path: str) -> Optional[bytes]:
assert isinstance(provenance.storage, ProvenanceDBBase)
if provenance.storage.with_path:
- return path
- return ""
+ return path.encode("utf-8")
+ return None
for synth_rev in synthetic_result(syntheticfile):
revision = revisions[synth_rev["sha1"]]
@@ -152,77 +76,87 @@
revision_add(provenance, archive, [entry], lower=lower, mindepth=mindepth)
# each "entry" in the synth file is one new revision
- rows["revision"].add(synth_rev["sha1"].hex())
- assert rows["revision"] == sha1s(cursor, "revision"), synth_rev["msg"]
+ rows["revision"].add(synth_rev["sha1"])
+ assert rows["revision"] == provenance.storage.entity_get_all(
+ EntityType.REVISION
+ ), synth_rev["msg"]
# check the timestamp of the revision
rev_ts = synth_rev["date"]
- assert get_timestamp(cursor, "revision", synth_rev["sha1"]) == [
- rev_ts
- ], synth_rev["msg"]
+ rev_date, _ = provenance.storage.revision_get([synth_rev["sha1"]])[
+ synth_rev["sha1"]
+ ]
+ assert rev_date is not None and rev_ts == rev_date.timestamp(), synth_rev["msg"]
# this revision might have added new content objects
- rows["content"] |= set(x["dst"].hex() for x in synth_rev["R_C"])
- rows["content"] |= set(x["dst"].hex() for x in synth_rev["D_C"])
- assert rows["content"] == sha1s(cursor, "content"), synth_rev["msg"]
+ rows["content"] |= set(x["dst"] for x in synth_rev["R_C"])
+ rows["content"] |= set(x["dst"] for x in synth_rev["D_C"])
+ assert rows["content"] == provenance.storage.entity_get_all(
+ EntityType.CONTENT
+ ), synth_rev["msg"]
# check for R-C (direct) entries
# these are added directly in the content_early_in_rev table
rows["content_in_revision"] |= set(
- (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"]))
- for x in synth_rev["R_C"]
+ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_C"]
)
- assert rows["content_in_revision"] == relations(
- cursor, "content", "revision"
+ assert rows["content_in_revision"] == provenance.storage.relation_get_all(
+ RelationType.CNT_EARLY_IN_REV
), synth_rev["msg"]
# check timestamps
for rc in synth_rev["R_C"]:
- assert get_timestamp(cursor, "content", rc["dst"]) == [
+ assert (
rev_ts + rc["rel_ts"]
- ], synth_rev["msg"]
+ == provenance.storage.content_get([rc["dst"]])[rc["dst"]].timestamp()
+ ), synth_rev["msg"]
# check directories
# each directory stored in the provenance index is an entry
# in the "directory" table...
- rows["directory"] |= set(x["dst"].hex() for x in synth_rev["R_D"])
- assert rows["directory"] == sha1s(cursor, "directory"), synth_rev["msg"]
+ rows["directory"] |= set(x["dst"] for x in synth_rev["R_D"])
+ assert rows["directory"] == provenance.storage.entity_get_all(
+ EntityType.DIRECTORY
+ ), synth_rev["msg"]
# ... + a number of rows in the "directory_in_rev" table...
# check for R-D entries
rows["directory_in_revision"] |= set(
- (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"]))
- for x in synth_rev["R_D"]
+ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["R_D"]
)
- assert rows["directory_in_revision"] == relations(
- cursor, "directory", "revision"
+ assert rows["directory_in_revision"] == provenance.storage.relation_get_all(
+ RelationType.DIR_IN_REV
), synth_rev["msg"]
# check timestamps
for rd in synth_rev["R_D"]:
- assert get_timestamp(cursor, "directory", rd["dst"]) == [
+ assert (
rev_ts + rd["rel_ts"]
- ], synth_rev["msg"]
+ == provenance.storage.directory_get([rd["dst"]])[rd["dst"]].timestamp()
+ ), synth_rev["msg"]
# ... + a number of rows in the "content_in_dir" table
# for content of the directory.
# check for D-C entries
rows["content_in_directory"] |= set(
- (x["dst"].hex(), x["src"].hex(), maybe_path(x["path"]))
- for x in synth_rev["D_C"]
+ (x["dst"], x["src"], maybe_path(x["path"])) for x in synth_rev["D_C"]
)
- assert rows["content_in_directory"] == relations(
- cursor, "content", "directory"
+ assert rows["content_in_directory"] == provenance.storage.relation_get_all(
+ RelationType.CNT_IN_DIR
), synth_rev["msg"]
# check timestamps
for dc in synth_rev["D_C"]:
- assert get_timestamp(cursor, "content", dc["dst"]) == [
+ assert (
rev_ts + dc["rel_ts"]
- ], synth_rev["msg"]
+ == provenance.storage.content_get([dc["dst"]])[dc["dst"]].timestamp()
+ ), synth_rev["msg"]
+ assert isinstance(provenance.storage, ProvenanceDBBase)
if provenance.storage.with_path:
# check for location entries
rows["location"] |= set(x["path"] for x in synth_rev["R_C"])
rows["location"] |= set(x["path"] for x in synth_rev["D_C"])
rows["location"] |= set(x["path"] for x in synth_rev["R_D"])
- assert rows["location"] == locations(cursor), synth_rev["msg"]
+ assert rows["location"] == provenance.storage.location_get(), synth_rev[
+ "msg"
+ ]
@pytest.mark.parametrize(

File Metadata

Mime Type
text/plain
Expires
Thu, Jan 30, 1:01 PM (9 h, 26 m ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3219934

Event Timeline