diff --git a/PKG-INFO b/PKG-INFO index c633be4..862c2ca 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,39 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.9.0 +Version: 2.10 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server diff --git a/docs/db.rst b/docs/db.rst index b17a9cb..930f121 100644 --- a/docs/db.rst +++ b/docs/db.rst @@ -1,176 +1,176 @@ .. _swh-core-db: Common database utilities ========================= The ``swh.core.db`` module offers a set of common (postgresql) database handling utilities and features for other swh packages implementing a `datastore`, aka a service responsible for providing a data store via a common interface which can use a postgresql database as backend. Examples are :mod:`swh.storage` or :mod:`swh.scheduler`. Most of the time, this database-based data storage facility will depend on a data schema (may be based in :mod:`swh.model` or not) and provide a unified interface based on an Python class to abstract access to this datastore. Some packages may implement only a postgresql backend, some may provide more backends. This :mod:`swh.core.db` only deals with the postgresql part and provides common features and tooling to manage the database lifecycle in a consistent and unified way among all the :mod:`swh` packages. It comes with a few command line tools to manage the specific :mod:`swh` package database. As such, most of the database management cli commands require a configuration file holding the database connection information. For example, for the :mod:`swh.storage` package, one will be able to create, initialize and upgrade the postgresql database using simple commands. To create the database and perform superuser initialization steps (see below): .. code-block:: bash $ swh db create storage --dbname=postgresql://superuser:passwd@localhost:5433/test-storage If the database already exists but lacks superuser level initialization steps, you may use: .. code-block:: bash $ swh db init-admin storage --dbname=postgresql://superuser:passwd@localhost:5433/test-storage Then assuming the ``config.yml`` file existence: .. code-block:: yaml storage: cls: postgresql db: host=localhost, port=5433, dbname=test-storage, username=normal-user, password=pwd objstorage: cls: memory then you can run: .. code-block:: bash $ swh db --config-file=config.yml init storage DONE database for storage initialized (flavor default) at version 182 Note: you can define the ``SWH_CONFIG_FILENAME`` environment variable instead of using the ``--config-name`` command line option. or check the actual data model version of this database: .. code-block:: bash $ swh db --config-file=config.yml version storage module: storage flavor: default version: 182 as well as the migration history for the database: .. code-block:: bash $ swh db --config-file=config.yml version --all storage module: storage flavor: default 182 [2022-02-11 15:08:31.806070+01:00] Work In Progress 181 [2022-02-11 14:06:27.435010+01:00] Work In Progress The database migration is done using the ``swh db upgrade`` command. Implementation of a swh.core.db datastore ----------------------------------------- To use this database management tooling, in a :mod:`swh` package, the following conditions are expected: - the package should provide an ``sql`` directory in its root namespace providing initialization sql scripts. Scripts should be named like ``nn-xxx.sql`` and are executed in order according to the ``nn`` integer value. Scripts having ``-superuser-`` in their name will be executed by the ``init-admin`` tool and are expected to require superuser access level, whereas scripts without ``-superuser-`` in their name will be executed by the ``swh db init`` command and are expected to require write access level (with no need for superuser access level). - the package should provide a ``sql/upgrade`` directory with SQL migration scripts in its root namespace. Script names are expected to be of the form ``nnn.sql`` where `nnn` is the version to which this script does the migration from a database at version `nnn - 1`. - the initialization and migration scripts should not create nor fill the metadata related tables (``dbversion`` and ``dbmodule``). - the package should provide a ``get_datastore`` function in its root namespace returning an instance of the datastore object. Normally, this datastore object uses ``swh.core.db.BaseDb`` to interact with the actual database. -- The datastore object should provide a ``get_current_version()`` method - returning the database version expected by the code. +- The datastore object should provide a ``current_version`` attribute returning the + database version expected by the code. See existing ``swh`` packages like ``swh.storage`` or ``swh.scheduler`` for usage examples. Writing tests ------------- The ``swh.core.db.pytest_plugin`` provides a few helper tools to write unit tests for postgresql based datastores. By default, when using these fixtures, a posgresql server will be started (by the pytest_postgresql fixture) and a template database will be created using the ``postgresql_proc`` fixture factory provided by ``pytest_postgresql``. Then a dedicated fixture must be declared to use the ``postgresql_proc`` fixture generated by the fixture factory function. This template database will then be used to create a new database for test using this dedicated fixture. In order to help the database initialization process and make it consistent with the database initialization tools from the ``swh db`` cli, an ``initialize_database_for_module()`` function is provided to be used with the fixture factory described above. Typically, writing tests for a ``swh`` package ``swh.example`` would look like: .. code-block:: python from functools import partial from pytest_postgresql import factories from swh.core.db.pytest_plugin import postgresql_fact from swh.core.db.pytest_plugin import initialize_database_for_module example_postgresql_proc = factories.postgresql_proc( dbname="example", load=[partial(initialize_database_for_module, modname="example", version=1)] ) postgresql_example = postgresql_fact("example_postgresql_proc") def test_example(postgresql_example): with postgresql_example.cursor() as c: c.execute("select version from dbversion limit 1") assert c.fecthone()[0] == 1 Note: most of the time, you will want to put the scaffolding part of the code above in a ``conftest.py`` file. The ``load`` argument of the ``factories.postgresql_proc`` will be used to initialize the template database that will be used to create a new database for each test, while the ``load`` argument of the ``postgresql_fact`` fixture will be executed before each test (in the database created from the template database and dedicated to the test being executed). diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index c633be4..862c2ca 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,39 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.9.0 +Version: 2.10 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index d78f9ec..d28a0b8 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,664 +1,691 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime, timezone import functools from importlib import import_module import logging from os import path import pathlib import re import subprocess from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast import psycopg2 import psycopg2.errors import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def now(): return datetime.now(tz=timezone.utc) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value @contextmanager def connect_to_conninfo( db_or_conninfo: Union[str, pgconnection] ) -> Iterator[pgconnection]: """Connect to the database passed as argument. Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle or None if the database is not initialized """ if isinstance(db_or_conninfo, pgconnection): yield db_or_conninfo else: if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" try: db = psycopg2.connect(db_or_conninfo) except psycopg2.Error: logger.exception("Failed to connect to `%s`", db_or_conninfo) else: yield db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) result = c.fetchone() if result: return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_versions( db_or_conninfo: Union[str, pgconnection] ) -> Optional[List[Tuple[int, datetime, str]]]: """Retrieve the swh version history of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = ( "select version, release, description " "from dbversion order by dbversion desc" ) try: c.execute(query) return cast(List[Tuple[int, datetime, str]], c.fetchall()) except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None def swh_db_upgrade( conninfo: str, modname: str, to_version: Optional[int] = None ) -> int: """Upgrade the database at `conninfo` for module `modname` This will run migration scripts found in the `sql/upgrades` subdirectory of the module `modname`. By default, this will upgrade to the latest declared version. Args: conninfo: A database connection, or a database connection info string modname: datastore module the database stores content for to_version: if given, update the database to this version rather than the latest """ if to_version is None: to_version = 99999999 db_module, db_version, db_flavor = get_database_info(conninfo) if db_version is None: raise ValueError("Unable to retrieve the current version of the database") if db_module is None: raise ValueError("Unable to retrieve the module of the database") if db_module != modname: raise ValueError( "The stored module of the database is different than the given one" ) sqlfiles = [ fname for fname in get_sql_for_package(modname, upgrade=True) if db_version < int(fname.stem) <= to_version ] if not sqlfiles: return db_version for sqlfile in sqlfiles: new_version = int(path.splitext(path.basename(sqlfile))[0]) logger.info("Executing migration script '%s'", sqlfile) if db_version is not None and (new_version - db_version) > 1: logger.error( f"There are missing migration steps between {db_version} and " f"{new_version}. It might be expected but it most unlikely is not. " "Will stop here." ) return db_version execute_sqlfiles([sqlfile], conninfo, db_flavor) # check if the db version has been updated by the upgrade script db_version = swh_db_version(conninfo) assert db_version is not None if db_version == new_version: # nothing to do, upgrade script did the job pass elif db_version == new_version - 1: # it has not (new style), so do it swh_set_db_version( conninfo, new_version, desc=f"Upgraded to version {new_version} using {sqlfile}", ) db_version = swh_db_version(conninfo) else: # upgrade script did it wrong logger.error( f"The upgrade script {sqlfile} did not update the dbversion table " f"consistently ({db_version} vs. expected {new_version}). " "Will stop migration here. Please check your migration scripts." ) return db_version return new_version def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh module used to create the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the module of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select dbmodule from dbmodule limit 1" try: c.execute(query) resp = c.fetchone() if resp: return resp[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None def swh_set_db_module( db_or_conninfo: Union[str, pgconnection], module: str, force=False ) -> None: """Set the swh module used to create the database. Fails if the dbmodule is already set or the table does not exist. Args: db_or_conninfo: A database connection, or a database connection info string module: the swh module to register (without the leading 'swh.') """ update = False if module.startswith("swh."): module = module[4:] current_module = swh_db_module(db_or_conninfo) if current_module is not None: if current_module == module: logger.warning("The database module is already set to %s", module) return if not force: raise ValueError( "The database module is already set to a value %s " "different than given %s", current_module, module, ) # force is True update = True with connect_to_conninfo(db_or_conninfo) as db: if not db: return None sqlfiles = [ fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname.stem ] execute_sqlfiles(sqlfiles, db_or_conninfo) with db.cursor() as c: if update: query = "update dbmodule set dbmodule = %s" else: query = "insert into dbmodule(dbmodule) values (%s)" c.execute(query, (module,)) db.commit() def swh_set_db_version( db_or_conninfo: Union[str, pgconnection], version: int, ts: Optional[datetime] = None, desc: str = "Work in progress", ) -> None: """Set the version of the database. Fails if the dbversion table does not exists. Args: db_or_conninfo: A database connection, or a database connection info string version: the version to add """ if ts is None: ts = now() with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = ( "insert into dbversion(version, release, description) " "values (%s, %s, %s)" ) c.execute(query, (version, ts, desc)) db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) result = c.fetchone() assert result is not None # to keep mypy happy return result[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(rb"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]: """Return the (sorted) list of sql script files for the given swh module If upgrade is True, return the list of available migration scripts, otherwise, return the list of initialization scripts. """ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = pathlib.Path(m.__file__).parent / "sql" if upgrade: sqldir /= "upgrades" if not sqldir.is_dir(): raise ValueError( "Module {} does not provide a db schema (no sql/ dir)".format(modname) ) return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor def globalsortkey(key): "like sortkey but only on basenames" return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem)) sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table swh_set_db_module(conninfo, modname) current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) return True, current_db_version, dbflavor +def initialize_database_for_module(modname: str, version: int, **kwargs): + """Helper function to initialize and populate a database for the given module + + This aims at helping the usage of pytest_postgresql for swh.core.db based datastores. + Typical usage will be (here for swh.storage):: + + from pytest_postgresql import factories + + storage_postgresql_proc = factories.postgresql_proc( + load=[partial(initialize_database_for_module, modname="storage", version=42)] + ) + storage_postgresql = factories.postgresql("storage_postgresql_proc") + + """ + conninfo = psycopg2.connect(**kwargs).dsn + init_admin_extensions(modname, conninfo) + populate_database_for_package(modname, conninfo) + try: + swh_set_db_version(conninfo, version) + except psycopg2.errors.UniqueViolation: + logger.warn( + "Version already set by db init scripts. " + f"This generally means the swh.{modname} package needs to be " + "updated for swh.core>=1.2" + ) + + def get_database_info( conninfo: str, ) -> Tuple[Optional[str], Optional[int], Optional[str]]: """Get version, flavor and module of the db""" dbmodule = swh_db_module(conninfo) dbversion = swh_db_version(conninfo) dbflavor = None if dbversion is not None: dbflavor = swh_db_flavor(conninfo) return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[pathlib.Path], db_or_conninfo: Union[str, pgconnection], flavor: Optional[str] = None, ): """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``. Args: sqlfiles: List of SQL files to execute db_or_conninfo: A database connection, or a database connection info string flavor: the database flavor to initialize """ if isinstance(db_or_conninfo, str): conninfo = db_or_conninfo else: conninfo = db_or_conninfo.dsn psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", str(sqlfile)]) if ( flavor is not None and not flavor_set and sqlfile.name.endswith("-flavor.sql") ): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index e12a0f3..23b0609 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,281 +1,276 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob from importlib import import_module import logging import subprocess from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Set, Union +import warnings from _pytest.fixtures import FixtureRequest +from deprecated import deprecated import psycopg2 import pytest from pytest_postgresql.compat import check_for_psycopg2, connection from pytest_postgresql.executor import PostgreSQLExecutor from pytest_postgresql.executor_noop import NoopExecutor from pytest_postgresql.janitor import DatabaseJanitor -from swh.core.db.db_utils import ( - init_admin_extensions, - populate_database_for_package, - swh_set_db_version, -) +from swh.core.db.db_utils import initialize_database_for_module from swh.core.utils import basename_sortkey # to keep mypy happy regardless pytest-postgresql version try: _pytest_pgsql_get_config_module = import_module("pytest_postgresql.config") except ImportError: # pytest_postgresql < 3.0.0 _pytest_pgsql_get_config_module = import_module("pytest_postgresql.factories") _pytest_postgresql_get_config = getattr(_pytest_pgsql_get_config_module, "get_config") logger = logging.getLogger(__name__) +initialize_database_for_module = deprecated( + version="2.10", + reason="Use swh.core.db.db_utils.initialize_database_for_module instead.", +)(initialize_database_for_module) + +warnings.warn( + "This pytest plugin is deprecated, it should not be used any more.", + category=DeprecationWarning, +) + class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: int, dbname: str, version: Union[str, float], password: Optional[str] = None, isolation_level: Optional[int] = None, connection_timeout: int = 60, dump_files: Optional[Union[str, Sequence[str]]] = None, no_truncate_tables: Set[str] = set(), no_db_drop: bool = False, ) -> None: super().__init__(user, host, port, dbname, version) # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) self.no_db_drop = no_db_drop self.dump_files = dump_files def psql_exec(self, fname: str) -> None: conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.dbname}" ) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self) -> None: """Truncate tables (all but self.no_truncate_tables set) and sequences""" with psycopg2.connect( dbname=self.dbname, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def _db_exists(self, cur, dbname): cur.execute( "SELECT EXISTS " "(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);", (dbname,), ) row = cur.fetchone() return (row is not None) and row[0] def init(self) -> None: """Create database in postgresql out of a template it if it exists, bare creation otherwise.""" template_name = f"{self.dbname}_tmpl" logger.debug("Initialize DB %s", self.dbname) with self.cursor() as cur: tmpl_exists = self._db_exists(cur, template_name) db_exists = self._db_exists(cur, self.dbname) if not db_exists: if tmpl_exists: logger.debug( "Create %s from template %s", self.dbname, template_name ) cur.execute( f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";' ) else: logger.debug("Create %s from scratch", self.dbname) cur.execute(f'CREATE DATABASE "{self.dbname}";') if self.dump_files: logger.warning( "Using dump_files on the postgresql_fact fixture " "is deprecated. See swh.core documentation for more " "details." ) for dump_file in gen_dump_files(self.dump_files): logger.info(f"Loading {dump_file}") self.psql_exec(dump_file) else: logger.debug("Reset %s", self.dbname) self.db_reset() def drop(self) -> None: """Drop database in postgresql.""" if self.no_db_drop: with self.cursor() as cur: self._terminate_connection(cur, self.dbname) else: super().drop() # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. +@deprecated(version="2.10", reason="Use stock pytest_postgresql factory instead") def postgresql_fact( process_fixture_name: str, dbname: Optional[str] = None, load: Optional[Sequence[Union[Callable, str]]] = None, isolation_level: Optional[int] = None, modname: Optional[str] = None, dump_files: Optional[Union[str, List[str]]] = None, no_truncate_tables: Set[str] = {"dbversion"}, no_db_drop: bool = False, ) -> Callable[[FixtureRequest], Iterator[connection]]: """ Return connection fixture factory for PostgreSQL. :param process_fixture_name: name of the process fixture :param dbname: database name :param load: SQL, function or function import paths to automatically load into our test database :param isolation_level: optional postgresql isolation level defaults to server's default :param modname: (swh) module name for which the database is created :dump_files: (deprecated, use load instead) list of sql script files to execute after the database has been created :no_truncate_tables: list of table not to truncate between tests (only used when no_db_drop is True) :no_db_drop: if True, keep the database between tests; in which case, the database is reset (see SWHDatabaseJanitor.db_reset()) by truncating most of the tables. Note that this makes de facto tests (potentially) interdependent, use with extra caution. :returns: function which makes a connection to postgresql """ @pytest.fixture def postgresql_factory(request: FixtureRequest) -> Iterator[connection]: """ Fixture factory for PostgreSQL. :param request: fixture request object :returns: postgresql client """ check_for_psycopg2() proc_fixture: Union[PostgreSQLExecutor, NoopExecutor] = request.getfixturevalue( process_fixture_name ) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_password = proc_fixture.password pg_options = proc_fixture.options pg_db = dbname or proc_fixture.dbname pg_load = load or [] assert pg_db is not None with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level=isolation_level, dump_files=dump_files, no_truncate_tables=no_truncate_tables, no_db_drop=no_db_drop, ) as janitor: db_connection: connection = psycopg2.connect( dbname=pg_db, user=pg_user, password=pg_password, host=pg_host, port=pg_port, options=pg_options, ) for load_element in pg_load: janitor.load(load_element) try: yield db_connection finally: db_connection.close() return postgresql_factory -def initialize_database_for_module(modname, version, **kwargs): - conninfo = psycopg2.connect(**kwargs).dsn - init_admin_extensions(modname, conninfo) - populate_database_for_package(modname, conninfo) - try: - swh_set_db_version(conninfo, version) - except psycopg2.errors.UniqueViolation: - logger.warn( - "Version already set by db init scripts. " - "This generally means the swh.{modname} package needs to be " - "updated for swh.core>=1.2" - ) - - def gen_dump_files(dump_files: Union[str, Iterable[str]]) -> Iterator[str]: """Generate files potentially resolving glob patterns if any""" if isinstance(dump_files, str): dump_files = [dump_files] for dump_file in dump_files: if glob.has_magic(dump_file): # if the dump_file is a glob pattern one, resolve it yield from ( fname for fname in sorted(glob.glob(dump_file), key=basename_sortkey) ) else: # otherwise, just return the filename yield dump_file diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 726f1a1..2d12707 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,466 +1,466 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass import datetime from enum import IntEnum import inspect from string import printable from typing import Any from unittest.mock import MagicMock, Mock import uuid from hypothesis import given, settings, strategies from hypothesis.extra.pytz import timezones import psycopg2 import pytest +from pytest_postgresql import factories from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator -from swh.core.db.pytest_plugin import postgresql_fact from swh.core.db.tests.conftest import function_scoped_fixture_check # workaround mypy bug https://github.com/python/mypy/issues/5485 class Converter(Protocol): def __call__(self, x: Any) -> Any: ... @dataclass class Field: name: str """Column name""" pg_type: str """Type of the PostgreSQL column""" example: Any """Example value for the static tests""" strategy: strategies.SearchStrategy """Hypothesis strategy to generate these values""" in_wrapper: Converter = lambda x: x """Wrapper to convert this data type for the static tests""" out_converter: Converter = lambda x: x """Converter from the raw PostgreSQL column value to this data type""" # Limit PostgreSQL integer values pg_int = strategies.integers(-2147483648, +2147483647) pg_text = strategies.text( alphabet=strategies.characters( blacklist_categories=["Cs"], # surrogates blacklist_characters=[ "\x00", # pgsql does not support the null codepoint "\r", # pgsql normalizes those ], ), ) pg_bytea = strategies.binary() def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[]""" return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( lambda n: strategies.lists( pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size ) ) def pg_tstz() -> strategies.SearchStrategy: """Generate values that fit in a PostgreSQL timestamptz. Notes: We're forbidding old datetimes, because until 1956, many timezones had seconds in their "UTC offsets" (see ), which is not representable by PostgreSQL. """ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) return strategies.datetimes(min_value=min_value, timezones=timezones()) def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate values representable as a PostgreSQL jsonb object (dict).""" return strategies.dictionaries( strategies.text(printable), strategies.recursive( # should use floats() instead of integers(), but PostgreSQL # coerces large integers into floats, making the tests fail. We # only store ints in our generated data anyway. strategies.none() | strategies.booleans() | strategies.integers(-2147483648, +2147483647) | strategies.text(printable), lambda children: strategies.lists(children, max_size=max_size) | strategies.dictionaries( strategies.text(printable), children, max_size=max_size ), ), min_size=min_size, max_size=max_size, ) def tuple_2d_to_list_2d(v): """Convert a 2D tuple to a 2D list""" return [list(inner) for inner in v] def list_2d_to_tuple_2d(v): """Convert a 2D list to a 2D tuple""" return tuple(tuple(inner) for inner in v) class TestIntEnum(IntEnum): foo = 1 bar = 2 def now(): return datetime.datetime.now(tz=datetime.timezone.utc) FIELDS = ( Field("i", "int", 1, pg_int), Field("txt", "text", "foo", pg_text), Field("bytes", "bytea", b"bar", strategies.binary()), Field( "bytes_array", "bytea[]", [b"baz1", b"baz2"], pg_bytea_a(min_size=0, max_size=5), ), Field( "bytes_tuple", "bytea[]", (b"baz1", b"baz2"), pg_bytea_a(min_size=0, max_size=5).map(tuple), in_wrapper=list, out_converter=tuple, ), Field( "bytes_2d", "bytea[][]", [[b"quux1"], [b"quux2"]], pg_bytea_a_a(min_size=0, max_size=5), ), Field( "bytes_2d_tuple", "bytea[][]", ((b"quux1",), (b"quux2",)), pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), in_wrapper=tuple_2d_to_list_2d, out_converter=list_2d_to_tuple_2d, ), Field( "ts", "timestamptz", now(), pg_tstz(), ), Field( "dict", "jsonb", {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, pg_jsonb(min_size=0, max_size=5), in_wrapper=psycopg2.extras.Json, ), Field( "intenum", "int", TestIntEnum.foo, strategies.sampled_from(TestIntEnum), in_wrapper=int, out_converter=lambda x: TestIntEnum(x), # lambda needed by mypy ), Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), Field( "text_list", "text[]", # All the funky corner cases ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], strategies.lists(pg_text, min_size=0, max_size=5), ), Field( "tstz_list", "timestamptz[]", [now(), now() + datetime.timedelta(days=1)], strategies.lists(pg_tstz(), min_size=0, max_size=5), ), Field( "tstz_range", "tstzrange", psycopg2.extras.DateTimeTZRange( lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", ), strategies.tuples( # generate two sorted timestamptzs for use as bounds strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), # and a set of bounds strategies.sampled_from(["[]", "()", "[)", "(]"]), ).map( # and build the actual DateTimeTZRange object from these args lambda args: psycopg2.extras.DateTimeTZRange( lower=args[0][0], upper=args[0][1], bounds=args[1], ) ), ), ) INIT_SQL = "create table test_table (%s)" % ", ".join( f"{field.name} {field.pg_type}" for field in FIELDS ) COLUMNS = tuple(field.name for field in FIELDS) INSERT_SQL = "insert into test_table (%s) values (%s)" % ( ", ".join(COLUMNS), ", ".join("%s" for i in range(len(COLUMNS))), ) STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) def convert_lines(cur): return [ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur ] -test_db = postgresql_fact("postgresql_proc", dbname="test-db2") +test_db = factories.postgresql("postgresql_proc", dbname="test-db2") @pytest.fixture def db_with_data(test_db, request): """Fixture to initialize a db with some data out of the "INIT_SQL above""" db = BaseDb.connect(test_db.dsn) with db.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) yield db db.conn.rollback() db.conn.close() @pytest.mark.db def test_db_connect(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_initialized(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_copy_to_static(db_with_data): items = [{field.name: field.example for field in FIELDS}] db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] @settings(suppress_health_check=function_scoped_fixture_check, max_examples=5) @given(db_rows) def test_db_copy_to(db_with_data, data): items = [dict(zip(COLUMNS, item)) for item in data] with db_with_data.cursor() as cur: cur.execute("TRUNCATE TABLE test_table CASCADE") db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") converted_lines = convert_lines(cur) assert converted_lines == data def test_db_copy_to_thread_exception(db_with_data): data = [(2**65, "foo", b"bar")] items = [dict(zip(COLUMNS, item)) for item in data] with pytest.raises(psycopg2.errors.NumericValueOutOfRange): db_with_data.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig @pytest.mark.parametrize( "query_options", (None, {"something": 42, "statement_timeout": 200}) ) @pytest.mark.parametrize("use_generator", (True, False)) def test_db_transaction_query_options(mocker, use_generator, query_options): class Storage: @db_transaction(statement_timeout=100) def endpoint(self, cur=None, db=None): return [None] @db_transaction_generator(statement_timeout=100) def gen_endpoint(self, cur=None, db=None): yield None storage = Storage() # mockers mocked_apply = mocker.patch("swh.core.db.common.apply_options") # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' expected_cur = object() db_mock = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) mocker.patch.object(storage, "put_db", create=True) if query_options: storage.query_options = { "endpoint": query_options, "gen_endpoint": query_options, } if use_generator: list(storage.gen_endpoint()) else: list(storage.endpoint()) mocked_apply.assert_called_once_with( expected_cur, query_options if query_options is not None else {"statement_timeout": 100}, ) diff --git a/swh/core/github/tests/test_github_utils.py b/swh/core/github/tests/test_github_utils.py index d9d940c..c7b7087 100644 --- a/swh/core/github/tests/test_github_utils.py +++ b/swh/core/github/tests/test_github_utils.py @@ -1,199 +1,205 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import pytest from swh.core.github.pytest_plugin import HTTP_GITHUB_API_URL from swh.core.github.utils import ( GitHubSession, _sanitize_github_url, _url_github_api, get_canonical_github_origin_url, ) KNOWN_GH_REPO = "https://github.com/user/repo" -def _url_github_html(user_repo: str, protocol: str = "https") -> str: - """Given the user repo, returns the expected github html url.""" - return f"{protocol}://github.com/{user_repo}" - - @pytest.mark.parametrize( "user_repo, expected_url", [ ("user/repo.git", KNOWN_GH_REPO), ("user/repo.git/", KNOWN_GH_REPO), ("user/repo/", KNOWN_GH_REPO), ("user/repo", KNOWN_GH_REPO), ("user/repo/.git", KNOWN_GH_REPO), ("unknown/page", None), # unknown gh origin returns None ("user/with/deps", None), # url kind is not dealt with ], ) def test_get_canonical_github_origin_url( user_repo, expected_url, requests_mock, github_credentials ): """It should return a canonical github origin when it exists, None otherwise""" - for protocol in ["https", "git", "http"]: - html_input_url = _url_github_html(user_repo, protocol=protocol) - html_url = _url_github_html(user_repo) - api_url = _url_github_api(_sanitize_github_url(user_repo)) - - if expected_url is not None: - status_code = 200 - response = {"html_url": _sanitize_github_url(html_url)} - else: - status_code = 404 - response = {} - - requests_mock.get(api_url, [{"status_code": status_code, "json": response}]) - - # anonymous - assert get_canonical_github_origin_url(html_input_url) == expected_url - - # with credentials - assert ( - get_canonical_github_origin_url( - html_input_url, credentials=github_credentials + for separator in ["/", ":"]: + for prefix in [ + "http://", + "https://", + "git://", + "ssh://", + "//", + "git@", + "ssh://git@", + "https://${env.GITHUB_TOKEN_USR}:${env.GITHUB_TOKEN_PSW}@", + "[fetch=]git@", + ]: + html_input_url = f"{prefix}github.com{separator}{user_repo}" + html_url = f"https://github.com/{user_repo}" + api_url = _url_github_api(_sanitize_github_url(user_repo)) + + if expected_url is not None: + status_code = 200 + response = {"html_url": _sanitize_github_url(html_url)} + else: + status_code = 404 + response = {} + + requests_mock.get(api_url, [{"status_code": status_code, "json": response}]) + + # anonymous + assert get_canonical_github_origin_url(html_input_url) == expected_url + + # with credentials + assert ( + get_canonical_github_origin_url( + html_input_url, credentials=github_credentials + ) + == expected_url + ) + + # anonymous + assert ( + GitHubSession( + user_agent="GitHub Session Test", + ).get_canonical_url(html_input_url) + == expected_url + ) + + # with credentials + assert ( + GitHubSession( + user_agent="GitHub Session Test", credentials=github_credentials + ).get_canonical_url(html_input_url) + == expected_url ) - == expected_url - ) - - # anonymous - assert ( - GitHubSession( - user_agent="GitHub Session Test", - ).get_canonical_url(html_input_url) - == expected_url - ) - - # with credentials - assert ( - GitHubSession( - user_agent="GitHub Session Test", credentials=github_credentials - ).get_canonical_url(html_input_url) - == expected_url - ) def test_get_canonical_github_origin_url_not_gh_origin(): """It should return the input url when that origin is not a github one""" url = "https://example.org" assert get_canonical_github_origin_url(url) == url assert ( GitHubSession( user_agent="GitHub Session Test", ).get_canonical_url(url) == url ) def test_github_session_anonymous_session(): user_agent = ("GitHub Session Test",) github_session = GitHubSession( user_agent=user_agent, ) assert github_session.anonymous is True actual_headers = github_session.session.headers assert actual_headers["Accept"] == "application/vnd.github.v3+json" assert actual_headers["User-Agent"] == user_agent @pytest.mark.parametrize( "num_ratelimit", [1] # return a single rate-limit response, then continue ) def test_github_session_ratelimit_once_recovery( caplog, requests_ratelimited, num_ratelimit, monkeypatch_sleep_calls, github_credentials, ): """GitHubSession should recover from hitting the rate-limit once""" caplog.set_level(logging.DEBUG, "swh.core.github.utils") github_session = GitHubSession( user_agent="GitHub Session Test", credentials=github_credentials ) res = github_session.request(f"{HTTP_GITHUB_API_URL}?per_page=1000&since=10") assert res.status_code == 200 token_users = [] for record in caplog.records: if "Using authentication token" in record.message: token_users.append(record.args[0]) # check that we used one more token than we saw rate limited requests assert len(token_users) == 1 + num_ratelimit # check that we slept for one second between our token uses assert monkeypatch_sleep_calls == [1] def test_github_session_authenticated_credentials( caplog, github_credentials, all_tokens ): """GitHubSession should have Authorization headers set in authenticated mode""" caplog.set_level(logging.DEBUG, "swh.core.github.utils") github_session = GitHubSession( "GitHub Session Test", credentials=github_credentials ) assert github_session.anonymous is False assert github_session.token_index == 0 assert ( sorted(github_session.credentials, key=lambda t: t["username"]) == github_credentials ) assert github_session.session.headers["Authorization"] in [ f"token {t}" for t in all_tokens ] @pytest.mark.parametrize( # Do 5 successful requests, return 6 ratelimits (to exhaust the credentials) with a # set value for X-Ratelimit-Reset, then resume listing successfully. "num_before_ratelimit, num_ratelimit, ratelimit_reset", [(5, 6, 123456)], ) def test_github_session_ratelimit_reset_sleep( caplog, requests_ratelimited, monkeypatch_sleep_calls, num_before_ratelimit, num_ratelimit, ratelimit_reset, github_credentials, ): """GitHubSession should handle rate-limit with authentication tokens.""" caplog.set_level(logging.DEBUG, "swh.core.github.utils") github_session = GitHubSession( user_agent="GitHub Session Test", credentials=github_credentials ) for _ in range(num_ratelimit): github_session.request(f"{HTTP_GITHUB_API_URL}?per_page=1000&since=10") # We sleep 1 second every time we change credentials, then we sleep until # ratelimit_reset + 1 expected_sleep_calls = len(github_credentials) * [1] + [ratelimit_reset + 1] assert monkeypatch_sleep_calls == expected_sleep_calls found_exhaustion_message = False for record in caplog.records: if record.levelname == "INFO": if "Rate limits exhausted for all tokens" in record.message: found_exhaustion_message = True break assert found_exhaustion_message is True diff --git a/swh/core/github/utils.py b/swh/core/github/utils.py index 867b2e4..80ffa2b 100644 --- a/swh/core/github/utils.py +++ b/swh/core/github/utils.py @@ -1,225 +1,227 @@ # Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import random import re import time from typing import Dict, List, Optional import requests from tenacity import ( retry, retry_any, retry_if_exception_type, retry_if_result, wait_exponential, ) -GITHUB_PATTERN = re.compile(r"(git|https?)://github.com/(?P.*)") +GITHUB_PATTERN = re.compile( + r"(//|git://|git@|git//|https?://|ssh://|.*@)github.com[/:](?P.*)" +) logger = logging.getLogger(__name__) def _url_github_api(user_repo: str) -> str: """Given the user_repo, returns the expected github api url.""" return f"https://api.github.com/repos/{user_repo}" def _sanitize_github_url(url: str) -> str: """Sanitize github url.""" return url.lower().rstrip("/").rstrip(".git").rstrip("/") def get_canonical_github_origin_url( url: str, credentials: Optional[List[Dict[str, str]]] = None ) -> Optional[str]: """Retrieve canonical github url out of an url if any or None otherwise. This triggers an http request to the github api url to determine the canonical repository url (if no credentials is provided, the http request is anonymous. Either way that request can be rate-limited by github.) """ return GitHubSession( user_agent="SWH core library", credentials=credentials ).get_canonical_url(url) class RateLimited(Exception): def __init__(self, response): self.reset_time: Optional[int] # Figure out how long we need to sleep because of that rate limit ratelimit_reset = response.headers.get("X-Ratelimit-Reset") retry_after = response.headers.get("Retry-After") if ratelimit_reset is not None: self.reset_time = int(ratelimit_reset) elif retry_after is not None: self.reset_time = int(time.time()) + int(retry_after) + 1 else: logger.warning( "Received a rate-limit-like status code %s, but no rate-limit " "headers set. Response content: %s", response.status_code, response.content, ) self.reset_time = None self.response = response class MissingRateLimitReset(Exception): pass class GitHubSession: """Manages a :class:`requests.Session` with (optionally) multiple credentials, and cycles through them when reaching rate-limits.""" credentials: Optional[List[Dict[str, str]]] = None def __init__( self, user_agent: str, credentials: Optional[List[Dict[str, str]]] = None ) -> None: """Initialize a requests session with the proper headers for requests to GitHub.""" if credentials: creds = credentials.copy() random.shuffle(creds) self.credentials = creds self.session = requests.Session() self.session.headers.update( {"Accept": "application/vnd.github.v3+json", "User-Agent": user_agent} ) self.anonymous = not self.credentials if self.anonymous: logger.warning("No tokens set in configuration, using anonymous mode") self.token_index = -1 self.current_user: Optional[str] = None if not self.anonymous: # Initialize the first token value in the session headers self.set_next_session_token() def set_next_session_token(self) -> None: """Update the current authentication token with the next one in line.""" assert self.credentials self.token_index = (self.token_index + 1) % len(self.credentials) auth = self.credentials[self.token_index] self.current_user = auth["username"] logger.debug("Using authentication token for user %s", self.current_user) if "password" in auth: token = auth["password"] else: token = auth["token"] self.session.headers.update({"Authorization": f"token {token}"}) @retry( wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_any( # ChunkedEncodingErrors happen when the TLS connection gets reset, e.g. # when running the lister on a connection with high latency retry_if_exception_type(requests.exceptions.ChunkedEncodingError), # 502 status codes happen for a Server Error, sometimes retry_if_result(lambda r: r.status_code == 502), ), ) def _request(self, url: str) -> requests.Response: response = self.session.get(url) if ( # GitHub returns inconsistent status codes between unauthenticated # rate limit and authenticated rate limits. Handle both. response.status_code == 429 or (self.anonymous and response.status_code == 403) ): raise RateLimited(response) return response def request(self, url) -> requests.Response: """Repeatedly requests the given URL, cycling through credentials and sleeping if necessary; until either a successful response or :exc:`MissingRateLimitReset` """ # The following for/else loop handles rate limiting; if successful, # it provides the rest of the function with a `response` object. # # If all tokens are rate-limited, we sleep until the reset time, # then `continue` into another iteration of the outer while loop, # attempting to get data from the same URL again. while True: max_attempts = len(self.credentials) if self.credentials else 1 reset_times: Dict[int, int] = {} # token index -> time for attempt in range(max_attempts): try: return self._request(url) except RateLimited as e: reset_info = "(unknown reset)" if e.reset_time is not None: reset_times[self.token_index] = e.reset_time reset_info = "(resetting in %ss)" % (e.reset_time - time.time()) if not self.anonymous: logger.info( "Rate limit exhausted for current user %s %s", self.current_user, reset_info, ) # Use next token in line self.set_next_session_token() # Wait one second to avoid triggering GitHub's abuse rate limits time.sleep(1) # All tokens have been rate-limited. What do we do? if not reset_times: logger.warning( "No X-Ratelimit-Reset value found in responses for any token; " "Giving up." ) raise MissingRateLimitReset() sleep_time = max(reset_times.values()) - time.time() + 1 logger.info( "Rate limits exhausted for all tokens. Sleeping for %f seconds.", sleep_time, ) time.sleep(sleep_time) def get_canonical_url(self, url: str) -> Optional[str]: """Retrieve canonical github url out of an url if any or None otherwise. This triggers an http request to the github api url to determine the canonical repository url. Returns The canonical url if any, None otherwise. """ url_ = url.lower() match = GITHUB_PATTERN.match(url_) if not match: return url user_repo = _sanitize_github_url(match.groupdict()["user_repo"]) response = self.request(_url_github_api(user_repo)) if response.status_code != 200: return None data = response.json() return data["html_url"]