diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -181,7 +181,7 @@ """ import subprocess - from swh.core.db.tests.db_testing import swh_db_version + from swh.core.db.db_utils import swh_db_version current_version = swh_db_version(conninfo) if current_version is not None: @@ -205,4 +205,5 @@ ) current_version = swh_db_version(conninfo) + assert current_version is not None return True, current_version diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -4,10 +4,15 @@ # See top-level LICENSE file for more information import functools +import logging import re +from typing import Optional, Union +import psycopg2 import psycopg2.extensions +logger = logging.getLogger(__name__) + def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument @@ -37,6 +42,46 @@ return value +def swh_db_version( + db_or_conninfo: Union[str, psycopg2.extensions.connection] +) -> Optional[int]: + """Retrieve the swh version if any. In case of the db not initialized, + this returns None. Otherwise, this returns the db's version. + + Args: + db_or_conninfo: A database connection, or a database connection info string + + Returns: + Optional[Int]: Either the db's version or None + + """ + + if isinstance(db_or_conninfo, psycopg2.extensions.connection): + db = db_or_conninfo + else: + try: + if "=" not in db_or_conninfo: + # Database name + db_or_conninfo = f"dbname={db_or_conninfo}" + db = psycopg2.connect(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + # Database not initialized + return None + + try: + with db.cursor() as c: + query = "select version from dbversion order by dbversion desc limit 1" + try: + c.execute(query) + return c.fetchone()[0] + except psycopg2.errors.UndefinedTable: + return None + except Exception: + logger.exception("Could not get version from `%s`", db_or_conninfo) + return None + + # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -15,43 +15,6 @@ DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] -def swh_db_version(dbname_or_service): - """Retrieve the swh version if any. In case of the db not initialized, - this returns None. Otherwise, this returns the db's version. - - Args: - dbname_or_service (str): The db's name or service - - Returns: - Optional[Int]: Either the db's version or None - - """ - query = "select version from dbversion order by dbversion desc limit 1" - cmd = [ - "psql", - "--tuples-only", - "--no-psqlrc", - "--quiet", - "-v", - "ON_ERROR_STOP=1", - "--command=%s" % query, - dbname_or_service, - ] - - try: - r = subprocess.run( - cmd, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - ) - result = int(r.stdout.strip()) - except Exception: # db not initialized - result = None - return result - - def pg_restore(dbname, dumpfile, dumptype="pg_dump"): """ Args: