diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -6,7 +6,7 @@ import glob import logging import subprocess -from typing import Optional, Set, Union +from typing import List, Optional, Set, Union from _pytest.fixtures import FixtureRequest import psycopg2 @@ -36,14 +36,16 @@ port: str, db_name: str, version: Union[str, float, Version], - dump_files: Optional[str] = None, + dump_files: Union[None, str, List[str]] = None, no_truncate_tables: Set[str] = set(), ) -> None: super().__init__(user, host, port, db_name, version) - if dump_files: + if dump_files is None: + self.dump_files = [] + elif isinstance(dump_files, str): self.dump_files = sorted(glob.glob(dump_files), key=sortkey) else: - self.dump_files = [] + self.dump_files = dump_files # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) @@ -134,7 +136,7 @@ def postgresql_fact( process_fixture_name: str, db_name: Optional[str] = None, - dump_files: str = "", + dump_files: Union[str, List[str]] = "", no_truncate_tables: Set[str] = {"dbversion"}, ): @pytest.fixture diff --git a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py --- a/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +++ b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import glob import os from swh.core.db import BaseDb @@ -19,6 +20,13 @@ no_truncate_tables={"dbversion", "people"}, ) +postgres_fun2 = postgresql_fact( + "postgresql_proc", + db_name="fun2", + dump_files=glob.glob(f"{SQL_DIR}/*.sql"), + no_truncate_tables={"dbversion", "people"}, +) + def test_smoke_test_fun_db_is_up(postgres_fun): """This ensures the db is created and configured according to its dumps files. @@ -43,6 +51,29 @@ assert val == 2 +def test_smoke_test_fun2_db_is_up(postgres_fun2): + """This ensures the db is created and configured according to its dumps files. + + """ + with BaseDb.connect(postgres_fun2.dsn).cursor() as cur: + cur.execute("select count(*) from dbversion") + nb_rows = cur.fetchone()[0] + assert nb_rows == 5 + + cur.execute("select count(*) from fun") + nb_rows = cur.fetchone()[0] + assert nb_rows == 3 + + cur.execute("select count(*) from people") + nb_rows = cur.fetchone()[0] + assert nb_rows == 2 + + # in data, we requested a value already so it starts at 2 + cur.execute("select nextval('serial')") + val = cur.fetchone()[0] + assert val == 2 + + def test_smoke_test_fun_db_is_still_up_and_got_reset(postgres_fun): """This ensures that within another tests, the 'fun' db is still up, created (and not configured again). This time, most of the data has been reset: