diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -6,6 +6,8 @@ import logging from os import path, environ +from typing import Tuple + import warnings import click @@ -75,34 +77,27 @@ import subprocess for modname, cfg in ctx.obj["config"].items(): - if cfg.get("cls") == "local" and cfg.get("args"): + if cfg.get("cls") == "local" and cfg.get("args", {}).get("db"): try: - sqlfiles = get_sql_for_package(modname) + initialized, dbversion = populate_database_for_package( + modname, cfg["args"]["db"] + ) except click.BadParameter: logger.info( "Failed to load/find sql initialization files for %s", modname ) - if sqlfiles: - conninfo = cfg["args"]["db"] - for sqlfile in sqlfiles: - subprocess.check_call( - [ - "psql", - "--quiet", - "--no-psqlrc", - "-v", - "ON_ERROR_STOP=1", - "-d", - conninfo, - "-f", - sqlfile, - ] - ) + click.secho( + "DONE database for {} {} at version {}".format( + modname, "initialized" if initialized else "exists", dbversion + ), + fg="green", + bold=True, + ) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument("module", nargs=-1, required=True) +@click.argument("module", required=True) @click.option( "--db-name", "-d", @@ -117,7 +112,7 @@ default=False, ) def db_init(module, db_name, create_db): - """Initialise a database for the Software Heritage . By + """Initialize a database for the Software Heritage . By default, does not attempt to create the database. Example: @@ -133,39 +128,24 @@ PGPORT=5434 swh db-init indexer """ - # put import statements here so we can keep startup time of the main swh - # command as short as possible - from swh.core.db.tests.db_testing import ( - pg_createdb, - pg_restore, - DB_DUMP_TYPES, - swh_db_version, - ) logger.debug("db_init %s dn_name=%s", module, db_name) - dump_files = [] - - for modname in module: - dump_files.extend(get_sql_for_package(modname)) if create_db: + from swh.core.db.tests.db_testing import pg_createdb + # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) - # Try to retrieve the db version if any - db_version = swh_db_version(db_name) - if not db_version: # Initialize the db - dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) for x in dump_files] - for dump, dtype in dump_files: - click.secho("Loading {}".format(dump), fg="yellow") - pg_restore(db_name, dump, dtype) - db_version = swh_db_version(db_name) + initialized, dbversion = populate_database_for_package(module, db_name) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( - "DONE database is {} version {}".format(db_name, db_version), + "DONE database for {} {} at version {}".format( + module, "initialized" if initialized else "exists", dbversion + ), fg="green", bold=True, ) @@ -189,3 +169,42 @@ "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) + + +def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: + """Populate the database, pointed at with `conninfo`, using the SQL files found in + the package `modname`. + + Args: + modname: Name of the module of which we're loading the files + conninfo: connection info string for the SQL database + Returns: + Tuple with two elements: whether the database has been initialized; the current + version of the database. + """ + import subprocess + from swh.core.db.tests.db_testing import swh_db_version + + current_version = swh_db_version(conninfo) + if current_version is not None: + return False, current_version + + sqlfiles = get_sql_for_package(modname) + + for sqlfile in sqlfiles: + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + "-f", + sqlfile, + ] + ) + + current_version = swh_db_version(conninfo) + return True, current_version diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -42,7 +42,11 @@ try: r = subprocess.run( - cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True + cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, ) result = int(r.stdout.strip()) except Exception: # db not initialized