diff --git a/Makefile.local b/Makefile.local index cc21b3f..1f1890f 100644 --- a/Makefile.local +++ b/Makefile.local @@ -1 +1 @@ -TEST_DIRS := ./swh/core/api/tests ./swh/core/db/tests ./swh/core/tests +TEST_DIRS := ./swh/core/api/tests ./swh/core/db/tests ./swh/core/tests ./swh/core/github/tests diff --git a/PKG-INFO b/PKG-INFO index 51d7b6a..9097bca 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,42 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.5.0 +Version: 2.6.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr -License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ -Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http +Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server - - diff --git a/conftest.py b/conftest.py index 5f2c429..73c6d7d 100644 --- a/conftest.py +++ b/conftest.py @@ -1,20 +1,22 @@ from hypothesis import settings import pytest # define tests profile. Full documentation is at: # https://hypothesis.readthedocs.io/en/latest/settings.html#settings-profiles settings.register_profile("fast", max_examples=5, deadline=5000) settings.register_profile("slow", max_examples=20, deadline=5000) +pytest_plugins = ["swh.core.github.pytest_plugin"] + @pytest.fixture def swhmain(): """Yield an instance of the main `swh` click command that cleans the added subcommands up on teardown.""" from swh.core.cli import swh as _swhmain commands = _swhmain.commands.copy() aliases = _swhmain.aliases.copy() yield _swhmain _swhmain.commands = commands _swhmain.aliases = aliases diff --git a/requirements-github.txt b/requirements-github.txt new file mode 100644 index 0000000..4e56914 --- /dev/null +++ b/requirements-github.txt @@ -0,0 +1,3 @@ +# requirements for swh.core.github +requests +tenacity diff --git a/requirements-http.txt b/requirements-http.txt index 1a0bb79..1cd7eff 100644 --- a/requirements-http.txt +++ b/requirements-http.txt @@ -1,8 +1,9 @@ # requirements for swh.core.api aiohttp aiohttp_utils >= 3.1.1 blinker # dependency of sentry-sdk[flask] flask iso8601 msgpack >= 1.0.0 requests + diff --git a/setup.py b/setup.py index 699e1b8..ab80819 100755 --- a/setup.py +++ b/setup.py @@ -1,88 +1,89 @@ #!/usr/bin/env python3 # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from io import open import os from os import path from setuptools import find_packages, setup here = path.abspath(path.dirname(__file__)) # Get the long description from the README file with open(path.join(here, "README.rst"), encoding="utf-8") as f: long_description = f.read() def parse_requirements(*names): requirements = [] for name in names: if name: reqf = "requirements-%s.txt" % name else: reqf = "requirements.txt" if not os.path.exists(reqf): return requirements with open(reqf) as f: for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( name="swh.core", description="Software Heritage core utilities", long_description=long_description, long_description_content_type="text/x-rst", python_requires=">=3.7", author="Software Heritage developers", author_email="swh-devel@inria.fr", url="https://forge.softwareheritage.org/diffusion/DCORE/", packages=find_packages(), py_modules=["pytest_swh_core"], scripts=[], install_requires=parse_requirements(None, "swh"), setup_requires=["setuptools-scm"], use_scm_version=True, extras_require={ "testing-core": parse_requirements("test"), "logging": parse_requirements("logging"), "db": parse_requirements("db", "db-pytestplugin"), "http": parse_requirements("http"), + "github": parse_requirements("github"), # kitchen sink, please do not use "testing": parse_requirements( "test", "db", "db-pytestplugin", "http", "logging" ), }, include_package_data=True, entry_points=""" [console_scripts] swh=swh.core.cli:main swh-db-init=swh.core.cli.db:db_init [swh.cli.subcommands] db=swh.core.cli.db [pytest11] pytest_swh_core = swh.core.pytest_plugin """, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Development Status :: 5 - Production/Stable", ], project_urls={ "Bug Reports": "https://forge.softwareheritage.org/maniphest", "Funding": "https://www.softwareheritage.org/donate", "Source": "https://forge.softwareheritage.org/source/swh-core", "Documentation": "https://docs.softwareheritage.org/devel/swh-core/", }, ) diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index 51d7b6a..9097bca 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,42 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.5.0 +Version: 2.6.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr -License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ -Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http +Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server - - diff --git a/swh.core.egg-info/SOURCES.txt b/swh.core.egg-info/SOURCES.txt index 26f9945..ff3124c 100644 --- a/swh.core.egg-info/SOURCES.txt +++ b/swh.core.egg-info/SOURCES.txt @@ -1,133 +1,141 @@ .git-blame-ignore-revs .gitignore .pre-commit-config.yaml AUTHORS CODE_OF_CONDUCT.md CONTRIBUTORS LICENSE MANIFEST.in Makefile Makefile.local README.rst conftest.py mypy.ini pyproject.toml pytest.ini requirements-db-pytestplugin.txt requirements-db.txt +requirements-github.txt requirements-http.txt requirements-logging.txt requirements-swh.txt requirements-test.txt requirements.txt setup.cfg setup.py tox.ini docs/.gitignore docs/Makefile docs/README.rst docs/cli.rst docs/conf.py docs/db.rst docs/index.rst docs/_static/.placeholder docs/_templates/.placeholder swh/__init__.py +swh/__main__.py swh.core.egg-info/PKG-INFO swh.core.egg-info/SOURCES.txt swh.core.egg-info/dependency_links.txt swh.core.egg-info/entry_points.txt swh.core.egg-info/requires.txt swh.core.egg-info/top_level.txt swh/core/__init__.py swh/core/api_async.py swh/core/collections.py swh/core/config.py swh/core/logger.py swh/core/py.typed swh/core/pytest_plugin.py swh/core/sentry.py swh/core/statsd.py swh/core/tarball.py swh/core/utils.py swh/core/api/__init__.py swh/core/api/asynchronous.py swh/core/api/classes.py swh/core/api/gunicorn_config.py swh/core/api/negotiation.py swh/core/api/serializers.py swh/core/api/tests/__init__.py swh/core/api/tests/conftest.py swh/core/api/tests/server_testing.py swh/core/api/tests/test_async.py swh/core/api/tests/test_classes.py swh/core/api/tests/test_gunicorn.py swh/core/api/tests/test_init.py swh/core/api/tests/test_rpc_client.py swh/core/api/tests/test_rpc_client_server.py swh/core/api/tests/test_rpc_server.py swh/core/api/tests/test_rpc_server_asynchronous.py swh/core/api/tests/test_serializers.py swh/core/cli/__init__.py swh/core/cli/db.py swh/core/db/__init__.py swh/core/db/common.py swh/core/db/db_utils.py swh/core/db/pytest_plugin.py swh/core/db/sql/35-dbversion.sql swh/core/db/sql/36-dbmodule.sql swh/core/db/tests/__init__.py swh/core/db/tests/conftest.py swh/core/db/tests/test_cli.py swh/core/db/tests/test_db.py swh/core/db/tests/test_db_utils.py swh/core/db/tests/data/cli/sql/0-superuser-init.sql swh/core/db/tests/data/cli/sql/30-schema.sql swh/core/db/tests/data/cli/sql/40-funcs.sql swh/core/db/tests/data/cli/sql/50-data.sql swh/core/db/tests/data/cli_new/sql/0-superuser-init.sql swh/core/db/tests/data/cli_new/sql/30-schema.sql swh/core/db/tests/data/cli_new/sql/40-funcs.sql swh/core/db/tests/data/cli_new/sql/50-data.sql swh/core/db/tests/data/cli_new/sql/upgrades/001.sql swh/core/db/tests/data/cli_new/sql/upgrades/002.sql swh/core/db/tests/data/cli_new/sql/upgrades/003.sql swh/core/db/tests/data/cli_new/sql/upgrades/004.sql swh/core/db/tests/data/cli_new/sql/upgrades/005.sql swh/core/db/tests/data/cli_new/sql/upgrades/006.sql swh/core/db/tests/pytest_plugin/__init__.py swh/core/db/tests/pytest_plugin/test_pytest_plugin.py swh/core/db/tests/pytest_plugin/data/0-schema.sql swh/core/db/tests/pytest_plugin/data/1-data.sql +swh/core/github/__init__.py +swh/core/github/pytest_plugin.py +swh/core/github/utils.py +swh/core/github/tests/__init__.py +swh/core/github/tests/test_github_utils.py +swh/core/github/tests/test_pytest_plugin.py swh/core/tests/__init__.py swh/core/tests/test_cli.py swh/core/tests/test_collections.py swh/core/tests/test_config.py swh/core/tests/test_logger.py swh/core/tests/test_pytest_plugin.py swh/core/tests/test_sentry.py swh/core/tests/test_statsd.py swh/core/tests/test_tarball.py swh/core/tests/test_utils.py swh/core/tests/data/archives/groff-1.02.tar.Z swh/core/tests/data/archives/hello.tar swh/core/tests/data/archives/hello.tar.bz2 swh/core/tests/data/archives/hello.tar.gz swh/core/tests/data/archives/hello.tar.lz swh/core/tests/data/archives/hello.tar.x swh/core/tests/data/archives/hello.tbz swh/core/tests/data/archives/hello.tbz2 swh/core/tests/data/archives/hello.zip swh/core/tests/data/archives/msk316src.zip swh/core/tests/data/archives/tokei-12.1.2.crate swh/core/tests/data/http_example.com/something.json swh/core/tests/data/https_example.com/file.json swh/core/tests/data/https_example.com/file.json,name=doe,firstname=jane swh/core/tests/data/https_example.com/file.json_visit1 swh/core/tests/data/https_example.com/other.json swh/core/tests/data/https_forge.s.o/api_diffusion,attachments[uris]=1 swh/core/tests/data/https_www.reference.com/web,q=What+Is+an+Example+of+a+URL?,qo=contentPageRelatedSearch,o=600605,l=dir,sga=1 swh/core/tests/fixture/__init__.py swh/core/tests/fixture/conftest.py swh/core/tests/fixture/test_pytest_plugin.py swh/core/tests/fixture/data/https_example.com/file.json \ No newline at end of file diff --git a/swh.core.egg-info/requires.txt b/swh.core.egg-info/requires.txt index 012e295..88de467 100644 --- a/swh.core.egg-info/requires.txt +++ b/swh.core.egg-info/requires.txt @@ -1,59 +1,63 @@ click deprecated python-magic pyyaml sentry-sdk [db] psycopg2 typing-extensions pytest-postgresql<4.0.0,>=3 +[github] +requests +tenacity + [http] aiohttp aiohttp_utils>=3.1.1 blinker flask iso8601 msgpack>=1.0.0 requests [logging] systemd-python [testing] hypothesis>=3.11.0 pytest pytest-mock pytz requests-mock types-click types-flask types-psycopg2 types-pytz types-pyyaml types-requests psycopg2 typing-extensions pytest-postgresql<4.0.0,>=3 aiohttp aiohttp_utils>=3.1.1 blinker flask iso8601 msgpack>=1.0.0 requests systemd-python [testing-core] hypothesis>=3.11.0 pytest pytest-mock pytz requests-mock types-click types-flask types-psycopg2 types-pytz types-pyyaml types-requests diff --git a/swh/__main__.py b/swh/__main__.py new file mode 100644 index 0000000..6f04565 --- /dev/null +++ b/swh/__main__.py @@ -0,0 +1,5 @@ +# provides the main swh cli entry point from standard 'python -m swh' +if __name__ == "__main__": + from swh.core.cli import main + + main() diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 0e71cdc..d78f9ec 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,678 +1,664 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from contextlib import contextmanager from datetime import datetime, timezone import functools from importlib import import_module import logging from os import path import pathlib import re import subprocess -from typing import Collection, Dict, List, Optional, Tuple, Union, cast +from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast import psycopg2 import psycopg2.errors import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def now(): return datetime.now(tz=timezone.utc) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value -def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection: - """Connect to the database passed in argument +@contextmanager +def connect_to_conninfo( + db_or_conninfo: Union[str, pgconnection] +) -> Iterator[pgconnection]: + """Connect to the database passed as argument. Args: db_or_conninfo: A database connection, or a database connection info string Returns: - a connected database handle + a connected database handle or None if the database is not initialized - Raises: - psycopg2.Error if the database doesn't exist """ if isinstance(db_or_conninfo, pgconnection): - return db_or_conninfo - - if "=" not in db_or_conninfo and "//" not in db_or_conninfo: - # Database name - db_or_conninfo = f"dbname={db_or_conninfo}" - - db = psycopg2.connect(db_or_conninfo) + yield db_or_conninfo + else: + if "=" not in db_or_conninfo and "//" not in db_or_conninfo: + # Database name + db_or_conninfo = f"dbname={db_or_conninfo}" - return db + try: + db = psycopg2.connect(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + else: + yield db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select version from dbversion order by dbversion desc limit 1" - try: - c.execute(query) - result = c.fetchone() - if result: - return result[0] - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select version from dbversion order by dbversion desc limit 1" + try: + c.execute(query) + result = c.fetchone() + if result: + return result[0] + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_versions( db_or_conninfo: Union[str, pgconnection] ) -> Optional[List[Tuple[int, datetime, str]]]: """Retrieve the swh version history of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = ( - "select version, release, description " - "from dbversion order by dbversion desc" - ) - try: - c.execute(query) - return cast(List[Tuple[int, datetime, str]], c.fetchall()) - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = ( + "select version, release, description " + "from dbversion order by dbversion desc" + ) + try: + c.execute(query) + return cast(List[Tuple[int, datetime, str]], c.fetchall()) + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None def swh_db_upgrade( conninfo: str, modname: str, to_version: Optional[int] = None ) -> int: """Upgrade the database at `conninfo` for module `modname` This will run migration scripts found in the `sql/upgrades` subdirectory of the module `modname`. By default, this will upgrade to the latest declared version. Args: conninfo: A database connection, or a database connection info string modname: datastore module the database stores content for to_version: if given, update the database to this version rather than the latest """ if to_version is None: to_version = 99999999 db_module, db_version, db_flavor = get_database_info(conninfo) if db_version is None: raise ValueError("Unable to retrieve the current version of the database") if db_module is None: raise ValueError("Unable to retrieve the module of the database") if db_module != modname: raise ValueError( "The stored module of the database is different than the given one" ) sqlfiles = [ fname for fname in get_sql_for_package(modname, upgrade=True) if db_version < int(fname.stem) <= to_version ] if not sqlfiles: return db_version for sqlfile in sqlfiles: new_version = int(path.splitext(path.basename(sqlfile))[0]) logger.info("Executing migration script '%s'", sqlfile) if db_version is not None and (new_version - db_version) > 1: logger.error( f"There are missing migration steps between {db_version} and " f"{new_version}. It might be expected but it most unlikely is not. " "Will stop here." ) return db_version execute_sqlfiles([sqlfile], conninfo, db_flavor) # check if the db version has been updated by the upgrade script db_version = swh_db_version(conninfo) assert db_version is not None if db_version == new_version: # nothing to do, upgrade script did the job pass elif db_version == new_version - 1: # it has not (new style), so do it swh_set_db_version( conninfo, new_version, desc=f"Upgraded to version {new_version} using {sqlfile}", ) db_version = swh_db_version(conninfo) else: # upgrade script did it wrong logger.error( f"The upgrade script {sqlfile} did not update the dbversion table " f"consistently ({db_version} vs. expected {new_version}). " "Will stop migration here. Please check your migration scripts." ) return db_version return new_version def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh module used to create the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the module of the database, or None if it couldn't be detected """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select dbmodule from dbmodule limit 1" - try: - c.execute(query) - resp = c.fetchone() - if resp: - return resp[0] - except psycopg2.errors.UndefinedTable: + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select dbmodule from dbmodule limit 1" + try: + c.execute(query) + resp = c.fetchone() + if resp: + return resp[0] + except psycopg2.errors.UndefinedTable: + return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None def swh_set_db_module( db_or_conninfo: Union[str, pgconnection], module: str, force=False ) -> None: """Set the swh module used to create the database. Fails if the dbmodule is already set or the table does not exist. Args: db_or_conninfo: A database connection, or a database connection info string module: the swh module to register (without the leading 'swh.') """ update = False if module.startswith("swh."): module = module[4:] current_module = swh_db_module(db_or_conninfo) if current_module is not None: if current_module == module: logger.warning("The database module is already set to %s", module) return if not force: raise ValueError( "The database module is already set to a value %s " "different than given %s", current_module, module, ) # force is True update = True - try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - sqlfiles = [ - fname - for fname in get_sql_for_package("swh.core.db") - if "dbmodule" in fname.stem - ] - execute_sqlfiles(sqlfiles, db_or_conninfo) + with connect_to_conninfo(db_or_conninfo) as db: + if not db: + return None - with db.cursor() as c: - if update: - query = "update dbmodule set dbmodule = %s" - else: - query = "insert into dbmodule(dbmodule) values (%s)" - c.execute(query, (module,)) - db.commit() + sqlfiles = [ + fname + for fname in get_sql_for_package("swh.core.db") + if "dbmodule" in fname.stem + ] + execute_sqlfiles(sqlfiles, db_or_conninfo) + + with db.cursor() as c: + if update: + query = "update dbmodule set dbmodule = %s" + else: + query = "insert into dbmodule(dbmodule) values (%s)" + c.execute(query, (module,)) + db.commit() def swh_set_db_version( db_or_conninfo: Union[str, pgconnection], version: int, ts: Optional[datetime] = None, desc: str = "Work in progress", ) -> None: """Set the version of the database. Fails if the dbversion table does not exists. Args: db_or_conninfo: A database connection, or a database connection info string version: the version to add """ - try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None if ts is None: ts = now() - with db.cursor() as c: - query = ( - "insert into dbversion(version, release, description) values (%s, %s, %s)" - ) - c.execute(query, (version, ts, desc)) - db.commit() + + with connect_to_conninfo(db_or_conninfo) as db: + if not db: + return None + with db.cursor() as c: + query = ( + "insert into dbversion(version, release, description) " + "values (%s, %s, %s)" + ) + c.execute(query, (version, ts, desc)) + db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: - db = connect_to_conninfo(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - # Database not initialized - return None - - try: - with db.cursor() as c: - query = "select swh_get_dbflavor()" - try: - c.execute(query) - result = c.fetchone() - assert result is not None # to keep mypy happy - return result[0] - except psycopg2.errors.UndefinedFunction: - # function not found: no flavor + with connect_to_conninfo(db_or_conninfo) as db: + if not db: return None + with db.cursor() as c: + query = "select swh_get_dbflavor()" + try: + c.execute(query) + result = c.fetchone() + assert result is not None # to keep mypy happy + return result[0] + except psycopg2.errors.UndefinedFunction: + # function not found: no flavor + return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(rb"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]: """Return the (sorted) list of sql script files for the given swh module If upgrade is True, return the list of available migration scripts, otherwise, return the list of initialization scripts. """ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = pathlib.Path(m.__file__).parent / "sql" if upgrade: sqldir /= "upgrades" if not sqldir.is_dir(): raise ValueError( "Module {} does not provide a db schema (no sql/ dir)".format(modname) ) return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor def globalsortkey(key): "like sortkey but only on basenames" return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem)) sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table swh_set_db_module(conninfo, modname) current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) return True, current_db_version, dbflavor def get_database_info( conninfo: str, ) -> Tuple[Optional[str], Optional[int], Optional[str]]: """Get version, flavor and module of the db""" dbmodule = swh_db_module(conninfo) dbversion = swh_db_version(conninfo) dbflavor = None if dbversion is not None: dbflavor = swh_db_flavor(conninfo) return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[pathlib.Path], db_or_conninfo: Union[str, pgconnection], flavor: Optional[str] = None, ): """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``. Args: sqlfiles: List of SQL files to execute db_or_conninfo: A database connection, or a database connection info string flavor: the database flavor to initialize """ if isinstance(db_or_conninfo, str): conninfo = db_or_conninfo else: conninfo = db_or_conninfo.dsn psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", str(sqlfile)]) if ( flavor is not None and not flavor_set and sqlfile.name.endswith("-flavor.sql") ): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 13b12cd..726f1a1 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,466 +1,466 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass import datetime from enum import IntEnum import inspect from string import printable from typing import Any from unittest.mock import MagicMock, Mock import uuid from hypothesis import given, settings, strategies from hypothesis.extra.pytz import timezones import psycopg2 import pytest from typing_extensions import Protocol from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator from swh.core.db.pytest_plugin import postgresql_fact from swh.core.db.tests.conftest import function_scoped_fixture_check # workaround mypy bug https://github.com/python/mypy/issues/5485 class Converter(Protocol): def __call__(self, x: Any) -> Any: ... @dataclass class Field: name: str """Column name""" pg_type: str """Type of the PostgreSQL column""" example: Any """Example value for the static tests""" strategy: strategies.SearchStrategy """Hypothesis strategy to generate these values""" in_wrapper: Converter = lambda x: x """Wrapper to convert this data type for the static tests""" out_converter: Converter = lambda x: x """Converter from the raw PostgreSQL column value to this data type""" # Limit PostgreSQL integer values pg_int = strategies.integers(-2147483648, +2147483647) pg_text = strategies.text( alphabet=strategies.characters( blacklist_categories=["Cs"], # surrogates blacklist_characters=[ "\x00", # pgsql does not support the null codepoint "\r", # pgsql normalizes those ], ), ) pg_bytea = strategies.binary() def pg_bytea_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[]""" return strategies.lists(pg_bytea, min_size=min_size, max_size=max_size) def pg_bytea_a_a(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate a PostgreSQL bytea[][]. The inner lists must all have the same size.""" return strategies.integers(min_value=max(1, min_size), max_value=max_size).flatmap( lambda n: strategies.lists( pg_bytea_a(min_size=n, max_size=n), min_size=min_size, max_size=max_size ) ) def pg_tstz() -> strategies.SearchStrategy: """Generate values that fit in a PostgreSQL timestamptz. Notes: We're forbidding old datetimes, because until 1956, many timezones had seconds in their "UTC offsets" (see ), which is not representable by PostgreSQL. """ min_value = datetime.datetime(1960, 1, 1, 0, 0, 0) return strategies.datetimes(min_value=min_value, timezones=timezones()) def pg_jsonb(min_size: int, max_size: int) -> strategies.SearchStrategy: """Generate values representable as a PostgreSQL jsonb object (dict).""" return strategies.dictionaries( strategies.text(printable), strategies.recursive( # should use floats() instead of integers(), but PostgreSQL # coerces large integers into floats, making the tests fail. We # only store ints in our generated data anyway. strategies.none() | strategies.booleans() | strategies.integers(-2147483648, +2147483647) | strategies.text(printable), lambda children: strategies.lists(children, max_size=max_size) | strategies.dictionaries( strategies.text(printable), children, max_size=max_size ), ), min_size=min_size, max_size=max_size, ) def tuple_2d_to_list_2d(v): """Convert a 2D tuple to a 2D list""" return [list(inner) for inner in v] def list_2d_to_tuple_2d(v): """Convert a 2D list to a 2D tuple""" return tuple(tuple(inner) for inner in v) class TestIntEnum(IntEnum): foo = 1 bar = 2 def now(): return datetime.datetime.now(tz=datetime.timezone.utc) FIELDS = ( Field("i", "int", 1, pg_int), Field("txt", "text", "foo", pg_text), Field("bytes", "bytea", b"bar", strategies.binary()), Field( "bytes_array", "bytea[]", [b"baz1", b"baz2"], pg_bytea_a(min_size=0, max_size=5), ), Field( "bytes_tuple", "bytea[]", (b"baz1", b"baz2"), pg_bytea_a(min_size=0, max_size=5).map(tuple), in_wrapper=list, out_converter=tuple, ), Field( "bytes_2d", "bytea[][]", [[b"quux1"], [b"quux2"]], pg_bytea_a_a(min_size=0, max_size=5), ), Field( "bytes_2d_tuple", "bytea[][]", ((b"quux1",), (b"quux2",)), pg_bytea_a_a(min_size=0, max_size=5).map(list_2d_to_tuple_2d), in_wrapper=tuple_2d_to_list_2d, out_converter=list_2d_to_tuple_2d, ), Field( "ts", "timestamptz", now(), pg_tstz(), ), Field( "dict", "jsonb", {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}}, pg_jsonb(min_size=0, max_size=5), in_wrapper=psycopg2.extras.Json, ), Field( "intenum", "int", TestIntEnum.foo, strategies.sampled_from(TestIntEnum), in_wrapper=int, out_converter=lambda x: TestIntEnum(x), # lambda needed by mypy ), Field("uuid", "uuid", uuid.uuid4(), strategies.uuids()), Field( "text_list", "text[]", # All the funky corner cases ["null", "NULL", None, "\\", "\t", "\n", "\r", " ", "'", ",", '"', "{", "}"], strategies.lists(pg_text, min_size=0, max_size=5), ), Field( "tstz_list", "timestamptz[]", [now(), now() + datetime.timedelta(days=1)], strategies.lists(pg_tstz(), min_size=0, max_size=5), ), Field( "tstz_range", "tstzrange", psycopg2.extras.DateTimeTZRange( lower=now(), upper=now() + datetime.timedelta(days=1), bounds="[)", ), strategies.tuples( # generate two sorted timestamptzs for use as bounds strategies.tuples(pg_tstz(), pg_tstz()).map(sorted), # and a set of bounds strategies.sampled_from(["[]", "()", "[)", "(]"]), ).map( # and build the actual DateTimeTZRange object from these args lambda args: psycopg2.extras.DateTimeTZRange( lower=args[0][0], upper=args[0][1], bounds=args[1], ) ), ), ) INIT_SQL = "create table test_table (%s)" % ", ".join( f"{field.name} {field.pg_type}" for field in FIELDS ) COLUMNS = tuple(field.name for field in FIELDS) INSERT_SQL = "insert into test_table (%s) values (%s)" % ( ", ".join(COLUMNS), ", ".join("%s" for i in range(len(COLUMNS))), ) STATIC_ROW_IN = tuple(field.in_wrapper(field.example) for field in FIELDS) EXPECTED_ROW_OUT = tuple(field.example for field in FIELDS) db_rows = strategies.lists(strategies.tuples(*(field.strategy for field in FIELDS))) def convert_lines(cur): return [ tuple(field.out_converter(x) for x, field in zip(line, FIELDS)) for line in cur ] test_db = postgresql_fact("postgresql_proc", dbname="test-db2") @pytest.fixture def db_with_data(test_db, request): """Fixture to initialize a db with some data out of the "INIT_SQL above""" db = BaseDb.connect(test_db.dsn) with db.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INIT_SQL) yield db db.conn.rollback() db.conn.close() @pytest.mark.db def test_db_connect(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_initialized(db_with_data): with db_with_data.cursor() as cur: psycopg2.extras.register_default_jsonb(cur) cur.execute(INSERT_SQL, STATIC_ROW_IN) cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] def test_db_copy_to_static(db_with_data): items = [{field.name: field.example for field in FIELDS}] db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") output = convert_lines(cur) assert len(output) == 1 assert EXPECTED_ROW_OUT == output[0] -@settings(suppress_health_check=function_scoped_fixture_check) +@settings(suppress_health_check=function_scoped_fixture_check, max_examples=5) @given(db_rows) def test_db_copy_to(db_with_data, data): items = [dict(zip(COLUMNS, item)) for item in data] with db_with_data.cursor() as cur: cur.execute("TRUNCATE TABLE test_table CASCADE") db_with_data.copy_to(items, "test_table", COLUMNS) with db_with_data.cursor() as cur: cur.execute("select * from test_table;") converted_lines = convert_lines(cur) assert converted_lines == data def test_db_copy_to_thread_exception(db_with_data): data = [(2**65, "foo", b"bar")] items = [dict(zip(COLUMNS, item)) for item in data] with pytest.raises(psycopg2.errors.NumericValueOutOfRange): db_with_data.copy_to(items, "test_table", COLUMNS) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match="generator"): class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" def f(self, foo, *, bar=None): pass expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig @pytest.mark.parametrize( "query_options", (None, {"something": 42, "statement_timeout": 200}) ) @pytest.mark.parametrize("use_generator", (True, False)) def test_db_transaction_query_options(mocker, use_generator, query_options): class Storage: @db_transaction(statement_timeout=100) def endpoint(self, cur=None, db=None): return [None] @db_transaction_generator(statement_timeout=100) def gen_endpoint(self, cur=None, db=None): yield None storage = Storage() # mockers mocked_apply = mocker.patch("swh.core.db.common.apply_options") # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' expected_cur = object() db_mock = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) mocker.patch.object(storage, "put_db", create=True) if query_options: storage.query_options = { "endpoint": query_options, "gen_endpoint": query_options, } if use_generator: list(storage.gen_endpoint()) else: list(storage.endpoint()) mocked_apply.assert_called_once_with( expected_cur, query_options if query_options is not None else {"statement_timeout": 100}, ) diff --git a/swh/core/github/__init__.py b/swh/core/github/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swh/core/github/pytest_plugin.py b/swh/core/github/pytest_plugin.py new file mode 100644 index 0000000..20c5e80 --- /dev/null +++ b/swh/core/github/pytest_plugin.py @@ -0,0 +1,184 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import time +from typing import Dict, Iterator, List, Optional, Union + +import pytest +import requests_mock + +HTTP_GITHUB_API_URL = "https://api.github.com/repositories" + + +def fake_time_sleep(duration: float, sleep_calls: Optional[List[float]] = None): + """Record calls to time.sleep in the sleep_calls list.""" + if duration < 0: + raise ValueError("Can't sleep for a negative amount of time!") + if sleep_calls is not None: + sleep_calls.append(duration) + + +def fake_time_time(): + """Return 0 when running time.time()""" + return 0 + + +@pytest.fixture +def monkeypatch_sleep_calls(monkeypatch) -> Iterator[List[float]]: + """Monkeypatch `time.time` and `time.sleep`. Returns a list cumulating the arguments + passed to time.sleep().""" + sleeps: List[float] = [] + monkeypatch.setattr(time, "sleep", lambda d: fake_time_sleep(d, sleeps)) + monkeypatch.setattr(time, "time", fake_time_time) + yield sleeps + + +@pytest.fixture() +def num_before_ratelimit() -> int: + """Number of successful requests before the ratelimit hits""" + return 0 + + +@pytest.fixture() +def num_ratelimit() -> Optional[int]: + """Number of rate-limited requests; None means infinity""" + return None + + +@pytest.fixture() +def ratelimit_reset() -> Optional[int]: + """Value of the X-Ratelimit-Reset header on ratelimited responses""" + return None + + +def github_ratelimit_callback( + request: requests_mock.request._RequestObjectProxy, + context: requests_mock.response._Context, + ratelimit_reset: Optional[int], +) -> Dict[str, str]: + """Return a rate-limited GitHub API response.""" + # Check request headers + assert request.headers["Accept"] == "application/vnd.github.v3+json" + assert request.headers["User-Agent"] is not None + if "Authorization" in request.headers: + context.status_code = 429 + else: + context.status_code = 403 + + if ratelimit_reset is not None: + context.headers["X-Ratelimit-Reset"] = str(ratelimit_reset) + + return { + "message": "API rate limit exceeded for .", + "documentation_url": "https://developer.github.com/v3/#rate-limiting", + } + + +def github_repo(i: int) -> Dict[str, Union[int, str]]: + """Basic repository information returned by the GitHub API""" + + repo: Dict[str, Union[int, str]] = { + "id": i, + "html_url": f"https://github.com/origin/{i}", + } + + # Set the pushed_at date on one of the origins + if i == 4321: + repo["pushed_at"] = "2018-11-08T13:16:24Z" + + return repo + + +def github_response_callback( + request: requests_mock.request._RequestObjectProxy, + context: requests_mock.response._Context, + page_size: int = 1000, + origin_count: int = 10000, +) -> List[Dict[str, Union[str, int]]]: + """Return minimal GitHub API responses for the common case where the loader + hasn't been rate-limited""" + # Check request headers + assert request.headers["Accept"] == "application/vnd.github.v3+json" + assert request.headers["User-Agent"] is not None + + # Check request parameters: per_page == 1000, since = last_repo_id + assert "per_page" in request.qs + assert request.qs["per_page"] == [str(page_size)] + assert "since" in request.qs + + since = int(request.qs["since"][0]) + + next_page = since + page_size + if next_page < origin_count: + # the first id for the next page is within our origin count; add a Link + # header to the response + next_url = f"{HTTP_GITHUB_API_URL}?per_page={page_size}&since={next_page}" + context.headers["Link"] = f"<{next_url}>; rel=next" + + return [github_repo(i) for i in range(since + 1, min(next_page, origin_count) + 1)] + + +@pytest.fixture() +def requests_ratelimited( + num_before_ratelimit: int, + num_ratelimit: Optional[int], + ratelimit_reset: Optional[int], +) -> Iterator[requests_mock.Mocker]: + """Mock requests to the GitHub API, returning a rate-limiting status code after + `num_before_ratelimit` requests. + + GitHub does inconsistent rate-limiting: + + - Anonymous requests return a 403 status code + - Authenticated requests return a 429 status code, with an X-Ratelimit-Reset header. + + This fixture takes multiple arguments (which can be overridden with a + :func:`pytest.mark.parametrize` parameter): + + - num_before_ratelimit: the global number of requests until the ratelimit triggers + - num_ratelimit: the number of requests that return a rate-limited response. + - ratelimit_reset: the timestamp returned in X-Ratelimit-Reset if the request is + authenticated. + + The default values set in the previous fixtures make all requests return a rate + limit response. + + """ + current_request = 0 + + def response_callback(request, context): + nonlocal current_request + current_request += 1 + if num_before_ratelimit < current_request and ( + num_ratelimit is None + or current_request < num_before_ratelimit + num_ratelimit + 1 + ): + return github_ratelimit_callback(request, context, ratelimit_reset) + else: + return github_response_callback(request, context) + + with requests_mock.Mocker() as mock: + mock.get(HTTP_GITHUB_API_URL, json=response_callback) + yield mock + + +@pytest.fixture +def github_credentials() -> List[Dict[str, str]]: + """Return a static list of GitHub credentials""" + return sorted( + [{"username": f"swh{i:d}", "token": f"token-{i:d}"} for i in range(3)] + + [ + {"username": f"swh-legacy{i:d}", "password": f"token-legacy-{i:d}"} + for i in range(3) + ], + key=lambda c: c["username"], + ) + + +@pytest.fixture +def all_tokens(github_credentials) -> List[str]: + """Return the list of tokens matching the static credential""" + + return [t.get("token", t.get("password")) for t in github_credentials] diff --git a/swh/core/github/tests/__init__.py b/swh/core/github/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swh/core/github/tests/test_github_utils.py b/swh/core/github/tests/test_github_utils.py new file mode 100644 index 0000000..da8bf7b --- /dev/null +++ b/swh/core/github/tests/test_github_utils.py @@ -0,0 +1,160 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import logging + +import pytest + +from swh.core.github.pytest_plugin import HTTP_GITHUB_API_URL +from swh.core.github.utils import ( + GitHubSession, + _sanitize_github_url, + _url_github_api, + _url_github_html, + get_canonical_github_origin_url, +) + +KNOWN_GH_REPO = "https://github.com/user/repo" + + +@pytest.mark.parametrize( + "user_repo, expected_url", + [ + ("user/repo.git", KNOWN_GH_REPO), + ("user/repo.git/", KNOWN_GH_REPO), + ("user/repo/", KNOWN_GH_REPO), + ("user/repo", KNOWN_GH_REPO), + ("user/repo/.git", KNOWN_GH_REPO), + # edge cases + ("https://github.com/unknown-page", None), # unknown gh origin returns None + ("user/repo/with/some/deps", None), # url kind is not dealt with for now + ], +) +def test_get_canonical_github_origin_url(user_repo, expected_url, requests_mock): + """It should return a canonical github origin when it exists, None otherwise""" + html_url = _url_github_html(user_repo) + api_url = _url_github_api(_sanitize_github_url(user_repo)) + + if expected_url is not None: + status_code = 200 + response = {"html_url": _sanitize_github_url(html_url)} + else: + status_code = 404 + response = {} + + requests_mock.get(api_url, [{"status_code": status_code, "json": response}]) + + assert get_canonical_github_origin_url(html_url) == expected_url + + +def test_get_canonical_github_origin_url_not_gh_origin(): + """It should return the input url when that origin is not a github one""" + url = "https://example.org" + assert get_canonical_github_origin_url(url) == url + + +def test_github_session_anonymous_session(): + user_agent = ("GitHub Session Test",) + github_session = GitHubSession( + user_agent=user_agent, + ) + assert github_session.anonymous is True + + actual_headers = github_session.session.headers + assert actual_headers["Accept"] == "application/vnd.github.v3+json" + assert actual_headers["User-Agent"] == user_agent + + +@pytest.mark.parametrize( + "num_ratelimit", [1] # return a single rate-limit response, then continue +) +def test_github_session_ratelimit_once_recovery( + caplog, + requests_ratelimited, + num_ratelimit, + monkeypatch_sleep_calls, + github_credentials, +): + """GitHubSession should recover from hitting the rate-limit once""" + caplog.set_level(logging.DEBUG, "swh.core.github.utils") + + github_session = GitHubSession( + user_agent="GitHub Session Test", credentials=github_credentials + ) + + res = github_session.request(f"{HTTP_GITHUB_API_URL}?per_page=1000&since=10") + assert res.status_code == 200 + + token_users = [] + for record in caplog.records: + if "Using authentication token" in record.message: + token_users.append(record.args[0]) + + # check that we used one more token than we saw rate limited requests + assert len(token_users) == 1 + num_ratelimit + + # check that we slept for one second between our token uses + assert monkeypatch_sleep_calls == [1] + + +def test_github_session_authenticated_credentials( + caplog, github_credentials, all_tokens +): + """GitHubSession should have Authorization headers set in authenticated mode""" + caplog.set_level(logging.DEBUG, "swh.core.github.utils") + + github_session = GitHubSession( + "GitHub Session Test", credentials=github_credentials + ) + + assert github_session.anonymous is False + assert github_session.token_index == 0 + assert ( + sorted(github_session.credentials, key=lambda t: t["username"]) + == github_credentials + ) + assert github_session.session.headers["Authorization"] in [ + f"token {t}" for t in all_tokens + ] + + +@pytest.mark.parametrize( + # Do 5 successful requests, return 6 ratelimits (to exhaust the credentials) with a + # set value for X-Ratelimit-Reset, then resume listing successfully. + "num_before_ratelimit, num_ratelimit, ratelimit_reset", + [(5, 6, 123456)], +) +def test_github_session_ratelimit_reset_sleep( + caplog, + requests_ratelimited, + monkeypatch_sleep_calls, + num_before_ratelimit, + num_ratelimit, + ratelimit_reset, + github_credentials, +): + """GitHubSession should handle rate-limit with authentication tokens.""" + caplog.set_level(logging.DEBUG, "swh.core.github.utils") + + github_session = GitHubSession( + user_agent="GitHub Session Test", credentials=github_credentials + ) + + for _ in range(num_ratelimit): + github_session.request(f"{HTTP_GITHUB_API_URL}?per_page=1000&since=10") + + # We sleep 1 second every time we change credentials, then we sleep until + # ratelimit_reset + 1 + expected_sleep_calls = len(github_credentials) * [1] + [ratelimit_reset + 1] + assert monkeypatch_sleep_calls == expected_sleep_calls + + found_exhaustion_message = False + for record in caplog.records: + if record.levelname == "INFO": + if "Rate limits exhausted for all tokens" in record.message: + found_exhaustion_message = True + break + + assert found_exhaustion_message is True diff --git a/swh/core/github/tests/test_pytest_plugin.py b/swh/core/github/tests/test_pytest_plugin.py new file mode 100644 index 0000000..57aa7e3 --- /dev/null +++ b/swh/core/github/tests/test_pytest_plugin.py @@ -0,0 +1,50 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest +import time +from swh.core.github.pytest_plugin import fake_time_sleep, fake_time_time + + +@pytest.mark.parametrize("duration", [10, 20, -1]) +def test_fake_time_sleep(duration): + + if duration < 0: + with pytest.raises(ValueError, match="negative"): + fake_time_sleep(duration, []) + else: + sleep_calls = [] + fake_time_sleep(duration, sleep_calls) + assert duration in sleep_calls + + +def test_fake_time_time(): + assert fake_time_time() == 0 + + +def test_monkeypatch_sleep_calls(monkeypatch_sleep_calls): + + sleeps = [10, 20, 30] + for sleep in sleeps: + # This adds the sleep number inside the monkeypatch_sleep_calls fixture + time.sleep(sleep) + assert sleep in monkeypatch_sleep_calls + + assert len(monkeypatch_sleep_calls) == len(sleeps) + # This mocks time but adds nothing to the same fixture + time.time() + assert len(monkeypatch_sleep_calls) == len(sleeps) + + +def test_num_before_ratelimit(num_before_ratelimit): + assert num_before_ratelimit == 0 + + +def test_ratelimit_reset(ratelimit_reset): + assert ratelimit_reset is None + + +def test_num_ratelimit(num_ratelimit): + assert num_ratelimit is None diff --git a/swh/core/github/utils.py b/swh/core/github/utils.py new file mode 100644 index 0000000..c21c8bb --- /dev/null +++ b/swh/core/github/utils.py @@ -0,0 +1,214 @@ +# Copyright (C) 2020-2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +import logging +import random +import re +import time +from typing import Dict, List, Optional + +import requests +from tenacity import ( + retry, + retry_any, + retry_if_exception_type, + retry_if_result, + wait_exponential, +) + +GITHUB_PATTERN = re.compile(r"https?://github.com/(?P.*)") + + +logger = logging.getLogger(__name__) + + +def _url_github_html(user_repo: str) -> str: + """Given the user repo, returns the expected github html url.""" + return f"https://github.com/{user_repo}" + + +def _url_github_api(user_repo: str) -> str: + """Given the user_repo, returns the expected github api url.""" + return f"https://api.github.com/repos/{user_repo}" + + +def _sanitize_github_url(url: str) -> str: + """Sanitize github url.""" + return url.lower().rstrip("/").rstrip(".git").rstrip("/") + + +def get_canonical_github_origin_url(url: str) -> Optional[str]: + """Retrieve canonical github url out of an url if any or None otherwise. + + This triggers an anonymous http request to the github api url to determine the + canonical repository url. + + """ + url_ = url.lower() + + match = GITHUB_PATTERN.match(url_) + if not match: + return url + + user_repo = _sanitize_github_url(match.groupdict()["user_repo"]) + response = requests.get(_url_github_api(user_repo)) + if response.status_code != 200: + return None + data = response.json() + return data["html_url"] + + +class RateLimited(Exception): + def __init__(self, response): + self.reset_time: Optional[int] + + # Figure out how long we need to sleep because of that rate limit + ratelimit_reset = response.headers.get("X-Ratelimit-Reset") + retry_after = response.headers.get("Retry-After") + if ratelimit_reset is not None: + self.reset_time = int(ratelimit_reset) + elif retry_after is not None: + self.reset_time = int(time.time()) + int(retry_after) + 1 + else: + logger.warning( + "Received a rate-limit-like status code %s, but no rate-limit " + "headers set. Response content: %s", + response.status_code, + response.content, + ) + self.reset_time = None + self.response = response + + +class MissingRateLimitReset(Exception): + pass + + +class GitHubSession: + """Manages a :class:`requests.Session` with (optionally) multiple credentials, + and cycles through them when reaching rate-limits.""" + + credentials: Optional[List[Dict[str, str]]] = None + + def __init__( + self, user_agent: str, credentials: Optional[List[Dict[str, str]]] = None + ) -> None: + """Initialize a requests session with the proper headers for requests to + GitHub.""" + if credentials: + creds = credentials.copy() + random.shuffle(creds) + self.credentials = creds + + self.session = requests.Session() + + self.session.headers.update( + {"Accept": "application/vnd.github.v3+json", "User-Agent": user_agent} + ) + + self.anonymous = not self.credentials + + if self.anonymous: + logger.warning("No tokens set in configuration, using anonymous mode") + + self.token_index = -1 + self.current_user: Optional[str] = None + + if not self.anonymous: + # Initialize the first token value in the session headers + self.set_next_session_token() + + def set_next_session_token(self) -> None: + """Update the current authentication token with the next one in line.""" + + assert self.credentials + + self.token_index = (self.token_index + 1) % len(self.credentials) + + auth = self.credentials[self.token_index] + + self.current_user = auth["username"] + logger.debug("Using authentication token for user %s", self.current_user) + + if "password" in auth: + token = auth["password"] + else: + token = auth["token"] + + self.session.headers.update({"Authorization": f"token {token}"}) + + @retry( + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_any( + # ChunkedEncodingErrors happen when the TLS connection gets reset, e.g. + # when running the lister on a connection with high latency + retry_if_exception_type(requests.exceptions.ChunkedEncodingError), + # 502 status codes happen for a Server Error, sometimes + retry_if_result(lambda r: r.status_code == 502), + ), + ) + def _request(self, url: str) -> requests.Response: + response = self.session.get(url) + + if ( + # GitHub returns inconsistent status codes between unauthenticated + # rate limit and authenticated rate limits. Handle both. + response.status_code == 429 + or (self.anonymous and response.status_code == 403) + ): + raise RateLimited(response) + + return response + + def request(self, url) -> requests.Response: + """Repeatedly requests the given URL, cycling through credentials and sleeping + if necessary; until either a successful response or :exc:`MissingRateLimitReset` + """ + # The following for/else loop handles rate limiting; if successful, + # it provides the rest of the function with a `response` object. + # + # If all tokens are rate-limited, we sleep until the reset time, + # then `continue` into another iteration of the outer while loop, + # attempting to get data from the same URL again. + + while True: + max_attempts = len(self.credentials) if self.credentials else 1 + reset_times: Dict[int, int] = {} # token index -> time + for attempt in range(max_attempts): + try: + return self._request(url) + except RateLimited as e: + reset_info = "(unknown reset)" + if e.reset_time is not None: + reset_times[self.token_index] = e.reset_time + reset_info = "(resetting in %ss)" % (e.reset_time - time.time()) + + if not self.anonymous: + logger.info( + "Rate limit exhausted for current user %s %s", + self.current_user, + reset_info, + ) + # Use next token in line + self.set_next_session_token() + # Wait one second to avoid triggering GitHub's abuse rate limits + time.sleep(1) + + # All tokens have been rate-limited. What do we do? + + if not reset_times: + logger.warning( + "No X-Ratelimit-Reset value found in responses for any token; " + "Giving up." + ) + raise MissingRateLimitReset() + + sleep_time = max(reset_times.values()) - time.time() + 1 + logger.info( + "Rate limits exhausted for all tokens. Sleeping for %f seconds.", + sleep_time, + ) + time.sleep(sleep_time) diff --git a/tox.ini b/tox.ini index 017fb1f..9160506 100644 --- a/tox.ini +++ b/tox.ini @@ -1,95 +1,100 @@ [tox] -envlist=black,flake8,mypy,py3-{core,db,server} +envlist=black,flake8,mypy,py3-{core,db,server,github} [testenv] passenv = PYTHONASYNCIODEBUG extras = testing-core core: logging db: db server: http + github: github deps = cover: pytest-cov commands = pytest --doctest-modules \ slow: --hypothesis-profile=slow \ cover: --cov={envsitepackagesdir}/swh/core --cov-branch \ core: {envsitepackagesdir}/swh/core/tests \ db: {envsitepackagesdir}/swh/core/db/tests \ server: {envsitepackagesdir}/swh/core/api/tests \ + github: {envsitepackagesdir}/swh/core/github/tests \ {posargs} [testenv:py3] skip_install = true deps = tox -commands = tox -e py3-core-db-server-slow-cover -- {posargs} +commands = tox -e py3-core-db-server-github-slow-cover -- {posargs} [testenv:black] skip_install = true deps = black==22.3.0 commands = {envpython} -m black --check swh [testenv:flake8] skip_install = true deps = flake8==4.0.1 flake8-bugbear==22.3.23 commands = {envpython} -m flake8 [testenv:mypy] extras = testing-core logging db http + github deps = mypy==0.942 commands = mypy swh # build documentation outside swh-environment using the current # git HEAD of swh-docs, is executed on CI for each diff to prevent # breaking doc build [testenv:sphinx] whitelist_externals = make usedevelop = true extras = testing-core logging db http + github deps = # fetch and install swh-docs in develop mode -e git+https://forge.softwareheritage.org/source/swh-docs#egg=swh.docs setenv = SWH_PACKAGE_DOC_TOX_BUILD = 1 # turn warnings into errors SPHINXOPTS = -W commands = make -I ../.tox/sphinx/src/swh-docs/swh/ -C docs # build documentation only inside swh-environment using local state # of swh-docs package [testenv:sphinx-dev] whitelist_externals = make usedevelop = true extras = testing-core logging db http + github deps = # install swh-docs in develop mode -e ../swh-docs setenv = SWH_PACKAGE_DOC_TOX_BUILD = 1 # turn warnings into errors SPHINXOPTS = -W commands = make -I ../.tox/sphinx-dev/src/swh-docs/swh/ -C docs