diff --git a/MANIFEST.in b/MANIFEST.in --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,6 +2,8 @@ include conftest.py include requirements*.txt include version.txt +recursive-include swh/core/db/sql *.sql recursive-include swh py.typed +recursive-include swh/core/db/tests/data/ * recursive-include swh/core/tests/data/ * recursive-include swh/core/tests/fixture/data/ * diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py old mode 100644 new mode 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -197,6 +197,73 @@ ) +@db.command(name="version", context_settings=CONTEXT_SETTINGS) +@click.argument("module", required=True) +@click.option( + "--all/--no-all", + "show_all", + help="Show version history.", + default=False, + show_default=True, +) +@click.pass_context +def db_version(ctx, module, show_all): + """Print the database version for the Software Heritage. + + Example:: + + swh db version -d swh-test + + """ + from swh.core.db.db_utils import get_database_info, import_swhmodule + + # use the db cnx from the config file; the expected config entry is the + # given module name + cfg = ctx.obj["config"].get(module, {}) + dbname = get_dburl_from_config(cfg) + + if not dbname: + raise click.BadParameter( + "Missing the postgresql connection configuration. Either fix your " + "configuration file or use the --dbname option." + ) + + logger.debug("db_version dbname=%s", dbname) + + db_module, db_version, db_flavor = get_database_info(dbname) + if db_module is None: + click.secho( + "WARNING the database does not have a dbmodule table.", fg="red", bold=True + ) + db_module = module + assert db_module == module, f"{db_module} (in the db) != {module} (given)" + + click.secho(f"module: {db_module}", fg="green", bold=True) + + if db_flavor is not None: + click.secho(f"flavor: {db_flavor}", fg="green", bold=True) + + # instantiate the data source to retrieve the current (expected) db version + datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) + if datastore_factory: + datastore = datastore_factory(**cfg) + code_version = datastore.get_current_version() + click.secho( + f"current code version: {code_version}", + fg="green" if code_version == db_version else "red", + bold=True, + ) + + if not show_all: + click.secho(f"version: {db_version}", fg="green", bold=True) + else: + from swh.core.db.db_utils import swh_db_versions + + versions = swh_db_versions(dbname) + for version, tstamp, desc in versions: + click.echo(f"{version} [{tstamp}] {desc}") + + def get_dburl_from_config(cfg): if cfg.get("cls") != "postgresql": raise click.BadParameter( diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -3,13 +3,14 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from datetime import datetime import functools import glob from importlib import import_module import logging from os import path import re -from typing import Collection, Dict, Optional, Tuple, Union +from typing import Collection, Dict, List, Optional, Tuple, Union import psycopg2 import psycopg2.extensions @@ -98,13 +99,109 @@ query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) - return c.fetchone()[0] + result = c.fetchone() + if result: + return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) + return None + + +def swh_db_versions( + db_or_conninfo: Union[str, pgconnection] +) -> Optional[List[Tuple[int, datetime, str]]]: + """Retrieve the swh version history of the database. + + If the database is not initialized, this logs a warning and returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Either the version of the database, or None if it couldn't be detected + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized return None + try: + with db.cursor() as c: + query = ( + "select version, release, description " + "from dbversion order by dbversion desc" + ) + try: + c.execute(query) + return c.fetchall() + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get versions from `%s`", db_or_conninfo) + return None + + +def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: + """Retrieve the swh module used to create the database. + + If the database is not initialized, this logs a warning and returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Either the module of the database, or None if it couldn't be detected + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + try: + with db.cursor() as c: + query = "select dbmodule from dbmodule limit 1" + try: + c.execute(query) + resp = c.fetchone() + if resp: + return resp[0] + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get module from `%s`", db_or_conninfo) + return None + + +def swh_set_db_module(db_or_conninfo: Union[str, pgconnection], module: str) -> None: + """Set the swh module used to create the database. + + Fails if the dbmodule is already set or the table does not exists. + + Args: + db_or_conninfo: A database connection, or a database connection info string + module: the swh module to register (without the leading 'swh.') + """ + if module.startswith("swh."): + module = module[4:] + + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + with db.cursor() as c: + query = "insert into dbmodule(dbmodule) values (%s)" + c.execute(query, (module,)) + db.commit() + def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. @@ -280,7 +377,7 @@ def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None -) -> Tuple[bool, int, Optional[str]]: +) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. @@ -299,14 +396,33 @@ dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor - sqlfiles = get_sql_for_package(modname) + def globalsortkey(key): + "like sortkey but only on basenames" + return sortkey(path.basename(key)) + + sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") + sqlfiles = sorted(sqlfiles, key=globalsortkey) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] execute_sqlfiles(sqlfiles, conninfo, flavor) - current_version = swh_db_version(conninfo) - assert current_version is not None + # populate the dbmodule table + swh_set_db_module(conninfo, modname) + + current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) - return True, current_version, dbflavor + return True, current_db_version, dbflavor + + +def get_database_info( + conninfo: str, +) -> Tuple[Optional[str], Optional[int], Optional[str]]: + """Get version, flavor and module of the db""" + dbmodule = swh_db_module(conninfo) + dbversion = swh_db_version(conninfo) + dbflavor = None + if dbversion is not None: + dbflavor = swh_db_flavor(conninfo) + return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: diff --git a/swh/core/db/sql/35-dbmetadata.sql b/swh/core/db/sql/35-dbmetadata.sql new file mode 100644 --- /dev/null +++ b/swh/core/db/sql/35-dbmetadata.sql @@ -0,0 +1,28 @@ +-- common metadata/context structures +-- +-- we use a 35- prefix for this to make it executed after db schema initialisation +-- sql scripts, which are normally 30- prefixed, so that it remains compatible +-- with packages that have not yet migrated to swh.core 1.2 + +-- schema versions +create table if not exists dbversion +( + version int primary key, + release timestamptz, + description text +); + +comment on table dbversion is 'Details of current db version'; +comment on column dbversion.version is 'SQL schema version'; +comment on column dbversion.release is 'Version deployment timestamp'; +comment on column dbversion.description is 'Release description'; + +-- swh module this db is storing data for +create table if not exists dbmodule ( + dbmodule text, + single_row char(1) primary key default 'x', + check (single_row = 'x') +); +comment on table dbmodule is 'Database module storage'; +comment on column dbmodule.dbmodule is 'Database (swh) module currently deployed'; +comment on column dbmodule.single_row is 'Bogus column to force the table to have a single row'; diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,6 +1,12 @@ +import glob import os +from click.testing import CliRunner from hypothesis import HealthCheck +import pytest + +from swh.core.db.db_utils import get_sql_for_package +from swh.core.utils import numfile_sortkey as sortkey os.environ["LC_ALL"] = "C.UTF-8" @@ -10,3 +16,37 @@ if hasattr(HealthCheck, "function_scoped_fixture") else [] ) + + +@pytest.fixture +def cli_runner(): + return CliRunner() + + +@pytest.fixture() +def mock_package_sql(mocker, datadir): + """This bypasses the module manipulation to only returns the data test files. + + For a given module `test.mod`, look for sql files in the directory `data/mod/*.sql`. + + Typical usage:: + + def test_xxx(cli_runner, mock_package_sql): + conninfo = craft_conninfo(test_db, "new-db") + module_name = "test.cli" + # the command below will use sql scripts from swh/core/db/tests/data/cli/*.sql + cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) + """ + + def get_sql_for_package_mock(modname): + if modname.startswith("test."): + sqldir = modname.split(".", 1)[1] + return sorted( + glob.glob(os.path.join(datadir, sqldir, "*.sql")), key=sortkey + ) + return get_sql_for_package(modname) + + mock_sql_files = mocker.patch( + "swh.core.db.db_utils.get_sql_for_package", get_sql_for_package_mock + ) + return mock_sql_files diff --git a/swh/core/db/tests/data/cli/1-schema.sql b/swh/core/db/tests/data/cli/30-schema.sql rename from swh/core/db/tests/data/cli/1-schema.sql rename to swh/core/db/tests/data/cli/30-schema.sql diff --git a/swh/core/db/tests/data/cli/3-func.sql b/swh/core/db/tests/data/cli/40-funcs.sql rename from swh/core/db/tests/data/cli/3-func.sql rename to swh/core/db/tests/data/cli/40-funcs.sql diff --git a/swh/core/db/tests/data/cli/4-data.sql b/swh/core/db/tests/data/cli/50-data.sql rename from swh/core/db/tests/data/cli/4-data.sql rename to swh/core/db/tests/data/cli/50-data.sql diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -4,10 +4,8 @@ # See top-level LICENSE file for more information import copy -import glob -from os import path +from datetime import datetime, timezone -from click.testing import CliRunner import pytest from swh.core.cli.db import db as swhdb @@ -16,9 +14,8 @@ from swh.core.tests.test_cli import assert_section_contains -@pytest.fixture -def cli_runner(): - return CliRunner() +def now(): + return datetime.now(tz=timezone.utc) def test_cli_swh_help(swhmain, cli_runner): @@ -58,19 +55,6 @@ assert_section_contains(result.output, section, snippet) -@pytest.fixture() -def mock_package_sql(mocker, datadir): - """This bypasses the module manipulation to only returns the data test files. - - """ - from swh.core.utils import numfile_sortkey as sortkey - - mock_sql_files = mocker.patch("swh.core.db.db_utils.get_sql_for_package") - sql_files = sorted(glob.glob(path.join(datadir, "cli", "*.sql")), key=sortkey) - mock_sql_files.return_value = sql_files - return mock_sql_files - - # We do not want the truncate behavior for those tests test_db = postgresql_fact( "postgresql_proc", dbname="clidb", no_truncate_tables={"dbversion", "origin"} @@ -107,7 +91,7 @@ """Create a db then initializing it should be ok """ - module_name = "something" + module_name = "test.cli" conninfo = craft_conninfo(test_db, "new-db") # This creates the db and installs the necessary admin extensions @@ -133,7 +117,7 @@ """Init command on an inexisting db cannot work """ - module_name = "anything" # it's mocked here + module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db, "inexisting-db") result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) @@ -149,7 +133,7 @@ In this test, the schema needs privileged extension to work. """ - module_name = "anything" # it's mocked here + module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db) result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) @@ -164,7 +148,7 @@ """Init commands with carefully crafted libpq conninfo works """ - module_name = "anything" # it's mocked here + module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) @@ -185,7 +169,7 @@ """Init commands with standard environment variables works """ - module_name = "anything" # it's mocked here + module_name = "test.cli" # it's mocked here cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] @@ -209,7 +193,7 @@ """Multiple runs of the init commands are idempotent """ - module_name = "anything" # mocked + module_name = "test.cli" # mocked cli_runner, db_params = swh_db_cli result = cli_runner.invoke( diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py new file mode 100644 --- /dev/null +++ b/swh/core/db/tests/test_db_utils.py @@ -0,0 +1,77 @@ +from datetime import datetime, timedelta + +import pytest + +from swh.core.cli.db import db as swhdb +from swh.core.db import BaseDb +from swh.core.db.db_utils import get_database_info, swh_db_module, swh_db_versions + +from .test_cli import craft_conninfo, now + + +@pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) +def test_db_utils_versions(cli_runner, postgresql, mock_package_sql, module): + """Check get_database_info, swh_db_versions and swh_db_module work ok + + This test test both a db with "new style" set of sql init scripts (i.e. the + dbversion table is not created in these scripts, but by the + populate_database_for_package() function directly, via the 'swh db init' + command) and an "old style" set (dbversion created in the scripts)S. + """ + conninfo = craft_conninfo(postgresql) + result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + # the dbversion and dbmodule tables exists and are populated + dbmodule, dbversion, dbflavor = get_database_info(conninfo) + assert dbmodule == module + if module == "test.cli": + # old style: backend init script set the db version + assert dbversion == 1 + else: + # new style: they do not (but we do not have support for this in swh.core yet) + assert dbversion is None + assert dbflavor is None + + # check also the swh_db_module() function + assert swh_db_module(conninfo) == module + + # check also the swh_db_versions() function + versions = swh_db_versions(conninfo) + if module == "test.cli": + assert len(versions) == 1 + assert versions[0][0] == 1 + assert versions[0][1] == datetime.fromisoformat( + "2016-02-22T15:56:28.358587+00:00" + ) + assert versions[0][2] == "Work In Progress" + else: + assert not versions + # add a few versions in dbversion + cnx = BaseDb.connect(conninfo) + with cnx.transaction() as cur: + if module == "test.cli_new": + # add version 1 to make it simpler for checks below + cur.execute( + "insert into dbversion(version, release, description) " + "values(1, NOW(), 'Wotk in progress')" + ) + cur.executemany( + "insert into dbversion(version, release, description) values (%s, %s, %s)", + [(i, now(), f"Upgrade to version {i}") for i in range(2, 6)], + ) + + dbmodule, dbversion, dbflavor = get_database_info(conninfo) + assert dbmodule == module + assert dbversion == 5 + assert dbflavor is None + + versions = swh_db_versions(conninfo) + assert len(versions) == 5 + for i, (version, ts, desc) in enumerate(versions): + assert version == (5 - i) # these are in reverse order + if version > 1: + assert desc == f"Upgrade to version {version}" + assert (now() - ts) < timedelta(seconds=1)