diff --git a/requirements-db-pytestplugin.txt b/requirements-db-pytestplugin.txt --- a/requirements-db-pytestplugin.txt +++ b/requirements-db-pytestplugin.txt @@ -1,2 +1,2 @@ # requirements for swh.core.db.pytest_plugin -pytest-postgresql < 4.0.0 # version 4.0 depends on psycopg 3. https://github.com/ClearcodeHQ/pytest-postgresql/blob/main/CHANGES.rst#400 +pytest-postgresql >=3, < 4.0.0 # version 4.0 depends on psycopg 3. https://github.com/ClearcodeHQ/pytest-postgresql/blob/main/CHANGES.rst#400 diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -7,14 +7,22 @@ from importlib import import_module import logging import subprocess -from typing import List, Optional, Set, Union +from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Set, Union from _pytest.fixtures import FixtureRequest import psycopg2 import pytest +from pytest_postgresql.compat import check_for_psycopg2, connection +from pytest_postgresql.executor import PostgreSQLExecutor +from pytest_postgresql.executor_noop import NoopExecutor from pytest_postgresql.janitor import DatabaseJanitor -from swh.core.utils import numfile_sortkey as sortkey +from swh.core.db.db_utils import ( + init_admin_extensions, + populate_database_for_package, + swh_set_db_version, +) +from swh.core.utils import basename_sortkey # to keep mypy happy regardless pytest-postgresql version try: @@ -43,46 +51,42 @@ self, user: str, host: str, - port: str, + port: int, dbname: str, version: Union[str, float], - dump_files: Union[None, str, List[str]] = None, + password: Optional[str] = None, + isolation_level: Optional[int] = None, + connection_timeout: int = 60, + dump_files: Optional[Union[str, Sequence[str]]] = None, no_truncate_tables: Set[str] = set(), + no_db_drop: bool = False, ) -> None: super().__init__(user, host, port, dbname, version) - if not hasattr(self, "dbname") and hasattr(self, "db_name"): - # pytest_postgresql < 3.0.0 - self.dbname = getattr(self, "db_name") - if dump_files is None: - self.dump_files = [] - elif isinstance(dump_files, str): - self.dump_files = sorted(glob.glob(dump_files), key=sortkey) - else: - self.dump_files = dump_files # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) + self.no_db_drop = no_db_drop + self.dump_files = dump_files - def db_setup(self): + def psql_exec(self, fname: str) -> None: conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.dbname}" ) - for fname in self.dump_files: - subprocess.check_call( - [ - "psql", - "--quiet", - "--no-psqlrc", - "-v", - "ON_ERROR_STOP=1", - "-d", - conninfo, - "-f", - fname, - ] - ) + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + "-f", + fname, + ] + ) - def db_reset(self): + def db_reset(self) -> None: """Truncate tables (all but self.no_truncate_tables set) and sequences """ @@ -111,36 +115,54 @@ cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() - def init(self): - """Initialize db. Create the db if it does not exist. Reset it if it exists.""" - with self.cursor() as cur: - cur.execute( - "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.dbname,) - ) - db_exists = cur.fetchone()[0] == 1 - if db_exists: - cur.execute( - "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", - (self.dbname,), - ) - self.db_reset() - return + def _db_exists(self, cur, dbname): + cur.execute( + "SELECT EXISTS " + "(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);", + (dbname,), + ) + row = cur.fetchone() + return (row is not None) and row[0] - # initialize the inexistent db + def init(self) -> None: + """Create database in postgresql out of a template it if it exists, bare + creation otherwise.""" + template_name = f"{self.dbname}_tmpl" + logger.debug("Initialize DB %s", self.dbname) with self.cursor() as cur: - cur.execute('CREATE DATABASE "{}";'.format(self.dbname)) - self.db_setup() - - def drop(self): - """The original DatabaseJanitor implementation prevents new connections from happening, - destroys current opened connections and finally drops the database. - - We actually do not want to drop the db so we instead do nothing and resets - (truncate most tables and sequences) the db instead, in order to have some - acceptable performance. + tmpl_exists = self._db_exists(cur, template_name) + db_exists = self._db_exists(cur, self.dbname) + if not db_exists: + if tmpl_exists: + logger.debug( + "Create %s from template %s", self.dbname, template_name + ) + cur.execute( + f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";' + ) + else: + logger.debug("Create %s from scratch", self.dbname) + cur.execute(f'CREATE DATABASE "{self.dbname}";') + if self.dump_files: + logger.warning( + "Using dump_files on the postgresql_fact fixture " + "is deprecated. See swh.core documentation for more " + "details." + ) + for dump_file in gen_dump_files(self.dump_files): + logger.info(f"Loading {dump_file}") + self.psql_exec(dump_file) + else: + logger.debug("Reset %s", self.dbname) + self.db_reset() - """ - pass + def drop(self) -> None: + """Drop database in postgresql.""" + if self.no_db_drop: + with self.cursor() as cur: + self._terminate_connection(cur, self.dbname) + else: + super().drop() # the postgres_fact factory fixture below is mostly a copy of the code @@ -149,42 +171,112 @@ def postgresql_fact( process_fixture_name: str, dbname: Optional[str] = None, - dump_files: Union[str, List[str]] = "", + load: Optional[Sequence[Union[Callable, str]]] = None, + isolation_level: Optional[int] = None, + modname: Optional[str] = None, + dump_files: Optional[Union[str, List[str]]] = None, no_truncate_tables: Set[str] = {"dbversion"}, -): + no_db_drop: bool = False, +) -> Callable[[FixtureRequest], Iterator[connection]]: + """ + Return connection fixture factory for PostgreSQL. + + :param process_fixture_name: name of the process fixture + :param dbname: database name + :param load: SQL, function or function import paths to automatically load + into our test database + :param isolation_level: optional postgresql isolation level + defaults to server's default + :param modname: (swh) module name for which the database is created + :dump_files: (deprecated, use load instead) list of sql script files to + execute after the database has been created + :no_truncate_tables: list of table not to truncate between tests (only used + when no_db_drop is True) + :no_db_drop: if True, keep the database between tests; in which case, the + database is reset (see SWHDatabaseJanitor.db_reset()) by truncating + most of the tables. Note that this makes de facto tests (potentially) + interdependent, use with extra caution. + :returns: function which makes a connection to postgresql + """ + @pytest.fixture - def postgresql_factory(request: FixtureRequest): - """Fixture factory for PostgreSQL. + def postgresql_factory(request: FixtureRequest) -> Iterator[connection]: + """ + Fixture factory for PostgreSQL. - :param FixtureRequest request: fixture request object - :rtype: psycopg2.connection + :param request: fixture request object :returns: postgresql client """ - config = _pytest_postgresql_get_config(request) - proc_fixture = request.getfixturevalue(process_fixture_name) + check_for_psycopg2() + proc_fixture: Union[PostgreSQLExecutor, NoopExecutor] = request.getfixturevalue( + process_fixture_name + ) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user + pg_password = proc_fixture.password pg_options = proc_fixture.options - pg_db = dbname or config["dbname"] + pg_db = dbname or proc_fixture.dbname + pg_load = load or [] + assert pg_db is not None + with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, + pg_password, + isolation_level=isolation_level, dump_files=dump_files, no_truncate_tables=no_truncate_tables, - ): - connection = psycopg2.connect( + no_db_drop=no_db_drop, + ) as janitor: + db_connection: connection = psycopg2.connect( dbname=pg_db, user=pg_user, + password=pg_password, host=pg_host, port=pg_port, options=pg_options, ) - yield connection - connection.close() + for load_element in pg_load: + janitor.load(load_element) + try: + yield db_connection + finally: + db_connection.close() return postgresql_factory + + +def initialize_database_for_module(modname, version, **kwargs): + conninfo = psycopg2.connect(**kwargs).dsn + init_admin_extensions(modname, conninfo) + populate_database_for_package(modname, conninfo) + try: + swh_set_db_version(conninfo, version) + except psycopg2.errors.UniqueViolation: + logger.warn( + "Version already set by db init scripts. " + "This generally means the swh.{modname} package needs to be " + "updated for swh.core>=1.2" + ) + + +def gen_dump_files(dump_files: Union[str, Iterable[str]]) -> Iterator[str]: + """Generate files potentially resolving glob patterns if any + + """ + if isinstance(dump_files, str): + dump_files = [dump_files] + for dump_file in dump_files: + if glob.has_magic(dump_file): + # if the dump_file is a glob pattern one, resolve it + yield from ( + fname for fname in sorted(glob.glob(dump_file), key=basename_sortkey) + ) + else: + # otherwise, just return the filename + yield dump_file diff --git a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py --- a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +++ b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py @@ -6,25 +6,31 @@ import glob import os +from pytest_postgresql import factories + from swh.core.db import BaseDb -from swh.core.db.pytest_plugin import postgresql_fact +from swh.core.db.pytest_plugin import gen_dump_files, postgresql_fact SQL_DIR = os.path.join(os.path.dirname(__file__), "data") +test_postgresql_proc = factories.postgresql_proc( + dbname="fun", + load=sorted(glob.glob(f"{SQL_DIR}/*.sql")), # type: ignore[arg-type] + # type ignored because load is typed as Optional[List[...]] instead of an + # Optional[Sequence[...]] in pytest_postgresql<4 +) # db with special policy for tables dbversion and people postgres_fun = postgresql_fact( - "postgresql_proc", - dbname="fun", - dump_files=f"{SQL_DIR}/*.sql", - no_truncate_tables={"dbversion", "people"}, + "test_postgresql_proc", no_db_drop=True, no_truncate_tables={"dbversion", "people"}, ) postgres_fun2 = postgresql_fact( - "postgresql_proc", + "test_postgresql_proc", dbname="fun2", - dump_files=sorted(glob.glob(f"{SQL_DIR}/*.sql")), + load=sorted(glob.glob(f"{SQL_DIR}/*.sql")), no_truncate_tables={"dbversion", "people"}, + no_db_drop=True, ) @@ -109,9 +115,15 @@ dbname="people", dump_files=f"{SQL_DIR}/*.sql", no_truncate_tables=set(), + no_db_drop=True, ) +def test_gen_dump_files(): + files = [os.path.basename(fn) for fn in gen_dump_files(f"{SQL_DIR}/*.sql")] + assert files == ["0-schema.sql", "1-data.sql"] + + def test_smoke_test_people_db_up(postgres_people): """'people' db is up and configured diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -131,3 +131,8 @@ assert utils.numfile_sortkey("1.sql") == (1, ".sql") assert utils.numfile_sortkey("1") == (1, "") assert utils.numfile_sortkey("toto-01.sql") == (999999, "toto-01.sql") + + +def test_basename_sotkey(): + assert utils.basename_sortkey("00-xxx.sql") == (0, "-xxx.sql") + assert utils.basename_sortkey("path/to/00-xxx.sql") == (0, "-xxx.sql") diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -130,3 +130,8 @@ assert m is not None num, rem = m.groups() return (int(num) if num else 999999, rem) + + +def basename_sortkey(fname: str) -> Tuple[int, str]: + "like numfile_sortkey but on basenames" + return numfile_sortkey(os.path.basename(fname))