diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 553162d..1e094d8 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,325 +1,409 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ import warnings import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @swh_cli_group.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools.""" from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="create", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--template", "-T", help="Template database from which to build this database.", default="template1", show_default=True, ) def db_create(module, dbname, template): """Create a database for the Software Heritage . and potentially execute superuser-level initialization steps. Example:: swh db create -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions. Example:: PGPORT=5434 swh db create indexer swh db create -d postgresql://superuser:passwd@pghost:5433/swh-storage storage """ from swh.core.db.db_utils import create_database_for_package logger.debug("db_create %s dn_name=%s", module, dbname) create_database_for_package(module, dbname, template) @db.command(name="init-admin", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) def db_init_admin(module: str, dbname: str) -> None: """Execute superuser-level initialization steps (e.g pg extensions, admin functions, ...) Example:: PGPASSWORD=... swh db init-admin -d swh-test scheduler If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions (e.g postgres, swh-admin, ...) Example:: PGPORT=5434 swh db init-admin scheduler swh db init-admin -d postgresql://superuser:passwd@pghost:5433/swh-scheduler \ scheduler """ from swh.core.db.db_utils import init_admin_extensions logger.debug("db_init_admin %s dbname=%s", module, dbname) init_admin_extensions(module, dbname) @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--dbname", "--db-name", "-d", help="Database name or connection URI.", default=None, show_default=False, ) @click.option( "--flavor", help="Database flavor.", default=None, ) @click.option( "--initial-version", help="Database initial version.", default=1, show_default=True ) @click.pass_context def db_init(ctx, module, dbname, flavor, initial_version): """Initialize a database for the Software Heritage . The database connection string comes from the configuration file (see option ``--config-file`` in ``swh db --help``) in the section named after the MODULE argument. Example:: $ cat conf.yml storage: cls: postgresql db: postgresql://user:passwd@pghost:5433/swh-storage objstorage: cls: memory $ swh db -C conf.yml init storage # or $ SWH_CONFIG_FILENAME=conf.yml swh db init storage Note that the connection string can also be passed directly using the '--db-name' option, but this usage is about to be deprecated. """ from swh.core.db.db_utils import ( get_database_info, import_swhmodule, populate_database_for_package, swh_set_db_version, ) cfg = None if dbname is None: # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_init %s flavor=%s dbname=%s", module, flavor, dbname) initialized, dbversion, dbflavor = populate_database_for_package( module, dbname, flavor ) if dbversion is None: if cfg is not None: # db version has not been populated by sql init scripts (new style), # let's do it; instantiate the data source to retrieve the current # (expected) db version datastore_factory = getattr(import_swhmodule(module), "get_datastore", None) if datastore_factory: datastore = datastore_factory(**cfg) try: get_current_version = datastore.get_current_version except AttributeError: logger.warning( "Datastore %s does not implement the " "'get_current_version()' method", datastore, ) else: code_version = get_current_version() logger.info( "Initializing database version to %s from the %s datastore", code_version, module, ) swh_set_db_version(dbname, code_version, desc="DB initialization") dbversion = get_database_info(dbname)[1] if dbversion is None: logger.info( "Initializing database version to %s " "from the command line option --initial-version", initial_version, ) swh_set_db_version(dbname, initial_version, desc="DB initialization") dbversion = get_database_info(dbname)[1] assert dbversion is not None # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {}{} at version {}".format( module, "initialized" if initialized else "exists", f" (flavor {dbflavor})" if dbflavor is not None else "", dbversion, ), fg="green", bold=True, ) if flavor is not None and dbflavor != flavor: click.secho( f"WARNING requested flavor '{flavor}' != recorded flavor '{dbflavor}'", fg="red", bold=True, ) @db.command(name="version", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--all/--no-all", "show_all", help="Show version history.", default=False, show_default=True, ) @click.pass_context def db_version(ctx, module, show_all): """Print the database version for the Software Heritage. Example:: swh db version -d swh-test """ from swh.core.db.db_utils import get_database_info, import_swhmodule # use the db cnx from the config file; the expected config entry is the # given module name cfg = ctx.obj["config"].get(module, {}) dbname = get_dburl_from_config(cfg) if not dbname: raise click.BadParameter( "Missing the postgresql connection configuration. Either fix your " "configuration file or use the --dbname option." ) logger.debug("db_version dbname=%s", dbname) db_module, db_version, db_flavor = get_database_info(dbname) if db_module is None: click.secho( "WARNING the database does not have a dbmodule table.", fg="red", bold=True ) db_module = module assert db_module == module, f"{db_module} (in the db) != {module} (given)" click.secho(f"module: {db_module}", fg="green", bold=True) if db_flavor is not None: click.secho(f"flavor: {db_flavor}", fg="green", bold=True) # instantiate the data source to retrieve the current (expected) db version datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) if datastore_factory: datastore = datastore_factory(**cfg) code_version = datastore.get_current_version() click.secho( f"current code version: {code_version}", fg="green" if code_version == db_version else "red", bold=True, ) if not show_all: click.secho(f"version: {db_version}", fg="green", bold=True) else: from swh.core.db.db_utils import swh_db_versions versions = swh_db_versions(dbname) for version, tstamp, desc in versions: click.echo(f"{version} [{tstamp}] {desc}") +@db.command(name="upgrade", context_settings=CONTEXT_SETTINGS) +@click.argument("module", required=True) +@click.option( + "--to-version", + type=int, + help="Upgrade up to version VERSION", + metavar="VERSION", + default=None, +) +@click.pass_context +def db_upgrade(ctx, module, to_version): + """Upgrade the database for given module (to a given version if specified). + + Examples:: + + swh db upgrade storage + swg db upgrade scheduler --to-version=10 + + """ + from swh.core.db.db_utils import ( + get_database_info, + import_swhmodule, + swh_db_upgrade, + swh_set_db_module, + ) + + # use the db cnx from the config file; the expected config entry is the + # given module name + cfg = ctx.obj["config"].get(module, {}) + dbname = get_dburl_from_config(cfg) + + if not dbname: + raise click.BadParameter( + "Missing the postgresql connection configuration. Either fix your " + "configuration file or use the --dbname option." + ) + + logger.debug("db_version dbname=%s", dbname) + + db_module, db_version, db_flavor = get_database_info(dbname) + if db_module is None: + click.secho( + "Warning: the database does not have a dbmodule table.", + fg="yellow", + bold=True, + ) + if not click.confirm( + f"Write the module information ({module}) in the database?", default=True + ): + raise click.BadParameter("Migration aborted.") + swh_set_db_module(dbname, module) + db_module = module + + if db_module != module: + raise click.BadParameter( + f"Error: the given module ({module}) does not match the value " + f"stored in the database ({db_module})." + ) + + # instantiate the data source to retrieve the current (expected) db version + datastore_factory = getattr(import_swhmodule(db_module), "get_datastore", None) + if not datastore_factory: + raise click.UsageError( + "You cannot use this command on old-style datastore backend {db_module}" + ) + datastore = datastore_factory(**cfg) + ds_version = datastore.get_current_version() + if to_version is None: + to_version = ds_version + if to_version > ds_version: + raise click.UsageError( + f"The target version {to_version} is larger than the current version " + f"{ds_version} of the datastore backend {db_module}" + ) + + new_db_version = swh_db_upgrade(dbname, module, to_version) + click.secho(f"Migration to version {new_db_version} done", fg="green") + if new_db_version < ds_version: + click.secho( + f"Warning: migration was not complete: the current version is {ds_version}", + fg="yellow", + ) + + def get_dburl_from_config(cfg): if cfg.get("cls") != "postgresql": raise click.BadParameter( "Configuration cls must be set to 'postgresql' for this command." ) if "args" in cfg: # for bw compat cfg = cfg["args"] return cfg.get("db") diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 2289f2e..74aeca5 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,552 +1,657 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timezone import functools import glob from importlib import import_module import logging from os import path import re import subprocess from typing import Collection, Dict, List, Optional, Tuple, Union import psycopg2 import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def now(): return datetime.now(tz=timezone.utc) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection: """Connect to the database passed in argument Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle Raises: psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, pgconnection): return db_or_conninfo if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" db = psycopg2.connect(db_or_conninfo) return db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) result = c.fetchone() if result: return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_versions( db_or_conninfo: Union[str, pgconnection] ) -> Optional[List[Tuple[int, datetime, str]]]: """Retrieve the swh version history of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = ( "select version, release, description " "from dbversion order by dbversion desc" ) try: c.execute(query) return c.fetchall() except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None +def swh_db_upgrade( + conninfo: str, modname: str, to_version: Optional[int] = None +) -> int: + """Upgrade the database at `conninfo` for module `modname` + + This will run migration scripts found in the `sql/upgrades` subdirectory of + the module `modname`. By default, this will upgrade to the latest declared version. + + Args: + conninfo: A database connection, or a database connection info string + modname: datastore module the database stores content for + to_version: if given, update the database to this version rather than the latest + + """ + + if to_version is None: + to_version = 99999999 + + db_module, db_version, db_flavor = get_database_info(conninfo) + if db_version is None: + raise ValueError("Unable to retrieve the current version of the database") + if db_module is None: + raise ValueError("Unable to retrieve the module of the database") + if db_module != modname: + raise ValueError( + "The stored module of the database is different than the given one" + ) + + sqlfiles = [ + fname + for fname in get_sql_for_package(modname, upgrade=True) + if db_version < int(path.splitext(path.basename(fname))[0]) <= to_version + ] + + for sqlfile in sqlfiles: + new_version = int(path.splitext(path.basename(sqlfile))[0]) + logger.info("Executing migration script {sqlfile}") + if db_version is not None and (new_version - db_version) > 1: + logger.error( + f"There are missing migration steps between {db_version} and " + f"{new_version}. It might be expected but it most unlikely is not. " + "Will stop here." + ) + return db_version + + execute_sqlfiles([sqlfile], conninfo, db_flavor) + + # check if the db version has been updated by the upgrade script + db_version = swh_db_version(conninfo) + assert db_version is not None + if db_version == new_version: + # nothing to do, upgrade script did the job + pass + elif db_version == new_version - 1: + # it has not (new style), so do it + swh_set_db_version( + conninfo, + new_version, + desc=f"Upgraded to version {new_version} using {sqlfile}", + ) + db_version = swh_db_version(conninfo) + else: + # upgrade script did it wrong + logger.error( + f"The upgrade script {sqlfile} did not update the dbversion table " + f"consistently ({db_version} vs. expected {new_version}). " + "Will stop migration here. Please check your migration scripts." + ) + return db_version + return new_version + + def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh module used to create the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the module of the database, or None if it couldn't be detected """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select dbmodule from dbmodule limit 1" try: c.execute(query) resp = c.fetchone() if resp: return resp[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None -def swh_set_db_module(db_or_conninfo: Union[str, pgconnection], module: str) -> None: +def swh_set_db_module( + db_or_conninfo: Union[str, pgconnection], module: str, force=False +) -> None: """Set the swh module used to create the database. Fails if the dbmodule is already set or the table does not exist. Args: db_or_conninfo: A database connection, or a database connection info string module: the swh module to register (without the leading 'swh.') """ + update = False if module.startswith("swh."): module = module[4:] + current_module = swh_db_module(db_or_conninfo) + if current_module is not None: + if current_module == module: + logger.warning("The database module is already set to %s", module) + return + + if not force: + raise ValueError( + "The database module is already set to a value %s " + "different than given %s", + current_module, + module, + ) + # force is True + update = True try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None + sqlfiles = [ + fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname + ] + execute_sqlfiles(sqlfiles, db_or_conninfo) + with db.cursor() as c: - query = "insert into dbmodule(dbmodule) values (%s)" + if update: + query = "update dbmodule set dbmodule = %s" + else: + query = "insert into dbmodule(dbmodule) values (%s)" c.execute(query, (module,)) db.commit() def swh_set_db_version( db_or_conninfo: Union[str, pgconnection], version: int, ts: Optional[datetime] = None, desc: str = "Work in progress", ) -> None: """Set the version of the database. Fails if the dbversion table does not exists. Args: db_or_conninfo: A database connection, or a database connection info string version: the version to add """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None if ts is None: ts = now() with db.cursor() as c: query = ( "insert into dbversion(version, release, description) values (%s, %s, %s)" ) c.execute(query, (version, ts, desc)) db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: db = connect_to_conninfo(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) # Database not initialized return None try: with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m -def get_sql_for_package(modname): +def get_sql_for_package(modname: str, upgrade: bool = False) -> List[str]: + """Return the (sorted) list of sql script files for the given swh module + + If upgrade is True, return the list of available migration scripts, + otherwise, return the list of initialization scripts. + """ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = path.join(path.dirname(m.__file__), "sql") + if upgrade: + sqldir += "/upgrades" if not path.isdir(sqldir): raise ValueError( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor def globalsortkey(key): "like sortkey but only on basenames" return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") sqlfiles = sorted(sqlfiles, key=globalsortkey) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table swh_set_db_module(conninfo, modname) current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) return True, current_db_version, dbflavor def get_database_info( conninfo: str, ) -> Tuple[Optional[str], Optional[int], Optional[str]]: """Get version, flavor and module of the db""" dbmodule = swh_db_module(conninfo) dbversion = swh_db_version(conninfo) dbflavor = None if dbversion is not None: dbflavor = swh_db_flavor(conninfo) return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None ): """Execute a list of SQL files on the database pointed at with ``conninfo``. Args: sqlfiles: List of SQL files to execute conninfo: connection info string for the SQL database flavor: the database flavor to initialize """ psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", sqlfile]) if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/sql/35-dbmetadata.sql b/swh/core/db/sql/35-dbversion.sql similarity index 61% copy from swh/core/db/sql/35-dbmetadata.sql copy to swh/core/db/sql/35-dbversion.sql index 3e49760..ee85ac7 100644 --- a/swh/core/db/sql/35-dbmetadata.sql +++ b/swh/core/db/sql/35-dbversion.sql @@ -1,28 +1,18 @@ -- common metadata/context structures -- -- we use a 35- prefix for this to make it executed after db schema initialisation -- sql scripts, which are normally 30- prefixed, so that it remains compatible -- with packages that have not yet migrated to swh.core 1.2 -- schema versions create table if not exists dbversion ( version int primary key, release timestamptz, description text ); comment on table dbversion is 'Details of current db version'; comment on column dbversion.version is 'SQL schema version'; comment on column dbversion.release is 'Version deployment timestamp'; comment on column dbversion.description is 'Release description'; - --- swh module this db is storing data for -create table if not exists dbmodule ( - dbmodule text, - single_row char(1) primary key default 'x', - check (single_row = 'x') -); -comment on table dbmodule is 'Database module storage'; -comment on column dbmodule.dbmodule is 'Database (swh) module currently deployed'; -comment on column dbmodule.single_row is 'Bogus column to force the table to have a single row'; diff --git a/swh/core/db/sql/35-dbmetadata.sql b/swh/core/db/sql/36-dbmodule.sql similarity index 57% rename from swh/core/db/sql/35-dbmetadata.sql rename to swh/core/db/sql/36-dbmodule.sql index 3e49760..ae9a670 100644 --- a/swh/core/db/sql/35-dbmetadata.sql +++ b/swh/core/db/sql/36-dbmodule.sql @@ -1,28 +1,15 @@ -- common metadata/context structures -- --- we use a 35- prefix for this to make it executed after db schema initialisation +-- we use a 3x- prefix for this to make it executed after db schema initialisation -- sql scripts, which are normally 30- prefixed, so that it remains compatible -- with packages that have not yet migrated to swh.core 1.2 --- schema versions -create table if not exists dbversion -( - version int primary key, - release timestamptz, - description text -); - -comment on table dbversion is 'Details of current db version'; -comment on column dbversion.version is 'SQL schema version'; -comment on column dbversion.release is 'Version deployment timestamp'; -comment on column dbversion.description is 'Release description'; - -- swh module this db is storing data for create table if not exists dbmodule ( dbmodule text, single_row char(1) primary key default 'x', check (single_row = 'x') ); comment on table dbmodule is 'Database module storage'; comment on column dbmodule.dbmodule is 'Database (swh) module currently deployed'; comment on column dbmodule.single_row is 'Bogus column to force the table to have a single row'; diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py index d87f8d3..342bc90 100644 --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,57 +1,59 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import os from click.testing import CliRunner from hypothesis import HealthCheck import pytest from swh.core.db.db_utils import get_sql_for_package from swh.core.utils import numfile_sortkey as sortkey os.environ["LC_ALL"] = "C.UTF-8" # we use getattr here to keep mypy happy regardless hypothesis version function_scoped_fixture_check = ( [getattr(HealthCheck, "function_scoped_fixture")] if hasattr(HealthCheck, "function_scoped_fixture") else [] ) @pytest.fixture def cli_runner(): return CliRunner() @pytest.fixture() def mock_package_sql(mocker, datadir): """This bypasses the module manipulation to only returns the data test files. For a given module `test.mod`, look for sql files in the directory `data/mod/*.sql`. Typical usage:: def test_xxx(cli_runner, mock_package_sql): conninfo = craft_conninfo(test_db, "new-db") module_name = "test.cli" # the command below will use sql scripts from swh/core/db/tests/data/cli/*.sql cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) """ - def get_sql_for_package_mock(modname): + def get_sql_for_package_mock(modname, upgrade=False): if modname.startswith("test."): sqldir = modname.split(".", 1)[1] + if upgrade: + sqldir += "/upgrades" return sorted( glob.glob(os.path.join(datadir, sqldir, "*.sql")), key=sortkey ) return get_sql_for_package(modname) mock_sql_files = mocker.patch( "swh.core.db.db_utils.get_sql_for_package", get_sql_for_package_mock ) return mock_sql_files diff --git a/swh/core/db/tests/data/cli_new/upgrades/001.sql b/swh/core/db/tests/data/cli_new/upgrades/001.sql new file mode 100644 index 0000000..d914414 --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/001.sql @@ -0,0 +1,5 @@ +-- this script should never be executed by an upgrade procedure (because +-- version 1 is set by 'swh db init') + +insert into origin(url, hash) +values ('this should never be executed', hash_sha1('')); diff --git a/swh/core/db/tests/data/cli_new/upgrades/002.sql b/swh/core/db/tests/data/cli_new/upgrades/002.sql new file mode 100644 index 0000000..5f12b9e --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/002.sql @@ -0,0 +1,4 @@ +-- + +insert into origin(url, hash) +values ('version002', hash_sha1('version002')); diff --git a/swh/core/db/tests/data/cli_new/upgrades/003.sql b/swh/core/db/tests/data/cli_new/upgrades/003.sql new file mode 100644 index 0000000..87ac9e1 --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/003.sql @@ -0,0 +1,4 @@ +-- + +insert into origin(url, hash) +values ('version003', hash_sha1('version003')); diff --git a/swh/core/db/tests/data/cli_new/upgrades/004.sql b/swh/core/db/tests/data/cli_new/upgrades/004.sql new file mode 100644 index 0000000..d1f03da --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/004.sql @@ -0,0 +1,4 @@ +-- + +insert into origin(url, hash) +values ('version004', hash_sha1('version004')); diff --git a/swh/core/db/tests/data/cli_new/upgrades/005.sql b/swh/core/db/tests/data/cli_new/upgrades/005.sql new file mode 100644 index 0000000..8d0db9e --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/005.sql @@ -0,0 +1,4 @@ +-- + +insert into origin(url, hash) +values ('version005', hash_sha1('version005')); diff --git a/swh/core/db/tests/data/cli_new/upgrades/006.sql b/swh/core/db/tests/data/cli_new/upgrades/006.sql new file mode 100644 index 0000000..115b59f --- /dev/null +++ b/swh/core/db/tests/data/cli_new/upgrades/006.sql @@ -0,0 +1,7 @@ +-- + +insert into origin(url, hash) +values ('version006', hash_sha1('version006')); + +insert into dbversion(version, release, description) +values (6, 'NOW()', 'Updated version from upgrade script'); diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index 35663f9..bbea68c 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,258 +1,356 @@ # Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy +import traceback from unittest.mock import MagicMock import pytest import yaml from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb -from swh.core.db.db_utils import import_swhmodule +from swh.core.db.db_utils import import_swhmodule, swh_db_module from swh.core.tests.test_cli import assert_section_contains def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert_section_contains( result.output, "Commands", "db Software Heritage database generic tools." ) help_db_snippets = ( ( "Usage", ( "Usage: swh db [OPTIONS] COMMAND [ARGS]...", "Software Heritage database generic tools.", ), ), ( "Commands", ( "create Create a database for the Software Heritage .", "init Initialize a database for the Software Heritage .", "init-admin Execute superuser-level initialization steps", ), ), ) def test_cli_swh_db_help(swhmain, cli_runner): swhmain.add_command(swhdb) result = cli_runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 for section, snippets in help_db_snippets: for snippet in snippets: assert_section_contains(result.output, section, snippet) @pytest.fixture def swh_db_cli(cli_runner, monkeypatch, postgresql): """This initializes a cli_runner and sets the correct environment variable expected by the cli to run appropriately (when not specifying the --dbname flag) """ db_params = postgresql.get_dsn_parameters() monkeypatch.setenv("PGHOST", db_params["host"]) monkeypatch.setenv("PGUSER", db_params["user"]) monkeypatch.setenv("PGPORT", db_params["port"]) return cli_runner, db_params def craft_conninfo(test_db, dbname=None) -> str: """Craft conninfo string out of the test_db object. This also allows to override the dbname.""" db_params = test_db.get_dsn_parameters() if dbname: params = copy.deepcopy(db_params) params["dbname"] = dbname else: params = db_params return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) def test_cli_swh_db_create_and_init_db(cli_runner, postgresql, mock_package_sql): """Create a db then initializing it should be ok """ module_name = "test.cli" conninfo = craft_conninfo(postgresql, "new-db") # This creates the db and installs the necessary admin extensions result = cli_runner.invoke(swhdb, ["create", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # This initializes the schema and data result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_fail_without_creation_first( cli_runner, postgresql, mock_package_sql ): """Init command on an inexisting db cannot work """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql, "inexisting-db") result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails because we cannot connect to an inexisting db assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_fail_without_extension( cli_runner, postgresql, mock_package_sql ): """Init command cannot work without privileged extension. In this test, the schema needs privileged extension to work. """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) # Fails as the function `public.digest` is not installed, init-admin calls is needed # first (the next tests show such behavior) assert result.exit_code == 1, f"Unexpected output: {result.output}" def test_cli_swh_db_initialization_works_with_flags( cli_runner, postgresql, mock_package_sql ): """Init commands with carefully crafted libpq conninfo works """ module_name = "test.cli" # it's mocked here conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["init", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, postgresql): """Init commands with standard environment variables works """ module_name = "test.cli" # it's mocked here cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, postgresql): """Multiple runs of the init commands are idempotent """ module_name = "test.cli" # mocked cli_runner, db_params = swh_db_cli result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init-admin", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module_name, "--dbname", db_params["dbname"]] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # the origin values in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, init-admin calls installs it) with BaseDb.connect(postgresql.dsn).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 def test_cli_swh_db_create_and_init_db_new_api( cli_runner, postgresql, mock_package_sql, mocker, tmp_path ): """Create a db then initializing it should be ok for a "new style" datastore """ module_name = "test.cli_new" def import_swhmodule_mock(modname): if modname.startswith("test."): def get_datastore(cls, **kw): # XXX probably not the best way of doing this... return MagicMock(get_current_version=lambda: 42) return MagicMock(name=modname, get_datastore=get_datastore) return import_swhmodule(modname) mocker.patch("swh.core.db.db_utils.import_swhmodule", import_swhmodule_mock) conninfo = craft_conninfo(postgresql) # This initializes the schema and data cfgfile = tmp_path / "config.yml" cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}})) result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) - import traceback - assert ( result.exit_code == 0 ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" # the origin value in the scripts uses a hash function (which implementation wise # uses a function from the pgcrypt extension, installed during db creation step) with BaseDb.connect(conninfo).cursor() as cur: cur.execute("select * from origin") origins = cur.fetchall() assert len(origins) == 1 + + +def test_cli_swh_db_upgrade_new_api( + cli_runner, postgresql, mock_package_sql, mocker, tmp_path +): + """Upgrade scenario for a "new style" datastore + + """ + module_name = "test.cli_new" + + from unittest.mock import MagicMock + + from swh.core.db.db_utils import import_swhmodule, swh_db_version + + # the `current_version` variable is the version that will be returned by + # any call to `get_current_version()` in this test session, thanks to the + # local mocked version of import_swhmodule() below. + current_version = 1 + + def import_swhmodule_mock(modname): + if modname.startswith("test."): + + def get_datastore(cls, **kw): + # XXX probably not the best way of doing this... + return MagicMock(get_current_version=lambda: current_version) + + return MagicMock(name=modname, get_datastore=get_datastore) + + return import_swhmodule(modname) + + mocker.patch("swh.core.db.db_utils.import_swhmodule", import_swhmodule_mock) + conninfo = craft_conninfo(postgresql) + + # This initializes the schema and data + cfgfile = tmp_path / "config.yml" + cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}})) + result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name]) + + assert ( + result.exit_code == 0 + ), f"Unexpected output: {traceback.print_tb(result.exc_info[2])}" + + assert swh_db_version(conninfo) == 1 + + # the upgrade should not do anything because the datastore does advertise + # version 1 + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) + assert swh_db_version(conninfo) == 1 + + # advertise current version as 3, a simple upgrade should get us there, but + # no further + current_version = 3 + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) + assert swh_db_version(conninfo) == 3 + + # an attempt to go further should not do anything + result = cli_runner.invoke( + swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 5] + ) + assert swh_db_version(conninfo) == 3 + # an attempt to go lower should not do anything + result = cli_runner.invoke( + swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 2] + ) + assert swh_db_version(conninfo) == 3 + + # advertise current version as 6, an upgrade with --to-version 4 should + # stick to the given version 4 and no further + current_version = 6 + result = cli_runner.invoke( + swhdb, ["-C", cfgfile, "upgrade", module_name, "--to-version", 4] + ) + assert swh_db_version(conninfo) == 4 + assert "migration was not complete" in result.output + + # attempt to upgrade to a newer version than current code version fails + result = cli_runner.invoke( + swhdb, + ["-C", cfgfile, "upgrade", module_name, "--to-version", current_version + 1], + ) + assert result.exit_code != 0 + assert swh_db_version(conninfo) == 4 + + cnx = BaseDb.connect(conninfo) + with cnx.transaction() as cur: + cur.execute("drop table dbmodule") + assert swh_db_module(conninfo) is None + + # db migration should recreate the missing dbmodule table + result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name]) + assert result.exit_code == 0 + assert "Warning: the database does not have a dbmodule table." in result.output + assert ( + "Write the module information (test.cli_new) in the database? [Y/n]" + in result.output + ) + assert swh_db_module(conninfo) == module_name diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py index ed033e5..4f0600e 100644 --- a/swh/core/db/tests/test_db_utils.py +++ b/swh/core/db/tests/test_db_utils.py @@ -1,78 +1,174 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta +from os import path import pytest from swh.core.cli.db import db as swhdb from swh.core.db import BaseDb -from swh.core.db.db_utils import get_database_info, now, swh_db_module, swh_db_versions +from swh.core.db.db_utils import ( + get_database_info, + now, + swh_db_module, + swh_db_upgrade, + swh_db_version, + swh_db_versions, + swh_set_db_module, +) from .test_cli import craft_conninfo @pytest.mark.parametrize("module", ["test.cli", "test.cli_new"]) def test_db_utils_versions(cli_runner, postgresql, mock_package_sql, module): """Check get_database_info, swh_db_versions and swh_db_module work ok This test checks db versions for both a db with "new style" set of sql init scripts (i.e. the dbversion table is not created in these scripts, but by the populate_database_for_package() function directly, via the 'swh db init' command) and an "old style" set (dbversion created in the scripts)S. """ conninfo = craft_conninfo(postgresql) result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) assert result.exit_code == 0, f"Unexpected output: {result.output}" result = cli_runner.invoke( swhdb, ["init", module, "--dbname", conninfo, "--initial-version", 10] ) assert result.exit_code == 0, f"Unexpected output: {result.output}" # check the swh_db_module() function assert swh_db_module(conninfo) == module # the dbversion and dbmodule tables exists and are populated dbmodule, dbversion, dbflavor = get_database_info(conninfo) # check also the swh_db_versions() function versions = swh_db_versions(conninfo) assert dbmodule == module assert dbversion == 10 assert dbflavor is None # check also the swh_db_versions() function versions = swh_db_versions(conninfo) assert len(versions) == 1 assert versions[0][0] == 10 if module == "test.cli": assert versions[0][1] == datetime.fromisoformat( "2016-02-22T15:56:28.358587+00:00" ) assert versions[0][2] == "Work In Progress" else: # new scheme but with no datastore (so no version support from there) assert versions[0][2] == "DB initialization" # add a few versions in dbversion cnx = BaseDb.connect(conninfo) with cnx.transaction() as cur: cur.executemany( "insert into dbversion(version, release, description) values (%s, %s, %s)", [(i, now(), f"Upgrade to version {i}") for i in range(11, 15)], ) dbmodule, dbversion, dbflavor = get_database_info(conninfo) assert dbmodule == module assert dbversion == 14 assert dbflavor is None versions = swh_db_versions(conninfo) assert len(versions) == 5 for i, (version, ts, desc) in enumerate(versions): assert version == (14 - i) # these are in reverse order if version > 10: assert desc == f"Upgrade to version {version}" assert (now() - ts) < timedelta(seconds=1) + + +@pytest.mark.parametrize("module", ["test.cli_new"]) +def test_db_utils_upgrade(cli_runner, postgresql, mock_package_sql, module, datadir): + """Check swh_db_upgrade + + """ + conninfo = craft_conninfo(postgresql) + result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + assert swh_db_version(conninfo) == 1 + new_version = swh_db_upgrade(conninfo, module) + assert new_version == 6 + assert swh_db_version(conninfo) == 6 + + versions = swh_db_versions(conninfo) + # get rid of dates to ease checking + versions = [(v[0], v[2]) for v in versions] + assert versions[-1] == (1, "DB initialization") + sqlbasedir = path.join(datadir, module.split(".", 1)[1], "upgrades") + + assert versions[1:-1] == [ + (i, f"Upgraded to version {i} using {sqlbasedir}/{i:03d}.sql") + for i in range(5, 1, -1) + ] + assert versions[0] == (6, "Updated version from upgrade script") + + cnx = BaseDb.connect(conninfo) + with cnx.transaction() as cur: + cur.execute("select url from origin where url like 'version%'") + result = cur.fetchall() + assert result == [("version%03d" % i,) for i in range(2, 7)] + cur.execute( + "select url from origin where url = 'this should never be executed'" + ) + result = cur.fetchall() + assert not result + + +@pytest.mark.parametrize("module", ["test.cli_new"]) +def test_db_utils_swh_db_upgrade_sanity_checks( + cli_runner, postgresql, mock_package_sql, module, datadir +): + """Check swh_db_upgrade + + """ + conninfo = craft_conninfo(postgresql) + result = cli_runner.invoke(swhdb, ["init-admin", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + cnx = BaseDb.connect(conninfo) + with cnx.transaction() as cur: + cur.execute("drop table dbmodule") + + # try to upgrade with a unset module + with pytest.raises(ValueError): + swh_db_upgrade(conninfo, module) + + # check the dbmodule is unset + assert swh_db_module(conninfo) is None + + # set the stored module to something else + swh_set_db_module(conninfo, f"{module}2") + assert swh_db_module(conninfo) == f"{module}2" + + # try to upgrade with a different module + with pytest.raises(ValueError): + swh_db_upgrade(conninfo, module) + + # revert to the proper module in the db + swh_set_db_module(conninfo, module, force=True) + assert swh_db_module(conninfo) == module + # trying again is a noop + swh_set_db_module(conninfo, module) + assert swh_db_module(conninfo) == module + + # drop the dbversion table + with cnx.transaction() as cur: + cur.execute("drop table dbversion") + # an upgrade should fail due to missing stored version + with pytest.raises(ValueError): + swh_db_upgrade(conninfo, module)