diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 51c3d81..37a1d35 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,244 +1,278 @@ #!/usr/bin/env python3 # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from os import environ, path -from typing import Collection, Tuple +from typing import Collection, Optional, Tuple import warnings import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.cli import swh as swh_cli_group warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t logger = logging.getLogger(__name__) @swh_cli_group.group(name="db", context_settings=CONTEXT_SETTINGS) @click.option( "--config-file", "-C", default=None, type=click.Path(exists=True, dir_okay=False), help="Configuration file.", ) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools.""" from swh.core.config import read as config_read ctx.ensure_object(dict) if config_file is None: config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="create", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) @click.option( "--template", "-T", help="Template database from which to build this database.", default="template1", show_default=True, ) def db_create(module, db_name, template): """Create a database for the Software Heritage . and potentially execute superuser-level initialization steps. Example: swh db create -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables or by the mean of a properly crafted libpq connection URI. See psql(1) man page (section ENVIRONMENTS) for details. Note: this command requires a postgresql connection with superuser permissions. Example: PGPORT=5434 swh db create indexer swh db create -d postgresql://superuser:passwd@pghost:5433/swh-storage storage """ logger.debug("db_create %s dn_name=%s", module, db_name) create_database_for_package(module, db_name, template) @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.argument("module", required=True) @click.option( "--db-name", "-d", help="Database name.", default="softwareheritage-dev", show_default=True, ) -def db_init(module, db_name): +@click.option( + "--flavor", help="Database flavor.", default=None, +) +def db_init(module, db_name, flavor): """Initialize a database for the Software Heritage . Example: swh db init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. - Example: + Examples: - PGPORT=5434 swh db-init indexer + PGPORT=5434 swh db init indexer swh db init -d postgresql://user:passwd@pghost:5433/swh-storage storage + swh db init --flavor read_replica -d swh-storage storage """ - logger.debug("db_init %s dn_name=%s", module, db_name) + logger.debug("db_init %s flavor=%s dn_name=%s", module, flavor, db_name) - initialized, dbversion = populate_database_for_package(module, db_name) + initialized, dbversion, dbflavor = populate_database_for_package( + module, db_name, flavor + ) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( - "DONE database for {} {} at version {}".format( - module, "initialized" if initialized else "exists", dbversion + "DONE database for {} {}{} at version {}".format( + module, + "initialized" if initialized else "exists", + f" (flavor {dbflavor})" if dbflavor is not None else "", + dbversion, ), fg="green", bold=True, ) + if flavor is not None and dbflavor != flavor: + click.secho( + f"WARNING requested flavor '{flavor}' != recorded flavor '{dbflavor}'", + fg="red", + bold=True, + ) + def get_sql_for_package(modname): import glob from importlib import import_module from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) -def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: +def populate_database_for_package( + modname: str, conninfo: str, flavor: Optional[str] = None +) -> Tuple[bool, int, Optional[str]]: """Populate the database, pointed at with `conninfo`, using the SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database + flavor: the module-specific flavor which we want to initialize the database under Returns: - Tuple with two elements: whether the database has been initialized; the current - version of the database. + Tuple with three elements: whether the database has been initialized; the current + version of the database; if it exists, the flavor of the database. """ - from swh.core.db.db_utils import swh_db_version + from swh.core.db.db_utils import swh_db_flavor, swh_db_version current_version = swh_db_version(conninfo) if current_version is not None: - return False, current_version + dbflavor = swh_db_flavor(conninfo) + return False, current_version, dbflavor sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] - execute_sqlfiles(sqlfiles, conninfo) + execute_sqlfiles(sqlfiles, conninfo, flavor) current_version = swh_db_version(conninfo) assert current_version is not None - return True, current_version + dbflavor = swh_db_flavor(conninfo) + return True, current_version, dbflavor def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with `conninfo`, and initialize it using -superuser- SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database template: the name of the database to connect to and use as template to create the new database """ import subprocess from psycopg2.extensions import make_dsn, parse_dsn # Use the given conninfo but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn(conninfo) db_name = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create db_name=%s (from %s)", db_name, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f"CREATE DATABASE {db_name}", ] ) # the remaining initialization process -- running -superuser- SQL files -- # is done using the given conninfo, thus connecting to the newly created # database sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname] execute_sqlfiles(sqlfiles, conninfo) -def execute_sqlfiles(sqlfiles: Collection[str], conninfo: str): +def execute_sqlfiles( + sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None +): """Execute a list of SQL files on the database pointed at with `conninfo`. Args: sqlfiles: List of SQL files to execute conninfo: connection info string for the SQL database + flavor: the database flavor to initialize """ import subprocess + psql_command = [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + ] + + flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} db_name={conninfo}") - subprocess.check_call( - [ - "psql", - "--quiet", - "--no-psqlrc", - "-v", - "ON_ERROR_STOP=1", - "-d", - conninfo, - "-f", - sqlfile, - ] + subprocess.check_call(psql_command + ["-f", sqlfile]) + + if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): + logger.debug("Setting database flavor %s", flavor) + query = f"insert into dbflavor (flavor) values ('{flavor}')" + subprocess.check_call(psql_command + ["-c", query]) + flavor_set = True + + if flavor is not None and not flavor_set: + logger.warn( + "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 5b7e64d..77ddb59 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,198 +1,252 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools import logging import re from typing import Optional, Union import psycopg2 import psycopg2.extensions logger = logging.getLogger(__name__) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value -def swh_db_version( +def connect_to_conninfo( db_or_conninfo: Union[str, psycopg2.extensions.connection] -) -> Optional[int]: - """Retrieve the swh version if any. In case of the db not initialized, - this returns None. Otherwise, this returns the db's version. +) -> psycopg2.extensions.connection: + """Connect to the database passed in argument Args: db_or_conninfo: A database connection, or a database connection info string Returns: - Optional[Int]: Either the db's version or None + a connected database handle + Raises: + psycopg2.Error if the database doesn't exist """ - if isinstance(db_or_conninfo, psycopg2.extensions.connection): - db = db_or_conninfo - else: - try: - if "=" not in db_or_conninfo: - # Database name - db_or_conninfo = f"dbname={db_or_conninfo}" - db = psycopg2.connect(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None + return db_or_conninfo + + if "=" not in db_or_conninfo and "//" not in db_or_conninfo: + # Database name + db_or_conninfo = f"dbname={db_or_conninfo}" + + db = psycopg2.connect(db_or_conninfo) + + return db + + +def swh_db_version( + db_or_conninfo: Union[str, psycopg2.extensions.connection] +) -> Optional[int]: + """Retrieve the swh version of the database. + + If the database is not initialized, this logs a warning and returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Either the version of the database, or None if it couldn't be detected + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None try: with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) return c.fetchone()[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None +def swh_db_flavor( + db_or_conninfo: Union[str, psycopg2.extensions.connection] +) -> Optional[str]: + """Retrieve the swh flavor of the database. + + If the database is not initialized, or the database doesn't support + flavors, this returns None. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + The flavor of the database, or None if it could not be detected. + """ + try: + db = connect_to_conninfo(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + try: + with db.cursor() as c: + query = "select swh_get_dbflavor()" + try: + c.execute(query) + return c.fetchone()[0] + except psycopg2.errors.UndefinedFunction: + # function not found: no flavor + return None + except Exception: + logger.exception("Could not get flavor from `%s`", db_or_conninfo) + return None + + # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(br"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur