diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -5,10 +5,10 @@ from datetime import datetime, timezone import functools -import glob from importlib import import_module import logging from os import path +import pathlib import re import subprocess from typing import Collection, Dict, List, Optional, Tuple, Union @@ -181,7 +181,7 @@ sqlfiles = [ fname for fname in get_sql_for_package(modname, upgrade=True) - if db_version < int(path.splitext(path.basename(fname))[0]) <= to_version + if db_version < int(fname.stem) <= to_version ] for sqlfile in sqlfiles: @@ -293,7 +293,9 @@ return None sqlfiles = [ - fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname + fname + for fname in get_sql_for_package("swh.core.db") + if "dbmodule" in fname.stem ] execute_sqlfiles(sqlfiles, db_or_conninfo) @@ -496,7 +498,7 @@ return m -def get_sql_for_package(modname: str, upgrade: bool = False) -> List[str]: +def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]: """Return the (sorted) list of sql script files for the given swh module If upgrade is True, return the list of available migration scripts, @@ -505,14 +507,15 @@ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") - sqldir = path.join(path.dirname(m.__file__), "sql") + + sqldir = pathlib.Path(m.__file__).parent / "sql" if upgrade: - sqldir += "/upgrades" - if not path.isdir(sqldir): + sqldir /= "upgrades" + if not sqldir.is_dir(): raise ValueError( - "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) + "Module {} does not provide a db schema (no sql/ dir)".format(modname) ) - return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) + return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) def populate_database_for_package( @@ -541,8 +544,8 @@ return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") - sqlfiles = sorted(sqlfiles, key=globalsortkey) - sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] + sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem)) + sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table @@ -581,7 +584,7 @@ """ sqlfiles = get_sql_for_package(modname) - sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname] + sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem] execute_sqlfiles(sqlfiles, conninfo) @@ -621,7 +624,7 @@ def execute_sqlfiles( - sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None + sqlfiles: Collection[pathlib.Path], conninfo: str, flavor: Optional[str] = None ): """Execute a list of SQL files on the database pointed at with ``conninfo``. @@ -643,9 +646,13 @@ flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") - subprocess.check_call(psql_command + ["-f", sqlfile]) + subprocess.check_call(psql_command + ["-f", str(sqlfile)]) - if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): + if ( + flavor is not None + and not flavor_set + and sqlfile.name.endswith("-flavor.sql") + ): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -3,8 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import glob import os +import pathlib from click.testing import CliRunner from hypothesis import HealthCheck @@ -45,12 +45,10 @@ def get_sql_for_package_mock(modname, upgrade=False): if modname.startswith("test."): - sqldir = modname.split(".", 1)[1] + sqldir = pathlib.Path(datadir) / modname.split(".", 1)[1] if upgrade: - sqldir += "/upgrades" - return sorted( - glob.glob(os.path.join(datadir, sqldir, "*.sql")), key=sortkey - ) + sqldir /= "upgrades" + return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) return get_sql_for_package(modname) mock_sql_files = mocker.patch(