diff --git a/swh/storage/tests/conftest.py b/swh/storage/tests/conftest.py --- a/swh/storage/tests/conftest.py +++ b/swh/storage/tests/conftest.py @@ -6,8 +6,10 @@ import glob import pytest +from typing import Union + from pytest_postgresql import factories -from pytest_postgresql.janitor import DatabaseJanitor, psycopg2 +from pytest_postgresql.janitor import DatabaseJanitor, psycopg2, Version from os import path, environ from hypothesis import settings @@ -72,7 +74,7 @@ # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. -def postgresql_fact(process_fixture_name, db_name=None): +def postgresql_fact(process_fixture_name, db_name=None, dump_files=DUMP_FILES): @pytest.fixture def postgresql_factory(request): """ @@ -95,9 +97,9 @@ pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config['dbname'] - with SwhDatabaseJanitor( - pg_user, pg_host, pg_port, pg_db, proc_fixture.version + pg_user, pg_host, pg_port, pg_db, proc_fixture.version, + dump_files=dump_files ): connection = psycopg2.connect( dbname=pg_db, @@ -121,6 +123,19 @@ # once, then it truncate the tables. This is needed to have acceptable test # performances. class SwhDatabaseJanitor(DatabaseJanitor): + def __init__( + self, + user: str, + host: str, + port: str, + db_name: str, + version: Union[str, float, Version], + dump_files: str = DUMP_FILES + ) -> None: + super().__init__(user, host, port, db_name, version) + self.dump_files = sorted( + glob.glob(dump_files), key=sortkey) + def db_setup(self): with psycopg2.connect( dbname=self.db_name, @@ -129,12 +144,11 @@ port=self.port, ) as cnx: with cnx.cursor() as cur: - all_dump_files = sorted( - glob.glob(DUMP_FILES), key=sortkey) - for fname in all_dump_files: + for fname in self.dump_files: with open(fname) as fobj: - sql = fobj.read().replace('concurrently', '') - cur.execute(sql) + sql = fobj.read().replace('concurrently', '').strip() + if sql: + cur.execute(sql) cnx.commit() def db_reset(self):