diff --git a/swh/provenance/__init__.py b/swh/provenance/__init__.py --- a/swh/provenance/__init__.py +++ b/swh/provenance/__init__.py @@ -1,7 +1,5 @@ from typing import TYPE_CHECKING -from .postgresql.db_utils import connect - if TYPE_CHECKING: from .archive import ArchiveInterface from .provenance import ProvenanceInterface, ProvenanceStorageInterface @@ -15,9 +13,11 @@ return ArchiveStorage(get_storage(**kwargs["storage"])) elif cls == "direct": + from swh.core.db import BaseDb + from .postgresql.archive import ArchivePostgreSQL - return ArchivePostgreSQL(connect(kwargs["db"])) + return ArchivePostgreSQL(BaseDb.connect(**kwargs["db"]).conn) else: raise NotImplementedError @@ -30,9 +30,11 @@ def get_provenance_storage(cls: str, **kwargs) -> "ProvenanceStorageInterface": if cls == "local": + from swh.core.db import BaseDb + from .postgresql.provenancedb_base import ProvenanceDBBase - conn = connect(kwargs["db"]) + conn = BaseDb.connect(**kwargs["db"]).conn flavor = ProvenanceDBBase(conn).flavor if flavor == "with-path": from .postgresql.provenancedb_with_path import ProvenanceWithPathDB diff --git a/swh/provenance/postgresql/db_utils.py b/swh/provenance/postgresql/db_utils.py deleted file mode 100644 --- a/swh/provenance/postgresql/db_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -from configparser import ConfigParser -import io - -import psycopg2 - - -def config(filename: str, section: str): - # create a parser - parser = ConfigParser() - # read config file - parser.read(filename) - - # get section, default to postgresql - db = {} - if parser.has_section(section): - params = parser.items(section) - for param in params: - db[param[0]] = param[1] - else: - raise Exception(f"Section {section} not found in the {filename} file") - - return db - - -def typecast_bytea(value, cur): - if value is not None: - data = psycopg2.BINARY(value, cur) - return data.tobytes() - - -def adapt_conn(conn): - """Makes psycopg2 use 'bytes' to decode bytea instead of - 'memoryview', for this connection.""" - t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) - psycopg2.extensions.register_type(t_bytes, conn) - - t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) - psycopg2.extensions.register_type(t_bytes_array, conn) - - -def connect(params: dict): - """Connect to the PostgreSQL database server""" - conn = None - - try: - # connect to the PostgreSQL server - conn = psycopg2.connect(**params) - adapt_conn(conn) - - except (Exception, psycopg2.DatabaseError) as error: - print(error) - - return conn - - -def execute_sql(conn: psycopg2.extensions.connection, filename: str): - with io.open(filename) as file: - cur = conn.cursor() - cur.execute(file.read()) - cur.close() - conn.commit() diff --git a/swh/provenance/postgresql/provenancedb_base.py b/swh/provenance/postgresql/provenancedb_base.py --- a/swh/provenance/postgresql/provenancedb_base.py +++ b/swh/provenance/postgresql/provenancedb_base.py @@ -7,6 +7,7 @@ import psycopg2.extras from typing_extensions import Literal +from swh.core.db import BaseDb from swh.model.model import Sha1Git from ..provenance import ProvenanceResult, RelationType @@ -16,6 +17,7 @@ raise_on_commit: bool = False def __init__(self, conn: psycopg2.extensions.connection): + BaseDb.adapt_conn(conn) conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) conn.set_session(autocommit=True) self.conn = conn diff --git a/swh/provenance/tests/test_archive_interface.py b/swh/provenance/tests/test_archive_interface.py --- a/swh/provenance/tests/test_archive_interface.py +++ b/swh/provenance/tests/test_archive_interface.py @@ -6,11 +6,10 @@ from collections import Counter from operator import itemgetter -import psycopg2 import pytest +from swh.core.db import BaseDb from swh.provenance.postgresql.archive import ArchivePostgreSQL -from swh.provenance.postgresql.db_utils import adapt_conn from swh.provenance.storage.archive import ArchiveStorage from swh.provenance.tests.conftest import fill_storage, load_repo_data @@ -22,8 +21,8 @@ def test_archive_interface(repo, swh_storage): archive_api = ArchiveStorage(swh_storage) dsn = swh_storage.get_db().conn.dsn - with psycopg2.connect(dsn) as conn: - adapt_conn(conn) + with BaseDb.connect(dsn).conn as conn: + BaseDb.adapt_conn(conn) archive_direct = ArchivePostgreSQL(conn) # read data/README.md for more details on how these datasets are generated data = load_repo_data(repo) diff --git a/swh/provenance/tests/test_cli.py b/swh/provenance/tests/test_cli.py --- a/swh/provenance/tests/test_cli.py +++ b/swh/provenance/tests/test_cli.py @@ -4,11 +4,11 @@ # See top-level LICENSE file for more information from click.testing import CliRunner -import psycopg2 import pytest from swh.core.cli import swh as swhmain import swh.core.cli.db # noqa ; ensure cli is loaded +from swh.core.db import BaseDb import swh.provenance.cli # noqa ; ensure cli is loaded @@ -71,7 +71,7 @@ assert f"(flavor {flavor})" in result.output db_params["dbname"] = dbname - cnx = psycopg2.connect(**db_params) + cnx = BaseDb.connect(**db_params).conn # check the DB looks OK (check for db_flavor and expected tables) with cnx.cursor() as cur: cur.execute("select swh_get_dbflavor()")