diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -136,8 +136,11 @@ @click.option( "--flavor", help="Database flavor.", default=None, ) +@click.option( + "--initial-version", help="Database initial version.", default=1, show_default=True +) @click.pass_context -def db_init(ctx, module, dbname, flavor): +def db_init(ctx, module, dbname, flavor, initial_version): """Initialize a database for the Software Heritage . The database connection string comes from the configuration file (see @@ -160,8 +163,14 @@ '--db-name' option, but this usage is about to be deprecated. """ - from swh.core.db.db_utils import populate_database_for_package + from swh.core.db.db_utils import ( + get_database_info, + import_swhmodule, + populate_database_for_package, + swh_set_db_version, + ) + cfg = None if dbname is None: # use the db cnx from the config file; the expected config entry is the # given module name @@ -179,6 +188,42 @@ initialized, dbversion, dbflavor = populate_database_for_package( module, dbname, flavor ) + if dbversion is None: + if cfg is not None: + # db version has not been populated by sql init scripts (new style), + # let's do it; instantiate the data source to retrieve the current + # (expected) db version + datastore_factory = getattr(import_swhmodule(module), "get_datastore", None) + if datastore_factory: + datastore = datastore_factory(**cfg) + try: + get_current_version = datastore.get_current_version + except AttributeError: + logger.warning( + "Datastore %s does not implement the " + "'get_current_version()' method", + datastore, + ) + else: + code_version = get_current_version() + logger.info( + "Initializing database version to %s from the %s datastore", + code_version, + module, + ) + swh_set_db_version(dbname, code_version, desc="DB initialization") + + dbversion = get_database_info(dbname)[1] + if dbversion is None: + logger.info( + "Initializing database version to %s " + "from the command line option --initial-version", + initial_version, + ) + swh_set_db_version(dbname, initial_version, desc="DB initialization") + + dbversion = get_database_info(dbname)[1] + assert dbversion is not None # TODO: Ideally migrate the version from db_version to the latest # db version diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from datetime import datetime +from datetime import datetime, timezone import functools import glob from importlib import import_module @@ -25,6 +25,10 @@ logger = logging.getLogger(__name__) +def now(): + return datetime.now(tz=timezone.utc) + + def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument @@ -204,6 +208,36 @@ db.commit() +def swh_set_db_version( + db_or_conninfo: Union[str, pgconnection], + version: int, + ts: Optional[datetime] = None, + desc: str = "Work in progress", +) -> None: + """Set the version of the database. + + Fails if the dbversion table does not exists. + + Args: + db_or_conninfo: A database connection, or a database connection info string + version: the version to add + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + if ts is None: + ts = now() + with db.cursor() as c: + query = ( + "insert into dbversion(version, release, description) values (%s, %s, %s)" + ) + c.execute(query, (version, ts, desc)) + db.commit() + + def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. diff --git a/swh/core/db/tests/data/cli/50-data.sql b/swh/core/db/tests/data/cli/50-data.sql --- a/swh/core/db/tests/data/cli/50-data.sql +++ b/swh/core/db/tests/data/cli/50-data.sql @@ -1,5 +1,5 @@ insert into dbversion(version, release, description) -values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); +values (10, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); insert into origin(url, hash) values ('https://forge.softwareheritage.org', hash_sha1('https://forge.softwareheritage.org')); diff --git a/swh/core/db/tests/data/cli_new/0-superuser-init.sql b/swh/core/db/tests/data/cli_new/0-superuser-init.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/data/cli_new/0-superuser-init.sql @@ -0,0 +1 @@ +create extension if not exists pgcrypto; diff --git a/swh/core/db/tests/data/cli_new/30-schema.sql b/swh/core/db/tests/data/cli_new/30-schema.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/data/cli_new/30-schema.sql @@ -0,0 +1,6 @@ +-- origin table +create table if not exists origin ( + id bigserial not null, + url text not null, + hash text not null +); diff --git a/swh/core/db/tests/data/cli_new/40-funcs.sql b/swh/core/db/tests/data/cli_new/40-funcs.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/data/cli_new/40-funcs.sql @@ -0,0 +1,6 @@ +create or replace function hash_sha1(text) + returns text + language sql strict immutable +as $$ + select encode(public.digest($1, 'sha1'), 'hex') +$$; diff --git a/swh/core/db/tests/data/cli/50-data.sql b/swh/core/db/tests/data/cli_new/50-data.sql copy from swh/core/db/tests/data/cli/50-data.sql copy to swh/core/db/tests/data/cli_new/50-data.sql --- a/swh/core/db/tests/data/cli/50-data.sql +++ b/swh/core/db/tests/data/cli_new/50-data.sql @@ -1,5 +1,2 @@ -insert into dbversion(version, release, description) -values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); - insert into origin(url, hash) values ('https://forge.softwareheritage.org', hash_sha1('https://forge.softwareheritage.org')); diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -4,19 +4,17 @@ # See top-level LICENSE file for more information import copy -from datetime import datetime, timezone +from unittest.mock import MagicMock import pytest +import yaml from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb +from swh.core.db.db_utils import import_swhmodule from swh.core.tests.test_cli import assert_section_contains -def now(): - return datetime.now(tz=timezone.utc) - - def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) @@ -215,3 +213,46 @@ cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 + + +def test_cli_swh_db_create_and_init_db_new_api( + cli_runner, postgresql, mock_package_sql, mocker, tmp_path +): + """Create a db then initializing it should be ok for a "new style" datastore + + """ + module_name = "test.cli_new" + + def import_swhmodule_mock(modname): + if modname.startswith("test."): + + def get_datastore(cls, **kw): + # XXX probably not the best way of doing this... + return MagicMock(get_current_version=lambda: 42) + + return MagicMock(name=modname, get_datastore=get_datastore) + + return import_swhmodule(modname) + + mocker.patch("swh.core.db.db_utils.import_swhmodule", import_swhmodule_mock) + conninfo = craft_conninfo(postgresql) + + # This initializes the schema and data + cfgfile = tmp_path / "config.yml" + cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}})) + result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) + + import traceback + + assert ( + result.exit_code == 0 + ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" + + # the origin value in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, installed during db creation step) + with BaseDb.connect(conninfo).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py --- a/swh/core/db/tests/test_db_utils.py +++ b/swh/core/db/tests/test_db_utils.py @@ -9,9 +9,9 @@ from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb -from swh.core.db.db_utils import get_database_info, swh_db_module, swh_db_versions +from swh.core.db.db_utils import get_database_info, now, swh_db_module, swh_db_versions -from .test_cli import craft_conninfo, now +from .test_cli import craft_conninfo @pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) @@ -27,57 +27,52 @@ conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" - result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + result = cli_runner.invoke( + swhdb, ["init", module, "--dbname", conninfo, "--initial-version", 10] + ) assert result.exit_code == 0, f"Unexpected output: {result.output}" + # check the swh_db_module() function + assert swh_db_module(conninfo) == module + # the dbversion and dbmodule tables exists and are populated dbmodule, dbversion, dbflavor = get_database_info(conninfo) + # check also the swh_db_versions() function + versions = swh_db_versions(conninfo) + assert dbmodule == module - if module == "test.cli": - # old style: backend init script set the db version - assert dbversion == 1 - else: - # new style: they do not (but we do not have support for this in swh.core yet) - assert dbversion is None + assert dbversion == 10 assert dbflavor is None - - # check also the swh_db_module() function - assert swh_db_module(conninfo) == module - # check also the swh_db_versions() function versions = swh_db_versions(conninfo) + assert len(versions) == 1 + assert versions[0][0] == 10 if module == "test.cli": - assert len(versions) == 1 - assert versions[0][0] == 1 assert versions[0][1] == datetime.fromisoformat( "2016-02-22T15:56:28.358587+00:00" ) assert versions[0][2] == "Work In Progress" else: - assert not versions + # new scheme but with no datastore (so no version support from there) + assert versions[0][2] == "DB initialization" + # add a few versions in dbversion cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: - if module == "test.cli_new": - # add version 1 to make it simpler for checks below - cur.execute( - "insert into dbversion(version, release, description) " - "values(1, NOW(), 'Wotk in progress')" - ) cur.executemany( "insert into dbversion(version, release, description) values (%s, %s, %s)", - [(i, now(), f"Upgrade to version {i}") for i in range(2, 6)], + [(i, now(), f"Upgrade to version {i}") for i in range(11, 15)], ) dbmodule, dbversion, dbflavor = get_database_info(conninfo) assert dbmodule == module - assert dbversion == 5 + assert dbversion == 14 assert dbflavor is None versions = swh_db_versions(conninfo) assert len(versions) == 5 for i, (version, ts, desc) in enumerate(versions): - assert version == (5 - i) # these are in reverse order - if version > 1: + assert version == (14 - i) # these are in reverse order + if version > 10: assert desc == f"Upgrade to version {version}" assert (now() - ts) < timedelta(seconds=1)