diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -6,7 +6,7 @@ import logging from os import environ, path -from typing import Collection, Tuple +from typing import Collection, Dict, Optional, Tuple import warnings import click @@ -92,7 +92,10 @@ default="softwareheritage-dev", show_default=True, ) -def db_init(module, db_name): +@click.option( + "--flavor", help="Database flavor.", default=None, +) +def db_init(module, db_name, flavor): """Initialize a database for the Software Heritage . Example: @@ -103,23 +106,26 @@ please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. - Example: + Examples: - PGPORT=5434 swh db-init indexer + PGPORT=5434 swh db init indexer swh db init -d postgresql://user:passwd@pghost:5433/swh-storage storage + swh db init --flavor read_replica -d swh-storage storage """ - logger.debug("db_init %s dn_name=%s", module, db_name) + logger.debug("db_init %s flavor=%s dn_name=%s", module, flavor, db_name) - initialized, dbversion = populate_database_for_package(module, db_name) + initialized, dbversion = populate_database_for_package(module, db_name, flavor) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {} at version {}".format( - module, "initialized" if initialized else "exists", dbversion + module, + f"initialized (as flavor {flavor})" if initialized else "exists", + dbversion, ), fg="green", bold=True, @@ -147,13 +153,16 @@ return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey) -def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: +def populate_database_for_package( + modname: str, conninfo: str, flavor: Optional[str] = None +) -> Tuple[bool, int]: """Populate the database, pointed at with `conninfo`, using the SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database + flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with two elements: whether the database has been initialized; the current version of the database. @@ -166,13 +175,26 @@ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname] - execute_sqlfiles(sqlfiles, conninfo) + execute_sqlfiles(sqlfiles, conninfo, flavor) current_version = swh_db_version(conninfo) assert current_version is not None return True, current_version +def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: + """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" + import psycopg2 + from psycopg2.extensions import parse_dsn as _parse_dsn + + try: + return _parse_dsn(dsn_or_dbname) + except psycopg2.ProgrammingError: + # psycopg2 failed to parse the DSN; it's probably a database name, + # handle it as such + return _parse_dsn(f"dbname={dsn_or_dbname}") + + def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): @@ -181,18 +203,18 @@ Args: modname: Name of the module of which we're loading the files - conninfo: connection info string for the SQL database + conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ import subprocess - from psycopg2.extensions import make_dsn, parse_dsn + from psycopg2.extensions import make_dsn - # Use the given conninfo but with dbname replaced by the template dbname + # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step - creation_dsn = parse_dsn(conninfo) + creation_dsn = parse_dsn_or_dbname(conninfo) db_name = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create db_name=%s (from %s)", db_name, template) @@ -206,7 +228,7 @@ "-d", make_dsn(**creation_dsn), "-c", - f"CREATE DATABASE {db_name}", + f'CREATE DATABASE "{db_name}"', ] ) @@ -218,27 +240,40 @@ execute_sqlfiles(sqlfiles, conninfo) -def execute_sqlfiles(sqlfiles: Collection[str], conninfo: str): +def execute_sqlfiles( + sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None +): """Execute a list of SQL files on the database pointed at with `conninfo`. Args: sqlfiles: List of SQL files to execute conninfo: connection info string for the SQL database + flavor: the database flavor to initialize """ import subprocess + psql_command = [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + ] + + flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} db_name={conninfo}") - subprocess.check_call( - [ - "psql", - "--quiet", - "--no-psqlrc", - "-v", - "ON_ERROR_STOP=1", - "-d", - conninfo, - "-f", - sqlfile, - ] + subprocess.check_call(psql_command + ["-f", sqlfile]) + + if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"): + logger.debug("Setting database flavor %s", flavor) + query = f"insert into dbflavor (flavor) values ('{flavor}')" + subprocess.check_call(psql_command + ["-c", query]) + flavor_set = True + + if flavor is not None and not flavor_set: + logger.warn( + "Asked for flavor %s, but module does not support database flavors", flavor, )