diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -136,8 +136,11 @@ @click.option( "--flavor", help="Database flavor.", default=None, ) +@click.option( + "--initial-version", help="Database initial version.", default=1, show_default=True +) @click.pass_context -def db_init(ctx, module, dbname, flavor): +def db_init(ctx, module, dbname, flavor, initial_version): """Initialize a database for the Software Heritage . Example:: @@ -155,8 +158,14 @@ swh db init --flavor read_replica -d swh-storage storage """ - from swh.core.db.db_utils import populate_database_for_package + from swh.core.db.db_utils import ( + get_database_info, + import_swhmodule, + populate_database_for_package, + swh_set_db_version, + ) + cfg = None if dbname is None: # use the db cnx from the config file; the expected config entry is the # given module name @@ -174,6 +183,39 @@ initialized, dbversion, dbflavor = populate_database_for_package( module, dbname, flavor ) + if dbversion is None: + if cfg is not None: + # db version has not been populated by sql init scripts (new style), + # let's do it; instantiate the data source to retrieve the current + # (expected) db version + datastore_factory = getattr(import_swhmodule(module), "get_datastore", None) + if datastore_factory: + datastore = datastore_factory(**cfg) + try: + code_version = datastore.get_current_version() + except AttributeError: + logger.warning( + "Datastore %s does not implement the " + "'get_current_version()' method", + datastore, + ) + else: + logger.info( + f"Initializing database version to {code_version} " + f"from the {module} datastore" + ) + swh_set_db_version(dbname, code_version, desc="DB initialization") + + dbversion = get_database_info(dbname)[1] + if dbversion is None: + logger.info( + f"Initializing database version to {initial_version} " + f"from the command line option --initial-version" + ) + swh_set_db_version(dbname, initial_version, desc="DB initialization") + + dbversion = get_database_info(dbname)[1] + assert dbversion is not None # TODO: Ideally migrate the version from db_version to the latest # db version diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime +from datetime import datetime, timezone import functools import glob from importlib import import_module @@ -24,6 +24,10 @@ logger = logging.getLogger(__name__) +def now(): + return datetime.now(tz=timezone.utc) + + def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument @@ -106,7 +110,7 @@ return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) - return None + return None def swh_db_versions( @@ -203,6 +207,36 @@ db.commit() +def swh_set_db_version( + db_or_conninfo: Union[str, pgconnection], + version: int, + ts: Optional[datetime] = None, + desc: str = "Work in progress", +) -> None: + """Set the version of the database. + + Fails if the dbversion table does not exists. + + Args: + db_or_conninfo: A database connection, or a database connection info string + version: the version to add + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + if ts is None: + ts = now() + with db.cursor() as c: + query = ( + "insert into dbversion(version, release, description) values (%s, %s, %s)" + ) + c.execute(query, (version, ts, desc)) + db.commit() + + def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. diff --git a/swh/core/db/tests/data/cli/30-schema.sql b/swh/core/db/tests/data/cli/30-schema.sql --- a/swh/core/db/tests/data/cli/30-schema.sql +++ b/swh/core/db/tests/data/cli/30-schema.sql @@ -1,5 +1,5 @@ -- schema version table which won't get truncated -create table if not exists dbversion ( +create table dbversion ( version int primary key, release timestamptz, description text diff --git a/swh/core/db/tests/data/cli/50-data.sql b/swh/core/db/tests/data/cli/50-data.sql --- a/swh/core/db/tests/data/cli/50-data.sql +++ b/swh/core/db/tests/data/cli/50-data.sql @@ -1,5 +1,5 @@ insert into dbversion(version, release, description) -values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); +values (10, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); insert into origin(url, hash) values ('https://forge.softwareheritage.org', hash_sha1('https://forge.softwareheritage.org')); diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -4,20 +4,14 @@ # See top-level LICENSE file for more information import copy -from datetime import datetime, timezone import pytest from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb -from swh.core.db.pytest_plugin import postgresql_fact from swh.core.tests.test_cli import assert_section_contains -def now(): - return datetime.now(tz=timezone.utc) - - def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) @@ -55,19 +49,13 @@ assert_section_contains(result.output, section, snippet) -# We do not want the truncate behavior for those tests -test_db = postgresql_fact( - "postgresql_proc", dbname="clidb", no_truncate_tables={"dbversion", "origin"} -) - - @pytest.fixture -def swh_db_cli(cli_runner, monkeypatch, test_db): +def swh_db_cli(cli_runner, monkeypatch, postgresql): """This initializes a cli_runner and sets the correct environment variable expected by the cli to run appropriately (when not specifying the --dbname flag) """ - db_params = test_db.get_dsn_parameters() + db_params = postgresql.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) @@ -87,13 +75,13 @@ return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) -def test_cli_swh_db_create_and_init_db(cli_runner, test_db, mock_package_sql): +def test_cli_swh_db_create_and_init_db(cli_runner, postgresql, mock_package_sql): """Create a db then initializing it should be ok """ module_name = "test.cli" - conninfo = craft_conninfo(test_db, "new-db") + conninfo = craft_conninfo(postgresql, "new-db") # This creates the db and installs the necessary admin extensions result = cli_runner.invoke(swhdb, ["create", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" @@ -112,13 +100,13 @@ def test_cli_swh_db_initialization_fail_without_creation_first( - cli_runner, test_db, mock_package_sql + cli_runner, postgresql, mock_package_sql ): """Init command on an inexisting db cannot work """ module_name = "test.cli" # it's mocked here - conninfo = craft_conninfo(test_db, "inexisting-db") + conninfo = craft_conninfo(postgresql, "inexisting-db") result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails because we cannot connect to an inexisting db @@ -126,7 +114,7 @@ def test_cli_swh_db_initialization_fail_without_extension( - cli_runner, test_db, mock_package_sql + cli_runner, postgresql, mock_package_sql ): """Init command cannot work without privileged extension. @@ -134,7 +122,7 @@ """ module_name = "test.cli" # it's mocked here - conninfo = craft_conninfo(test_db) + conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails as the function `public.digest` is not installed, init-admin calls is needed @@ -143,13 +131,13 @@ def test_cli_swh_db_initialization_works_with_flags( - cli_runner, test_db, mock_package_sql + cli_runner, postgresql, mock_package_sql ): """Init commands with carefully crafted libpq conninfo works """ module_name = "test.cli" # it's mocked here - conninfo = craft_conninfo(test_db) + conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" @@ -159,13 +147,13 @@ assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) - with BaseDb.connect(test_db.dsn).cursor() as cur: + with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 -def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, test_db): +def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, postgresql): """Init commands with standard environment variables works """ @@ -183,13 +171,13 @@ assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) - with BaseDb.connect(test_db.dsn).cursor() as cur: + with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 -def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, test_db): +def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, postgresql): """Multiple runs of the init commands are idempotent """ @@ -218,7 +206,60 @@ # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) - with BaseDb.connect(test_db.dsn).cursor() as cur: + with BaseDb.connect(postgresql.dsn).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 + + +def test_cli_swh_db_create_and_init_db_new_api( + cli_runner, postgresql, mock_package_sql, mocker, tmp_path +): + """Create a db then initializing it should be ok for a "new style" datastore + + """ + module_name = "test.cli_new" + + from unittest.mock import MagicMock + + from swh.core.db.db_utils import import_swhmodule + + def import_swhmodule_mock(modname): + if modname.startswith("test."): + + def get_datastore(cls, **kw): + # XXX probably not the best way of doing this... + return MagicMock(get_current_version=lambda: 42) + + return MagicMock(name=modname, get_datastore=get_datastore) + + return import_swhmodule(modname) + + mocker.patch("swh.core.db.db_utils.import_swhmodule", import_swhmodule_mock) + conninfo = craft_conninfo(postgresql) + + # This initializes the schema and data + cfgfile = tmp_path / "config.yml" + cfgfile.write_text( + f""" +{module_name}: + cls: postgresql + db: {conninfo} +""" + ) + result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) + + import traceback + + assert ( + result.exit_code == 0 + ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" + + # the origin value in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, installed during db creation step) + with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py --- a/swh/core/db/tests/test_db_utils.py +++ b/swh/core/db/tests/test_db_utils.py @@ -4,9 +4,9 @@ from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb -from swh.core.db.db_utils import get_database_info, swh_db_module, swh_db_versions +from swh.core.db.db_utils import get_database_info, now, swh_db_module, swh_db_versions -from .test_cli import craft_conninfo, now +from .test_cli import craft_conninfo @pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) @@ -21,54 +21,53 @@ conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" - result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + result = cli_runner.invoke( + swhdb, ["init", module, "--dbname", conninfo, "--initial-version", 10] + ) assert result.exit_code == 0, f"Unexpected output: {result.output}" + # check the swh_db_module() function + assert swh_db_module(conninfo) == module + # the dbversion and dbmodule tables exists and are populated dbmodule, dbversion, dbflavor = get_database_info(conninfo) + # check also the swh_db_versions() function + versions = swh_db_versions(conninfo) + nversions = len(versions) + assert dbmodule == module - if module == "test.cli": - # old style: backend init script set the db version - assert dbversion == 1 - else: - # new style: they do not (but we do not have support for this in swh.core yet) - assert dbversion is None + assert dbversion == 10 assert dbflavor is None - - # check also the swh_db_module() function - assert swh_db_module(conninfo) == module - # check also the swh_db_versions() function versions = swh_db_versions(conninfo) + assert len(versions) == 1 + assert versions[0][0] == 10 if module == "test.cli": - assert len(versions) == 1 - assert versions[0][0] == 1 - assert versions[0][1] == datetime.fromisoformat("2016-02-22T15:56:28.358587+00:00") + assert versions[0][1] == datetime.fromisoformat( + "2016-02-22T15:56:28.358587+00:00" + ) assert versions[0][2] == "Work In Progress" else: - assert not versions + # new scheme but with no datastore (so no version support from there) + assert versions[0][2] == "DB initialization" + # add a few versions in dbversion cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: - if module == "test.cli_new": - # add version 1 to make it simpler for checkes below - cur.execute( - "insert into dbversion(version, release, description) values(1, NOW(), 'Wotk in progress')" - ) cur.executemany( "insert into dbversion(version, release, description) values (%s, %s, %s)", - [(i, now(), f"Upgrade to version {i}") for i in range(2, 6)], + [(i, now(), f"Upgrade to version {i}") for i in range(11, 15)], ) dbmodule, dbversion, dbflavor = get_database_info(conninfo) assert dbmodule == module - assert dbversion == 5 + assert dbversion == 14 assert dbflavor is None versions = swh_db_versions(conninfo) assert len(versions) == 5 for i, (version, ts, desc) in enumerate(versions): - assert version == (5 - i) # these are in reverse order - if version > 1: + assert version == (14 - i) # these are in reverse order + if version > 10: assert desc == f"Upgrade to version {version}" assert (now() - ts) < timedelta(seconds=1)