diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information import glob +from importlib import import_module import logging import subprocess from typing import List, Optional, Set, Union @@ -11,15 +12,19 @@ from _pytest.fixtures import FixtureRequest import psycopg2 import pytest -from pytest_postgresql.janitor import DatabaseJanitor, Version +from pytest_postgresql.janitor import DatabaseJanitor +from swh.core.utils import numfile_sortkey as sortkey + +# to keep mypy happy regardless pytest-postgresql version try: - from pytest_postgresql.config import get_config as _pytest_postgresql_get_config + _pytest_pgsql_get_config_module = import_module("pytest_postgresql.config") except ImportError: # pytest_postgresql < 3.0.0 - from pytest_postgresql.factories import get_config as _pytest_postgresql_get_config + _pytest_pgsql_get_config_module = import_module("pytest_postgresql.factories") + +_pytest_postgresql_get_config = getattr(_pytest_pgsql_get_config_module, "get_config") -from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) @@ -40,14 +45,14 @@ host: str, port: str, dbname: str, - version: Union[str, float, Version], + version: Union[str, float], dump_files: Union[None, str, List[str]] = None, no_truncate_tables: Set[str] = set(), ) -> None: super().__init__(user, host, port, dbname, version) - if not hasattr(self, "dbname"): + if not hasattr(self, "dbname") and hasattr(self, "db_name"): # pytest_postgresql < 3.0.0 - self.dbname = self.db_name + self.dbname = getattr(self, "db_name") if dump_files is None: self.dump_files = [] elif isinstance(dump_files, str):