diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py index 945d8f3..8fa2a90 100644 --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,46 +1,48 @@ from typing import TYPE_CHECKING -from .postgresql.db_utils import connect - if TYPE_CHECKING: from .archive import ArchiveInterface from .provenance import ProvenanceInterface, ProvenanceStorageInterface def get_archive(cls: str, **kwargs) -> "ArchiveInterface": if cls == "api": from swh.storage import get_storage from .storage.archive import ArchiveStorage return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": + from swh.core.db import BaseDb + from .postgresql.archive import ArchivePostgreSQL - return ArchivePostgreSQL(connect(kwargs["db"])) + return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) else: raise NotImplementedError def get_provenance(**kwargs) -> "ProvenanceInterface": from .backend import ProvenanceBackend return ProvenanceBackend(get_provenance_storage(**kwargs)) def get_provenance_storage(cls: str, **kwargs) -> "ProvenanceStorageInterface": if cls == "local": + from swh.core.db import BaseDb + from .postgresql.provenancedb_base import ProvenanceDBBase - conn = connect(kwargs["db"]) + conn = BaseDb.connect(**kwargs["db"]).conn flavor = ProvenanceDBBase(conn).flavor if flavor == "with-path": from .postgresql.provenancedb_with_path import ProvenanceWithPathDB return ProvenanceWithPathDB(conn) else: from .postgresql.provenancedb_without_path import ProvenanceWithoutPathDB return ProvenanceWithoutPathDB(conn) else: raise NotImplementedError diff --git a/swh/provenance/postgresql/db_utils.py b/swh/provenance/postgresql/db_utils.py deleted file mode 100644 index 61bace5..0000000 --- a/swh/provenance/postgresql/db_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -from configparser import ConfigParser -import io - -import psycopg2 - - -def config(filename: str, section: str): - # create a parser - parser = ConfigParser() - # read config file - parser.read(filename) - - # get section, default to postgresql - db = {} - if parser.has_section(section): - params = parser.items(section) - for param in params: - db[param[0]] = param[1] - else: - raise Exception(f"Section {section} not found in the {filename} file") - - return db - - -def typecast_bytea(value, cur): - if value is not None: - data = psycopg2.BINARY(value, cur) - return data.tobytes() - - -def adapt_conn(conn): - """Makes psycopg2 use 'bytes' to decode bytea instead of - 'memoryview', for this connection.""" - t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) - psycopg2.extensions.register_type(t_bytes, conn) - - t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) - psycopg2.extensions.register_type(t_bytes_array, conn) - - -def connect(params: dict): - """Connect to the PostgreSQL database server""" - conn = None - - try: - # connect to the PostgreSQL server - conn = psycopg2.connect(**params) - adapt_conn(conn) - - except (Exception, psycopg2.DatabaseError) as error: - print(error) - - return conn - - -def execute_sql(conn: psycopg2.extensions.connection, filename: str): - with io.open(filename) as file: - cur = conn.cursor() - cur.execute(file.read()) - cur.close() - conn.commit() diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py index 30b970c..4fbf470 100644 --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -1,283 +1,285 @@ from datetime import datetime import itertools import logging from typing import Dict, Generator, Iterable, Optional, Set, Tuple import psycopg2 import psycopg2.extras from typing_extensions import Literal +from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..provenance import ProvenanceResult, RelationType class ProvenanceDBBase: raise_on_commit: bool = False def __init__(self, conn: psycopg2.extensions.connection): + BaseDb.adapt_conn(conn) conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_session(autocommit=True) self.conn = conn self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) # XXX: not sure this is the best place to do it! sql = "SET timezone TO 'UTC'" self.cursor.execute(sql) self._flavor: Optional[str] = None @property def flavor(self) -> str: if self._flavor is None: sql = "SELECT swh_get_dbflavor() AS flavor" self.cursor.execute(sql) self._flavor = self.cursor.fetchone()["flavor"] assert self._flavor is not None return self._flavor @property def with_path(self) -> bool: return self.flavor == "with-path" def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: ... def content_find_all( self, id: Sha1Git, limit: Optional[int] = None ) -> Generator[ProvenanceResult, None, None]: ... def content_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("content", dates) def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("content", ids) def directory_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("directory", dates) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: return self._entity_get_date("directory", ids) def origin_set_url(self, urls: Dict[Sha1Git, str]) -> bool: try: if urls: sql = """ LOCK TABLE ONLY origin; INSERT INTO origin(sha1, url) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, urls.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: urls: Dict[Sha1Git, str] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, url FROM origin WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) urls.update((row["sha1"], row["url"]) for row in self.cursor.fetchall()) return urls def revision_set_date(self, dates: Dict[Sha1Git, datetime]) -> bool: return self._entity_set_date("revision", dates) def revision_set_origin(self, origins: Dict[Sha1Git, Sha1Git]) -> bool: try: if origins: sql = """ LOCK TABLE ONLY revision; INSERT INTO revision(sha1, origin) (SELECT V.rev AS sha1, O.id AS origin FROM (VALUES %s) AS V(rev, org) JOIN origin AS O ON (O.sha1=V.org)) ON CONFLICT (sha1) DO UPDATE SET origin=EXCLUDED.origin """ psycopg2.extras.execute_values(self.cursor, sql, origins.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def revision_get( self, ids: Iterable[Sha1Git] ) -> Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]]: result: Dict[Sha1Git, Tuple[Optional[datetime], Optional[Sha1Git]]] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date, origin FROM revision WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) result.update( (row["sha1"], (row["date"], row["origin"])) for row in self.cursor.fetchall() ) return result def relation_add( self, relation: RelationType, data: Iterable[Tuple[Sha1Git, Sha1Git, Optional[bytes]]], ) -> bool: try: if data: table = relation.value src, *_, dst = table.split("_") if src != "origin": # Origin entries should be inserted previously as they require extra # non-null information srcs = tuple(set((sha1,) for (sha1, _, _) in data)) sql = f""" LOCK TABLE ONLY {src}; INSERT INTO {src}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, srcs) if dst != "origin": # Origin entries should be inserted previously as they require extra # non-null information dsts = tuple(set((sha1,) for (_, sha1, _) in data)) sql = f""" LOCK TABLE ONLY {dst}; INSERT INTO {dst}(sha1) VALUES %s ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, dsts) joins = [ f"INNER JOIN {src} AS S ON (S.sha1=V.src)", f"INNER JOIN {dst} AS D ON (D.sha1=V.dst)", ] selected = ["S.id", "D.id"] if self._relation_uses_location_table(relation): locations = tuple(set((path,) for (_, _, path) in data)) sql = """ LOCK TABLE ONLY location; INSERT INTO location(path) VALUES %s ON CONFLICT (path) DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, locations) joins.append("INNER JOIN location AS L ON (L.path=V.path)") selected.append("L.id") sql = f""" INSERT INTO {table} (SELECT {", ".join(selected)} FROM (VALUES %s) AS V(src, dst, path) {''' '''.join(joins)}) ON CONFLICT DO NOTHING """ psycopg2.extras.execute_values(self.cursor, sql, data) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def relation_get( self, relation: RelationType, ids: Iterable[Sha1Git], reverse: bool = False ) -> Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]]: result: Set[Tuple[Sha1Git, Sha1Git, Optional[bytes]]] = set() sha1s = tuple(ids) if sha1s: table = relation.value src, *_, dst = table.split("_") # TODO: improve this! if src == "revision" and dst == "revision": src_field = "prev" dst_field = "next" else: src_field = src dst_field = dst joins = [ f"INNER JOIN {src} AS S ON (S.id=R.{src_field})", f"INNER JOIN {dst} AS D ON (D.id=R.{dst_field})", ] selected = ["S.sha1 AS src", "D.sha1 AS dst"] selector = "S.sha1" if not reverse else "D.sha1" if self._relation_uses_location_table(relation): joins.append("INNER JOIN location AS L ON (L.id=R.location)") selected.append("L.path AS path") else: selected.append("NULL AS path") sql = f""" SELECT {", ".join(selected)} FROM {table} AS R {" ".join(joins)} WHERE {selector} IN %s """ self.cursor.execute(sql, (sha1s,)) result.update( (row["src"], row["dst"], row["path"]) for row in self.cursor.fetchall() ) return result def _entity_get_date( self, entity: Literal["content", "directory", "revision"], ids: Iterable[Sha1Git], ) -> Dict[Sha1Git, datetime]: dates: Dict[Sha1Git, datetime] = {} sha1s = tuple(ids) if sha1s: values = ", ".join(itertools.repeat("%s", len(sha1s))) sql = f""" SELECT sha1, date FROM {entity} WHERE sha1 IN ({values}) """ self.cursor.execute(sql, sha1s) dates.update((row["sha1"], row["date"]) for row in self.cursor.fetchall()) return dates def _entity_set_date( self, entity: Literal["content", "directory", "revision"], data: Dict[Sha1Git, datetime], ) -> bool: try: if data: sql = f""" LOCK TABLE ONLY {entity}; INSERT INTO {entity}(sha1, date) VALUES %s ON CONFLICT (sha1) DO UPDATE SET date=LEAST(EXCLUDED.date,{entity}.date) """ psycopg2.extras.execute_values(self.cursor, sql, data.items()) return True except: # noqa: E722 # Unexpected error occurred, rollback all changes and log message logging.exception("Unexpected error") if self.raise_on_commit: raise return False def _relation_uses_location_table(self, relation: RelationType) -> bool: ... diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py index 7c8bbd8..53775d2 100644 --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -1,51 +1,50 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import Counter from operator import itemgetter -import psycopg2 import pytest +from swh.core.db import BaseDb from swh.provenance.postgresql.archive import ArchivePostgreSQL -from swh.provenance.postgresql.db_utils import adapt_conn from swh.provenance.storage.archive import ArchiveStorage from swh.provenance.tests.conftest import fill_storage, load_repo_data @pytest.mark.parametrize( "repo", ("cmdbts2", "out-of-order", "with-merges"), ) def test_archive_interface(repo, swh_storage): archive_api = ArchiveStorage(swh_storage) dsn = swh_storage.get_db().conn.dsn - with psycopg2.connect(dsn) as conn: - adapt_conn(conn) + with BaseDb.connect(dsn).conn as conn: + BaseDb.adapt_conn(conn) archive_direct = ArchivePostgreSQL(conn) # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) fill_storage(swh_storage, data) for directory in data["directory"]: entries_api = sorted( archive_api.directory_ls(directory["id"]), key=itemgetter("name") ) entries_direct = sorted( archive_direct.directory_ls(directory["id"]), key=itemgetter("name") ) assert entries_api == entries_direct for revision in data["revision"]: parents_api = Counter(archive_api.revision_get_parents(revision["id"])) parents_direct = Counter( archive_direct.revision_get_parents(revision["id"]) ) assert parents_api == parents_direct for snapshot in data["snapshot"]: heads_api = Counter(archive_api.snapshot_get_heads(snapshot["id"])) heads_direct = Counter(archive_direct.snapshot_get_heads(snapshot["id"])) assert heads_api == heads_direct diff --git a/swh/provenance/tests/test_cli.py b/swh/provenance/tests/test_cli.py index 744fbed..51ebefe 100644 --- a/swh/provenance/tests/test_cli.py +++ b/swh/provenance/tests/test_cli.py @@ -1,97 +1,97 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from click.testing import CliRunner -import psycopg2 import pytest from swh.core.cli import swh as swhmain import swh.core.cli.db # noqa ; ensure cli is loaded +from swh.core.db import BaseDb import swh.provenance.cli # noqa ; ensure cli is loaded def test_cli_swh_db_help(): # swhmain.add_command(provenance_cli) result = CliRunner().invoke(swhmain, ["provenance", "-h"]) assert result.exit_code == 0 assert "Commands:" in result.output commands = result.output.split("Commands:")[1] for command in ( "find-all", "find-first", "iter-origins", "iter-revisions", ): assert f" {command} " in commands TABLES = { "dbflavor", "dbversion", "content", "content_in_revision", "content_in_directory", "directory", "directory_in_revision", "location", "origin", "revision", "revision_before_revision", "revision_in_origin", } @pytest.mark.parametrize( "flavor, dbtables", (("with-path", TABLES | {"location"}), ("without-path", TABLES)) ) def test_cli_db_create_and_init_db_with_flavor( monkeypatch, postgresql, flavor, dbtables ): """Test that 'swh db init provenance' works with flavors for both with-path and without-path flavors""" dbname = f"{flavor}-db" # DB creation using 'swh db create' db_params = postgresql.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) result = CliRunner().invoke(swhmain, ["db", "create", "-d", dbname, "provenance"]) assert result.exit_code == 0, result.output # DB init using 'swh db init' result = CliRunner().invoke( swhmain, ["db", "init", "-d", dbname, "--flavor", flavor, "provenance"] ) assert result.exit_code == 0, result.output assert f"(flavor {flavor})" in result.output db_params["dbname"] = dbname - cnx = psycopg2.connect(**db_params) + cnx = BaseDb.connect(**db_params).conn # check the DB looks OK (check for db_flavor and expected tables) with cnx.cursor() as cur: cur.execute("select swh_get_dbflavor()") assert cur.fetchone() == (flavor,) cur.execute( "select table_name from information_schema.tables " "where table_schema = 'public' " f"and table_catalog = '{dbname}'" ) tables = set(x for (x,) in cur.fetchall()) assert tables == dbtables def test_cli_init_db_default_flavor(postgresql): "Test that 'swh db init provenance' defaults to a with-path flavored DB" dbname = postgresql.dsn result = CliRunner().invoke(swhmain, ["db", "init", "-d", dbname, "provenance"]) assert result.exit_code == 0, result.output with postgresql.cursor() as cur: cur.execute("select swh_get_dbflavor()") assert cur.fetchone() == ("with-path",)