diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -107,7 +107,10 @@ help="Attempt to create the database.", default=False, ) -def db_init(module, db_name, create_db): +@click.option( + "--flavor", help="Database flavor.", default="default", show_default=True, +) +def db_init(module, db_name, create_db, flavor): """Initialize a database for the Software Heritage . By default, does not attempt to create the database. @@ -125,7 +128,7 @@ """ - logger.debug("db_init %s dn_name=%s", module, db_name) + logger.debug("db_init %s flavor=%s dn_name=%s", module, flavor, db_name) if create_db: from swh.core.db.tests.db_testing import pg_createdb @@ -133,14 +136,16 @@ # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) - initialized, dbversion = populate_database_for_package(module, db_name) + initialized, dbversion = populate_database_for_package(module, db_name, flavor) # TODO: Ideally migrate the version from db_version to the latest # db version click.secho( "DONE database for {} {} at version {}".format( - module, "initialized" if initialized else "exists", dbversion + module, + f"initialized (as flavor {flavor})" if initialized else "exists", + dbversion, ), fg="green", bold=True, @@ -168,13 +173,16 @@ return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) -def populate_database_for_package(modname: str, conninfo: str) -> Tuple[bool, int]: +def populate_database_for_package( + modname: str, conninfo: str, flavor: str = "default" +) -> Tuple[bool, int]: """Populate the database, pointed at with `conninfo`, using the SQL files found in the package `modname`. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database + flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with two elements: whether the database has been initialized; the current version of the database. @@ -188,20 +196,33 @@ return False, current_version sqlfiles = get_sql_for_package(modname) - + psql_command = [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-d", + conninfo, + ] + + flavor_set = False for sqlfile in sqlfiles: - subprocess.check_call( - [ - "psql", - "--quiet", - "--no-psqlrc", - "-v", - "ON_ERROR_STOP=1", - "-d", - conninfo, - "-f", - sqlfile, - ] + subprocess.check_call(psql_command + ["-f", sqlfile]) + + logger.info("Loading file %s", sqlfile) + + if sqlfile.endswith("-flavor.sql"): + logger.info("Setting database flavor %s", flavor) + query = f"insert into dbflavor (flavor) values ('{flavor}')" + subprocess.check_call(psql_command + ["-c", query]) + flavor_set = True + + if not flavor_set and flavor != "default": + logger.warn( + "Asked for flavor %s, but module %s does not support database flavors", + flavor, + modname, ) current_version = swh_db_version(conninfo)