diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager from datetime import datetime, timezone import functools from importlib import import_module @@ -11,7 +12,7 @@ import pathlib import re import subprocess -from typing import Collection, Dict, List, Optional, Tuple, Union, cast +from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast import psycopg2 import psycopg2.errors @@ -58,28 +59,32 @@ return value -def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection: - """Connect to the database passed in argument +@contextmanager +def connect_to_conninfo( + db_or_conninfo: Union[str, pgconnection] +) -> Iterator[pgconnection]: + """Connect to the database passed as argument. Args: db_or_conninfo: A database connection, or a database connection info string Returns: - a connected database handle + a connected database handle or None if the database is not initialized - Raises: - psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, pgconnection): - return db_or_conninfo - - if "=" not in db_or_conninfo and "//" not in db_or_conninfo: - # Database name - db_or_conninfo = f"dbname={db_or_conninfo}" - - db = psycopg2.connect(db_or_conninfo) + yield db_or_conninfo + else: + if "=" not in db_or_conninfo and "//" not in db_or_conninfo: + # Database name + db_or_conninfo = f"dbname={db_or_conninfo}" - return db + try: + db = psycopg2.connect(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + else: + yield db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: @@ -94,22 +99,18 @@ Either the version of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select version from dbversion order by dbversion desc limit 1" - try: - c.execute(query) - result = c.fetchone() - if result: - return result[0] - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select version from dbversion order by dbversion desc limit 1" + try: + c.execute(query) + result = c.fetchone() + if result: + return result[0] + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None @@ -129,23 +130,19 @@ Either the version of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = ( - "select version, release, description " - "from dbversion order by dbversion desc" - ) - try: - c.execute(query) - return cast(List[Tuple[int, datetime, str]], c.fetchall()) - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = ( + "select version, release, description " + "from dbversion order by dbversion desc" + ) + try: + c.execute(query) + return cast(List[Tuple[int, datetime, str]], c.fetchall()) + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None @@ -238,22 +235,18 @@ Either the module of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select dbmodule from dbmodule limit 1" - try: - c.execute(query) - resp = c.fetchone() - if resp: - return resp[0] - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select dbmodule from dbmodule limit 1" + try: + c.execute(query) + resp = c.fetchone() + if resp: + return resp[0] + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None @@ -289,27 +282,25 @@ ) # force is True update = True - try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - sqlfiles = [ - fname - for fname in get_sql_for_package("swh.core.db") - if "dbmodule" in fname.stem - ] - execute_sqlfiles(sqlfiles, db_or_conninfo) + with connect_to_conninfo(db_or_conninfo) as db: + if not db: + return None - with db.cursor() as c: - if update: - query = "update dbmodule set dbmodule = %s" - else: - query = "insert into dbmodule(dbmodule) values (%s)" - c.execute(query, (module,)) - db.commit() + sqlfiles = [ + fname + for fname in get_sql_for_package("swh.core.db") + if "dbmodule" in fname.stem + ] + execute_sqlfiles(sqlfiles, db_or_conninfo) + + with db.cursor() as c: + if update: + query = "update dbmodule set dbmodule = %s" + else: + query = "insert into dbmodule(dbmodule) values (%s)" + c.execute(query, (module,)) + db.commit() def swh_set_db_version( @@ -326,20 +317,19 @@ db_or_conninfo: A database connection, or a database connection info string version: the version to add """ - try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None if ts is None: ts = now() - with db.cursor() as c: - query = ( - "insert into dbversion(version, release, description) values (%s, %s, %s)" - ) - c.execute(query, (version, ts, desc)) - db.commit() + + with connect_to_conninfo(db_or_conninfo) as db: + if not db: + return None + with db.cursor() as c: + query = ( + "insert into dbversion(version, release, description) " + "values (%s, %s, %s)" + ) + c.execute(query, (version, ts, desc)) + db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: @@ -355,23 +345,19 @@ The flavor of the database, or None if it could not be detected. """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select swh_get_dbflavor()" - try: - c.execute(query) - result = c.fetchone() - assert result is not None # to keep mypy happy - return result[0] - except psycopg2.errors.UndefinedFunction: - # function not found: no flavor + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select swh_get_dbflavor()" + try: + c.execute(query) + result = c.fetchone() + assert result is not None # to keep mypy happy + return result[0] + except psycopg2.errors.UndefinedFunction: + # function not found: no flavor + return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None