diff --git a/MANIFEST.in b/MANIFEST.in index 9199b48..483b7ad 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,9 @@ include Makefile include conftest.py include requirements*.txt include version.txt +recursive-include swh/core/db/sql *.sql recursive-include swh py.typed recursive-include swh/core/db/tests/data/ * recursive-include swh/core/tests/data/ * recursive-include swh/core/tests/fixture/data/ * diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 41c8172..f062e96 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,213 +1,280 @@ #!/usr/bin/env python3 -# Copyright (C) 2018-2020 The Software Heritage developers +# Copyright (C) 2018-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ import warnings import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @swh_cli_group.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools.""" from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="create", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--template", "-T", help="Template database from which to build this database.", default="template1", show_default=True, ) def db_create(module, dbname, template): """Create a database for the Software Heritage . and potentially execute superuser-level initialization steps. Example:: swh db create -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions. Example:: PGPORT=5434 swh db create indexer swh db create -d postgresql://superuser:passwd@pghost:5433/swh-storage storage """ from swh.core.db.db_utils import create_database_for_package logger.debug("db_create %s dn_name=%s", module, dbname) create_database_for_package(module, dbname, template) @db.command(name="init-admin", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) def db_init_admin(module: str, dbname: str) -> None: """Execute superuser-level initialization steps (e.g pg extensions, admin functions, ...) Example:: PGPASSWORD=... swh db init-admin -d swh-test scheduler If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions (e.g postgres, swh-admin, ...) Example:: PGPORT=5434 swh db init-admin scheduler swh db init-admin -d postgresql://superuser:passwd@pghost:5433/swh-scheduler \ scheduler """ from swh.core.db.db_utils import init_admin_extensions logger.debug("db_init_admin %s dbname=%s", module, dbname) init_admin_extensions(module, dbname) @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name or connection URI.", default=None, show_default=False, ) @click.option( "--flavor", help="Database flavor.", default=None, ) @click.pass_context def db_init(ctx, module, dbname, flavor): """Initialize a database for the Software Heritage . The database connection string comes from the configuration file (see option ``--config-file`` in ``swh db --help``) in the section named after the MODULE argument. Example:: $ cat conf.yml storage: cls: postgresql db: postgresql://user:passwd@pghost:5433/swh-storage objstorage: cls: memory $ swh db -C conf.yml init storage # or $ SWH_CONFIG_FILENAME=conf.yml swh db init storage Note that the connection string can also be passed directly using the '--db-name' option, but this usage is about to be deprecated. """ from swh.core.db.db_utils import populate_database_for_package if dbname is None: # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_init %s flavor=%s dbname=%s", module, flavor, dbname) initialized, dbversion, dbflavor = populate_database_for_package( module, dbname, flavor ) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {}{} at version {}".format( module, "initialized" if initialized else "exists", f" (flavor {dbflavor})" if dbflavor is not None else "", dbversion, ), fg="green", bold=True, ) if flavor is not None and dbflavor != flavor: click.secho( f"WARNING requested flavor '{flavor}' != recorded flavor '{dbflavor}'", fg="red", bold=True, ) +@db.command(name="version", context_settings=CONTEXT_SETTINGS) +@click.argument("module", required=True) +@click.option( + "--all/--no-all", + "show_all", + help="Show version history.", + default=False, + show_default=True, +) +@click.pass_context +def db_version(ctx, module, show_all): + """Print the database version for the Software Heritage. + + Example:: + + swh db version -d swh-test + + """ + from swh.core.db.db_utils import get_database_info, import_swhmodule + + # use the db cnx from the config file; the expected config entry is the + # given module name + cfg = ctx.obj["config"].get(module, {}) + dbname = get_dburl_from_config(cfg) + + if not dbname: + raise click.BadParameter( + "Missing the postgresql connection configuration. Either fix your " + "configuration file or use the --dbname option." + ) + + logger.debug("db_version dbname=%s", dbname) + + db_module, db_version, db_flavor = get_database_info(dbname) + if db_module is None: + click.secho( + "WARNING the database does not have a dbmodule table.", fg="red", bold=True + ) + db_module = module + assert db_module == module, f"{db_module} (in the db) != {module} (given)" + + click.secho(f"module: {db_module}", fg="green", bold=True) + + if db_flavor is not None: + click.secho(f"flavor: {db_flavor}", fg="green", bold=True) + + # instantiate the data source to retrieve the current (expected) db version + datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) + if datastore_factory: + datastore = datastore_factory(**cfg) + code_version = datastore.get_current_version() + click.secho( + f"current code version: {code_version}", + fg="green" if code_version == db_version else "red", + bold=True, + ) + + if not show_all: + click.secho(f"version: {db_version}", fg="green", bold=True) + else: + from swh.core.db.db_utils import swh_db_versions + + versions = swh_db_versions(dbname) + for version, tstamp, desc in versions: + click.echo(f"{version} [{tstamp}] {desc}") + + def get_dburl_from_config(cfg): if cfg.get("cls") != "postgresql": raise click.BadParameter( "Configuration cls must be set to 'postgresql' for this command." ) if "args" in cfg: # for bw compat cfg = cfg["args"] return cfg.get("db") diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 381c6ea..ed7edd4 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,402 +1,518 @@ -# Copyright (C) 2015-2020 The Software Heritage developers +# Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from datetime import datetime import functools import glob from importlib import import_module import logging from os import path import re import subprocess -from typing import Collection, Dict, Optional, Tuple, Union +from typing import Collection, Dict, List, Optional, Tuple, Union import psycopg2 import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection: """Connect to the database passed in argument Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle Raises: psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, pgconnection): return db_or_conninfo if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" db = psycopg2.connect(db_or_conninfo) return db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) - return c.fetchone()[0] + result = c.fetchone() + if result: + return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) + return None + + +def swh_db_versions( + db_or_conninfo: Union[str, pgconnection] +) -> Optional[List[Tuple[int, datetime, str]]]: + """Retrieve the swh version history of the database. + + If the database is not initialized, this logs a warning and returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Either the version of the database, or None if it couldn't be detected + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized return None + try: + with db.cursor() as c: + query = ( + "select version, release, description " + "from dbversion order by dbversion desc" + ) + try: + c.execute(query) + return c.fetchall() + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get versions from `%s`", db_or_conninfo) + return None + + +def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: + """Retrieve the swh module used to create the database. + + If the database is not initialized, this logs a warning and returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Either the module of the database, or None if it couldn't be detected + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + try: + with db.cursor() as c: + query = "select dbmodule from dbmodule limit 1" + try: + c.execute(query) + resp = c.fetchone() + if resp: + return resp[0] + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get module from `%s`", db_or_conninfo) + return None + + +def swh_set_db_module(db_or_conninfo: Union[str, pgconnection], module: str) -> None: + """Set the swh module used to create the database. + + Fails if the dbmodule is already set or the table does not exist. + + Args: + db_or_conninfo: A database connection, or a database connection info string + module: the swh module to register (without the leading 'swh.') + """ + if module.startswith("swh."): + module = module[4:] + + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + with db.cursor() as c: + query = "insert into dbmodule(dbmodule) values (%s)" + c.execute(query, (module,)) + db.commit() + def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m def get_sql_for_package(modname): m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise ValueError( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None -) -> Tuple[bool, int, Optional[str]]: +) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor - sqlfiles = get_sql_for_package(modname) + def globalsortkey(key): + "like sortkey but only on basenames" + return sortkey(path.basename(key)) + + sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") + sqlfiles = sorted(sqlfiles, key=globalsortkey) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] execute_sqlfiles(sqlfiles, conninfo, flavor) - current_version = swh_db_version(conninfo) - assert current_version is not None + # populate the dbmodule table + swh_set_db_module(conninfo, modname) + + current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) - return True, current_version, dbflavor + return True, current_db_version, dbflavor + + +def get_database_info( + conninfo: str, +) -> Tuple[Optional[str], Optional[int], Optional[str]]: + """Get version, flavor and module of the db""" + dbmodule = swh_db_module(conninfo) + dbversion = swh_db_version(conninfo) + dbflavor = None + if dbversion is not None: + dbflavor = swh_db_flavor(conninfo) + return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None ): """Execute a list of SQL files on the database pointed at with ``conninfo``. Args: sqlfiles: List of SQL files to execute conninfo: connection info string for the SQL database flavor: the database flavor to initialize """ psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", sqlfile]) if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/sql/35-dbmetadata.sql b/swh/core/db/sql/35-dbmetadata.sql new file mode 100644 index 0000000..3e49760 --- /dev/null +++ b/swh/core/db/sql/35-dbmetadata.sql @@ -0,0 +1,28 @@ +-- common metadata/context structures +-- +-- we use a 35- prefix for this to make it executed after db schema initialisation +-- sql scripts, which are normally 30- prefixed, so that it remains compatible +-- with packages that have not yet migrated to swh.core 1.2 + +-- schema versions +create table if not exists dbversion +( + version int primary key, + release timestamptz, + description text +); + +comment on table dbversion is 'Details of current db version'; +comment on column dbversion.version is 'SQL schema version'; +comment on column dbversion.release is 'Version deployment timestamp'; +comment on column dbversion.description is 'Release description'; + +-- swh module this db is storing data for +create table if not exists dbmodule ( + dbmodule text, + single_row char(1) primary key default 'x', + check (single_row = 'x') +); +comment on table dbmodule is 'Database module storage'; +comment on column dbmodule.dbmodule is 'Database (swh) module currently deployed'; +comment on column dbmodule.single_row is 'Bogus column to force the table to have a single row'; diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py index 7be81e3..d87f8d3 100644 --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,12 +1,57 @@ +# Copyright (C) 2019-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import glob import os +from click.testing import CliRunner from hypothesis import HealthCheck +import pytest + +from swh.core.db.db_utils import get_sql_for_package +from swh.core.utils import numfile_sortkey as sortkey os.environ["LC_ALL"] = "C.UTF-8" # we use getattr here to keep mypy happy regardless hypothesis version function_scoped_fixture_check = ( [getattr(HealthCheck, "function_scoped_fixture")] if hasattr(HealthCheck, "function_scoped_fixture") else [] ) + + +@pytest.fixture +def cli_runner(): + return CliRunner() + + +@pytest.fixture() +def mock_package_sql(mocker, datadir): + """This bypasses the module manipulation to only returns the data test files. + + For a given module `test.mod`, look for sql files in the directory `data/mod/*.sql`. + + Typical usage:: + + def test_xxx(cli_runner, mock_package_sql): + conninfo = craft_conninfo(test_db, "new-db") + module_name = "test.cli" + # the command below will use sql scripts from swh/core/db/tests/data/cli/*.sql + cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) + """ + + def get_sql_for_package_mock(modname): + if modname.startswith("test."): + sqldir = modname.split(".", 1)[1] + return sorted( + glob.glob(os.path.join(datadir, sqldir, "*.sql")), key=sortkey + ) + return get_sql_for_package(modname) + + mock_sql_files = mocker.patch( + "swh.core.db.db_utils.get_sql_for_package", get_sql_for_package_mock + ) + return mock_sql_files diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index 312f2a2..a944608 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,257 +1,224 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy -import glob -from os import path +from datetime import datetime, timezone -from click.testing import CliRunner import pytest from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb from swh.core.db.pytest_plugin import postgresql_fact from swh.core.tests.test_cli import assert_section_contains -@pytest.fixture -def cli_runner(): - return CliRunner() +def now(): + return datetime.now(tz=timezone.utc) def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert_section_contains( result.output, "Commands", "db Software Heritage database generic tools." ) help_db_snippets = ( ( "Usage", ( "Usage: swh db [OPTIONS] COMMAND [ARGS]...", "Software Heritage database generic tools.", ), ), ( "Commands", ( "create Create a database for the Software Heritage .", "init Initialize a database for the Software Heritage .", "init-admin Execute superuser-level initialization steps", ), ), ) def test_cli_swh_db_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 for section, snippets in help_db_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) -@pytest.fixture() -def mock_package_sql(mocker, datadir): - """This bypasses the module manipulation to only returns the data test files. - - For a given module `test.mod`, look for sql files in the directory `data/mod/*.sql`. - - Typical usage:: - - def test_xxx(cli_runner, mock_package_sql): - conninfo = craft_conninfo(test_db, "new-db") - module_name = "test.cli" - # the command below will use sql scripts from swh/core/db/tests/data/cli/*.sql - cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) - """ - from swh.core.db.db_utils import get_sql_for_package - from swh.core.utils import numfile_sortkey as sortkey - - def get_sql_for_package_mock(modname): - if modname.startswith("test."): - sqldir = modname.split(".", 1)[1] - return sorted(glob.glob(path.join(datadir, sqldir, "*.sql")), key=sortkey) - return get_sql_for_package(modname) - - mock_sql_files = mocker.patch( - "swh.core.db.db_utils.get_sql_for_package", get_sql_for_package_mock - ) - - return mock_sql_files - - # We do not want the truncate behavior for those tests test_db = postgresql_fact( "postgresql_proc", dbname="clidb", no_truncate_tables={"dbversion", "origin"} ) @pytest.fixture def swh_db_cli(cli_runner, monkeypatch, test_db): """This initializes a cli_runner and sets the correct environment variable expected by the cli to run appropriately (when not specifying the --dbname flag) """ db_params = test_db.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) return cli_runner, db_params def craft_conninfo(test_db, dbname=None) -> str: """Craft conninfo string out of the test_db object. This also allows to override the dbname.""" db_params = test_db.get_dsn_parameters() if dbname: params = copy.deepcopy(db_params) params["dbname"] = dbname else: params = db_params return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) def test_cli_swh_db_create_and_init_db(cli_runner, test_db, mock_package_sql): """Create a db then initializing it should be ok """ module_name = "test.cli" conninfo = craft_conninfo(test_db, "new-db") # This creates the db and installs the necessary admin extensions result = cli_runner.invoke(swhdb, ["create", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # This initializes the schema and data result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_fail_without_creation_first( cli_runner, test_db, mock_package_sql ): """Init command on an inexisting db cannot work """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db, "inexisting-db") result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails because we cannot connect to an inexisting db assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_fail_without_extension( cli_runner, test_db, mock_package_sql ): """Init command cannot work without privileged extension. In this test, the schema needs privileged extension to work. """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db) result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails as the function `public.digest` is not installed, init-admin calls is needed # first (the next tests show such behavior) assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_works_with_flags( cli_runner, test_db, mock_package_sql ): """Init commands with carefully crafted libpq conninfo works """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(test_db) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, test_db): """Init commands with standard environment variables works """ module_name = "test.cli" # it's mocked here cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, test_db): """Multiple runs of the init commands are idempotent """ module_name = "test.cli" # mocked cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(test_db.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py new file mode 100644 index 0000000..70c228a --- /dev/null +++ b/swh/core/db/tests/test_db_utils.py @@ -0,0 +1,83 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from datetime import datetime, timedelta + +import pytest + +from swh.core.cli.db import db as swhdb +from swh.core.db import BaseDb +from swh.core.db.db_utils import get_database_info, swh_db_module, swh_db_versions + +from .test_cli import craft_conninfo, now + + +@pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) +def test_db_utils_versions(cli_runner, postgresql, mock_package_sql, module): + """Check get_database_info, swh_db_versions and swh_db_module work ok + + This test checks db versions for both a db with "new style" set of sql init + scripts (i.e. the dbversion table is not created in these scripts, but by + the populate_database_for_package() function directly, via the 'swh db + init' command) and an "old style" set (dbversion created in the scripts)S. + + """ + conninfo = craft_conninfo(postgresql) + result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + # the dbversion and dbmodule tables exists and are populated + dbmodule, dbversion, dbflavor = get_database_info(conninfo) + assert dbmodule == module + if module == "test.cli": + # old style: backend init script set the db version + assert dbversion == 1 + else: + # new style: they do not (but we do not have support for this in swh.core yet) + assert dbversion is None + assert dbflavor is None + + # check also the swh_db_module() function + assert swh_db_module(conninfo) == module + + # check also the swh_db_versions() function + versions = swh_db_versions(conninfo) + if module == "test.cli": + assert len(versions) == 1 + assert versions[0][0] == 1 + assert versions[0][1] == datetime.fromisoformat( + "2016-02-22T15:56:28.358587+00:00" + ) + assert versions[0][2] == "Work In Progress" + else: + assert not versions + # add a few versions in dbversion + cnx = BaseDb.connect(conninfo) + with cnx.transaction() as cur: + if module == "test.cli_new": + # add version 1 to make it simpler for checks below + cur.execute( + "insert into dbversion(version, release, description) " + "values(1, NOW(), 'Wotk in progress')" + ) + cur.executemany( + "insert into dbversion(version, release, description) values (%s, %s, %s)", + [(i, now(), f"Upgrade to version {i}") for i in range(2, 6)], + ) + + dbmodule, dbversion, dbflavor = get_database_info(conninfo) + assert dbmodule == module + assert dbversion == 5 + assert dbflavor is None + + versions = swh_db_versions(conninfo) + assert len(versions) == 5 + for i, (version, ts, desc) in enumerate(versions): + assert version == (5 - i) # these are in reverse order + if version > 1: + assert desc == f"Upgrade to version {version}" + assert (now() - ts) < timedelta(seconds=1)