diff --git a/PKG-INFO b/PKG-INFO index f08d69f..8b8f7a4 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,39 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.14.0 +Version: 2.14.1 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index f08d69f..8b8f7a4 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,39 +1,39 @@ Metadata-Version: 2.1 Name: swh.core -Version: 2.14.0 +Version: 2.14.1 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/x-rst Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: http Provides-Extra: github Provides-Extra: testing License-File: LICENSE License-File: AUTHORS Software Heritage - Core foundations ==================================== Low-level utilities and helpers used by almost all other modules in the stack. core library for swh's modules: - config parser - serialization - logging mechanism - database connection - http-based RPC client/server diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index e325434..b738fb8 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,702 +1,702 @@ # Copyright (C) 2015-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from contextlib import contextmanager from datetime import datetime, timezone import functools from importlib import import_module import logging from os import path import pathlib import re import subprocess from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast import psycopg2 import psycopg2.errors import psycopg2.extensions from psycopg2.extensions import connection as pgconnection from psycopg2.extensions import encodings as pgencodings from psycopg2.extensions import make_dsn from psycopg2.extensions import parse_dsn as _parse_dsn from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) def now(): return datetime.now(tz=timezone.utc) def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get("cur", None) self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) return _meth return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value @contextmanager def connect_to_conninfo( db_or_conninfo: Union[str, pgconnection] ) -> Iterator[pgconnection]: """Connect to the database passed as argument. Args: db_or_conninfo: A database connection, or a database connection info string Returns: a connected database handle or None if the database is not initialized """ if isinstance(db_or_conninfo, pgconnection): yield db_or_conninfo else: if "=" not in db_or_conninfo and "//" not in db_or_conninfo: # Database name db_or_conninfo = f"dbname={db_or_conninfo}" - try: - db = psycopg2.connect(db_or_conninfo) - except psycopg2.Error: - logger.exception("Failed to connect to `%s`", db_or_conninfo) - else: - yield db + try: + db = psycopg2.connect(db_or_conninfo) + except psycopg2.Error: + logger.exception("Failed to connect to `%s`", db_or_conninfo) + else: + yield db def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]: """Retrieve the swh version of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select version from dbversion order by dbversion desc limit 1" try: c.execute(query) result = c.fetchone() if result: return result[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get version from `%s`", db_or_conninfo) return None def swh_db_versions( db_or_conninfo: Union[str, pgconnection] ) -> Optional[List[Tuple[int, datetime, str]]]: """Retrieve the swh version history of the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the version of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = ( "select version, release, description " "from dbversion order by dbversion desc" ) try: c.execute(query) return cast(List[Tuple[int, datetime, str]], c.fetchall()) except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get versions from `%s`", db_or_conninfo) return None def swh_db_upgrade( conninfo: str, modname: str, to_version: Optional[int] = None ) -> int: """Upgrade the database at `conninfo` for module `modname` This will run migration scripts found in the `sql/upgrades` subdirectory of the module `modname`. By default, this will upgrade to the latest declared version. Args: conninfo: A database connection, or a database connection info string modname: datastore module the database stores content for to_version: if given, update the database to this version rather than the latest """ if to_version is None: to_version = 99999999 db_module, db_version, db_flavor = get_database_info(conninfo) if db_version is None: raise ValueError("Unable to retrieve the current version of the database") if db_module is None: raise ValueError("Unable to retrieve the module of the database") if db_module != modname: raise ValueError( "The stored module of the database is different than the given one" ) sqlfiles = [ fname for fname in get_sql_for_package(modname, upgrade=True) if db_version < int(fname.stem) <= to_version ] if not sqlfiles: return db_version for sqlfile in sqlfiles: new_version = int(path.splitext(path.basename(sqlfile))[0]) logger.info("Executing migration script '%s'", sqlfile) if db_version is not None and (new_version - db_version) > 1: logger.error( f"There are missing migration steps between {db_version} and " f"{new_version}. It might be expected but it most unlikely is not. " "Will stop here." ) return db_version execute_sqlfiles([sqlfile], conninfo, db_flavor) # check if the db version has been updated by the upgrade script db_version = swh_db_version(conninfo) assert db_version is not None if db_version == new_version: # nothing to do, upgrade script did the job pass elif db_version == new_version - 1: # it has not (new style), so do it swh_set_db_version( conninfo, new_version, desc=f"Upgraded to version {new_version} using {sqlfile}", ) db_version = swh_db_version(conninfo) else: # upgrade script did it wrong logger.error( f"The upgrade script {sqlfile} did not update the dbversion table " f"consistently ({db_version} vs. expected {new_version}). " "Will stop migration here. Please check your migration scripts." ) return db_version return new_version def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh module used to create the database. If the database is not initialized, this logs a warning and returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: Either the module of the database, or None if it couldn't be detected """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select dbmodule from dbmodule limit 1" try: c.execute(query) resp = c.fetchone() if resp: return resp[0] except psycopg2.errors.UndefinedTable: return None except Exception: logger.exception("Could not get module from `%s`", db_or_conninfo) return None def swh_set_db_module( db_or_conninfo: Union[str, pgconnection], module: str, force=False ) -> None: """Set the swh module used to create the database. Fails if the dbmodule is already set or the table does not exist. Args: db_or_conninfo: A database connection, or a database connection info string module: the swh module to register (without the leading 'swh.') """ update = False if module.startswith("swh."): module = module[4:] current_module = swh_db_module(db_or_conninfo) if current_module is not None: if current_module == module: logger.warning("The database module is already set to %s", module) return if not force: raise ValueError( "The database module is already set to a value %s " "different than given %s", current_module, module, ) # force is True update = True with connect_to_conninfo(db_or_conninfo) as db: if not db: return None sqlfiles = [ fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname.stem ] execute_sqlfiles(sqlfiles, db_or_conninfo) with db.cursor() as c: if update: query = "update dbmodule set dbmodule = %s" else: query = "insert into dbmodule(dbmodule) values (%s)" c.execute(query, (module,)) db.commit() def swh_set_db_version( db_or_conninfo: Union[str, pgconnection], version: int, ts: Optional[datetime] = None, desc: str = "Work in progress", ) -> None: """Set the version of the database. Fails if the dbversion table does not exists. Args: db_or_conninfo: A database connection, or a database connection info string version: the version to add """ if ts is None: ts = now() with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = ( "insert into dbversion(version, release, description) " "values (%s, %s, %s)" ) c.execute(query, (version, ts, desc)) db.commit() def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]: """Retrieve the swh flavor of the database. If the database is not initialized, or the database doesn't support flavors, this returns None. Args: db_or_conninfo: A database connection, or a database connection info string Returns: The flavor of the database, or None if it could not be detected. """ try: with connect_to_conninfo(db_or_conninfo) as db: if not db: return None with db.cursor() as c: query = "select swh_get_dbflavor()" try: c.execute(query) result = c.fetchone() assert result is not None # to keep mypy happy return result[0] except psycopg2.errors.UndefinedFunction: # function not found: no flavor return None except Exception: logger.exception("Could not get flavor from `%s`", db_or_conninfo) return None # The following code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] tokens = re.split(rb"(%.)", sql) for token in tokens: if len(token) != 2 or token[:1] != b"%": curr.append(token) continue if token[1:] == b"s": if curr is pre: curr = post else: raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b"%": curr.append(b"%") else: raise ValueError( "unsupported format character: '%s'" % token[1:].decode("ascii", "replace") ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): sql = sql.encode(pgencodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) parts.append(b",") parts[-1:] = post cur.execute(b"".join(parts)) yield from cur def import_swhmodule(modname): if not modname.startswith("swh."): modname = f"swh.{modname}" try: m = import_module(modname) except ImportError as exc: logger.error(f"Could not load the {modname} module: {exc}") return None return m def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]: """Return the (sorted) list of sql script files for the given swh module If upgrade is True, return the list of available migration scripts, otherwise, return the list of initialization scripts. """ m = import_swhmodule(modname) if m is None: raise ValueError(f"Module {modname} cannot be loaded") sqldir = pathlib.Path(m.__file__).parent / "sql" if upgrade: sqldir /= "upgrades" if not sqldir.is_dir(): raise ValueError( "Module {} does not provide a db schema (no sql/ dir)".format(modname) ) return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name)) def populate_database_for_package( modname: str, conninfo: str, flavor: Optional[str] = None ) -> Tuple[bool, Optional[int], Optional[str]]: """Populate the database, pointed at with ``conninfo``, using the SQL files found in the package ``modname``. Also fill the 'dbmodule' table with the given ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string for the SQL database flavor: the module-specific flavor which we want to initialize the database under Returns: Tuple with three elements: whether the database has been initialized; the current version of the database; if it exists, the flavor of the database. """ current_version = swh_db_version(conninfo) if current_version is not None: dbflavor = swh_db_flavor(conninfo) return False, current_version, dbflavor def globalsortkey(key): "like sortkey but only on basenames" return sortkey(path.basename(key)) sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db") sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem)) sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem] execute_sqlfiles(sqlfiles, conninfo, flavor) # populate the dbmodule table swh_set_db_module(conninfo, modname) current_db_version = swh_db_version(conninfo) dbflavor = swh_db_flavor(conninfo) return True, current_db_version, dbflavor def initialize_database_for_module( modname: str, version: int, flavor: Optional[str] = None, **kwargs ): """Helper function to initialize and populate a database for the given module This aims at helping the usage of pytest_postgresql for swh.core.db based datastores. Typical usage will be (here for swh.storage):: from pytest_postgresql import factories storage_postgresql_proc = factories.postgresql_proc( load=[partial(initialize_database_for_module, modname="storage", version=42)] ) storage_postgresql = factories.postgresql("storage_postgresql_proc") """ conninfo = psycopg2.connect(**kwargs).dsn init_admin_extensions(modname, conninfo) populate_database_for_package(modname, conninfo, flavor) try: swh_set_db_version(conninfo, version) except psycopg2.errors.UniqueViolation: logger.warn( "Version already set by db init scripts. " f"This generally means the swh.{modname} package needs to be " "updated for swh.core>=1.2" ) def get_database_info( conninfo: str, ) -> Tuple[Optional[str], Optional[int], Optional[str]]: """Get version, flavor and module of the db""" dbmodule = swh_db_module(conninfo) dbversion = swh_db_version(conninfo) dbflavor = None if dbversion is not None: dbflavor = swh_db_flavor(conninfo) return (dbmodule, dbversion, dbflavor) def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" try: return _parse_dsn(dsn_or_dbname) except psycopg2.ProgrammingError: # psycopg2 failed to parse the DSN; it's probably a database name, # handle it as such return _parse_dsn(f"dbname={dsn_or_dbname}") def init_admin_extensions(modname: str, conninfo: str) -> None: """The remaining initialization process -- running -superuser- SQL files -- is done using the given conninfo, thus connecting to the newly created database """ sqlfiles = get_sql_for_package(modname) sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem] execute_sqlfiles(sqlfiles, conninfo) def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): """Create the database pointed at with ``conninfo``, and initialize it using -superuser- SQL files found in the package ``modname``. Args: modname: Name of the module of which we're loading the files conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step creation_dsn = parse_dsn_or_dbname(conninfo) dbname = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create dbname=%s (from %s)", dbname, template) subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", make_dsn(**creation_dsn), "-c", f'CREATE DATABASE "{dbname}"', ] ) init_admin_extensions(modname, conninfo) def execute_sqlfiles( sqlfiles: Collection[pathlib.Path], db_or_conninfo: Union[str, pgconnection], flavor: Optional[str] = None, ): """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``. Args: sqlfiles: List of SQL files to execute db_or_conninfo: A database connection, or a database connection info string flavor: the database flavor to initialize """ if isinstance(db_or_conninfo, str): conninfo = db_or_conninfo else: conninfo = db_or_conninfo.dsn psql_command = [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, ] flavor_set = False for sqlfile in sqlfiles: logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}") subprocess.check_call(psql_command + ["-f", str(sqlfile)]) if ( flavor is not None and not flavor_set and sqlfile.name.endswith("-flavor.sql") ): logger.debug("Setting database flavor %s", flavor) query = f"insert into dbflavor (flavor) values ('{flavor}')" subprocess.check_call(psql_command + ["-c", query]) flavor_set = True if flavor is not None and not flavor_set: logger.warn( "Asked for flavor %s, but module does not support database flavors", flavor, ) # Grant read-access to guest user on all tables of the schema (if possible) with connect_to_conninfo(db_or_conninfo) as db: try: with db.cursor() as c: query = "grant select on all tables in schema public to guest" c.execute(query) except Exception: logger.warning("Grant read-only access to guest user failed. Skipping.") diff --git a/swh/core/github/tests/test_pytest_plugin.py b/swh/core/github/tests/test_pytest_plugin.py index 57aa7e3..cb49816 100644 --- a/swh/core/github/tests/test_pytest_plugin.py +++ b/swh/core/github/tests/test_pytest_plugin.py @@ -1,50 +1,52 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import pytest import time + +import pytest + from swh.core.github.pytest_plugin import fake_time_sleep, fake_time_time @pytest.mark.parametrize("duration", [10, 20, -1]) def test_fake_time_sleep(duration): if duration < 0: with pytest.raises(ValueError, match="negative"): fake_time_sleep(duration, []) else: sleep_calls = [] fake_time_sleep(duration, sleep_calls) assert duration in sleep_calls def test_fake_time_time(): assert fake_time_time() == 0 def test_monkeypatch_sleep_calls(monkeypatch_sleep_calls): sleeps = [10, 20, 30] for sleep in sleeps: # This adds the sleep number inside the monkeypatch_sleep_calls fixture time.sleep(sleep) assert sleep in monkeypatch_sleep_calls assert len(monkeypatch_sleep_calls) == len(sleeps) # This mocks time but adds nothing to the same fixture time.time() assert len(monkeypatch_sleep_calls) == len(sleeps) def test_num_before_ratelimit(num_before_ratelimit): assert num_before_ratelimit == 0 def test_ratelimit_reset(ratelimit_reset): assert ratelimit_reset is None def test_num_ratelimit(num_ratelimit): assert num_ratelimit is None diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index dc44478..c5d1b21 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,369 +1,407 @@ # Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import deque from functools import partial import logging from os import path import re from typing import Dict, List, Optional from urllib.parse import unquote, urlparse import pytest import requests from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers +import sentry_sdk logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( request: requests.Request, context, datadir, ignore_urls: List[str] = [], visits: Optional[Dict] = None, ): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ logger.debug("get_response_cb(%s, %s)", request, context) logger.debug("url: %s", request.url) logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] if filename.endswith("/"): filename = filename[:-1] filename = filename.replace("/", "_") if url.query: filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None fd = open(filepath, "rb") context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture def datadir(request: pytest.FixtureRequest) -> str: """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ # pytest >= 7 renamed FixtureRequest fspath attribute to path path_ = request.path if hasattr(request, "path") else request.fspath # type: ignore return path.join(path.dirname(str(path_)), "data") def requests_mock_datadir_factory( ignore_urls: List[str] = [], has_multi_visit: bool = False ): """This factory generates fixtures which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the data/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: ``requests_mock_datadir_factory([])`` This computes the file name from the query and always returns the same result. ``requests_mock_datadir_factory(has_multi_visit=True)`` This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. ``requests_mock_datadir_factory(ignore_urls=['url1', 'url2'])`` This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} requests_mock.get( re.compile("https?://"), body=partial( get_response_cb, ignore_urls=ignore_urls, visits=visits, datadir=datadir, ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory() """ Instance of :py:func:`requests_mock_datadir_factory`, with the default arguments. """ # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) """ Instance of :py:func:`requests_mock_datadir_factory`, with the default arguments, but `has_multi_visit=True`. """ @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ client = app.test_client() yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( request.url, method=request.method, headers=request.headers.items(), data=request.body, ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) def _push_request_context(request: pytest.FixtureRequest): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ if "app" not in request.fixturenames: return app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) class FakeSocket(object): """A fake socket for testing.""" def __init__(self): self.payloads = deque() def send(self, payload): assert type(payload) == bytes self.payloads.append(payload) def recv(self): try: return self.payloads.popleft().decode("utf-8") except IndexError: return None def close(self): pass def __repr__(self): return str(self.payloads) @pytest.fixture def statsd(): """Simple fixture giving a Statsd instance suitable for tests The Statsd instance uses a FakeSocket as `.socket` attribute in which one can get the accumulated statsd messages in a deque in `.socket.payloads`. """ from swh.core.statsd import Statsd statsd = Statsd() statsd._socket = FakeSocket() yield statsd + + +@pytest.fixture +def monkeypatch_sentry_transport(): + # Inspired by + # https://github.com/getsentry/sentry-python/blob/1.5.9/tests/conftest.py#L168-L184 + + initialized = False + + def setup_sentry_transport_monkeypatch(*a, **kw): + nonlocal initialized + assert not initialized, "already initialized" + initialized = True + hub = sentry_sdk.Hub.current + client = sentry_sdk.Client(*a, **kw) + hub.bind_client(client) + client.transport = TestTransport() + + class TestTransport: + def __init__(self): + self.events = [] + self.envelopes = [] + + def capture_event(self, event): + self.events.append(event) + + def capture_envelope(self, envelope): + self.envelopes.append(envelope) + + with sentry_sdk.Hub(None): + yield setup_sentry_transport_monkeypatch + + +@pytest.fixture +def sentry_events(monkeypatch_sentry_transport): + monkeypatch_sentry_transport() + return sentry_sdk.Hub.current.client.transport.events diff --git a/swh/core/tests/test_sentry.py b/swh/core/tests/test_sentry.py index fd26d93..f97e025 100644 --- a/swh/core/tests/test_sentry.py +++ b/swh/core/tests/test_sentry.py @@ -1,103 +1,135 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import pytest -from sentry_sdk import capture_message +from sentry_sdk import capture_exception, capture_message, set_tag from swh.core.sentry import init_sentry, override_with_bool_envvar @pytest.mark.parametrize( "envvalue,retval", ( ("y", True), ("n", False), ("0", False), ("true", True), ("FaLsE", False), ("1", True), ), ) def test_override_with_bool_envvar(monkeypatch, envvalue: str, retval: bool): """Test if the override_with_bool_envvar function returns appropriate results""" envvar = "OVERRIDE_WITH_BOOL_ENVVAR" monkeypatch.setenv(envvar, envvalue) for default in (True, False): assert override_with_bool_envvar(envvar, default) == retval def test_override_with_bool_envvar_logging(monkeypatch, caplog): envvar = "OVERRIDE_WITH_BOOL_ENVVAR" monkeypatch.setenv(envvar, "not a boolean env value") for default in (True, False): caplog.clear() assert override_with_bool_envvar(envvar, default) == default assert len(caplog.records) == 1 assert ( "OVERRIDE_WITH_BOOL_ENVVAR='not a boolean env value'" in caplog.records[0].getMessage() ) assert f"using default value {default}" in caplog.records[0].getMessage() assert caplog.records[0].levelname == "WARNING" def test_sentry(): reports = [] init_sentry("http://example.org", extra_kwargs={"transport": reports.append}) capture_message("Something went wrong") logging.error("Stupid error") assert len(reports) == 2 assert reports[0]["message"] == "Something went wrong" assert reports[1]["logentry"]["message"] == "Stupid error" def test_sentry_no_logging(): reports = [] init_sentry( "http://example.org", disable_logging_events=True, extra_kwargs={"transport": reports.append}, ) capture_message("Something went wrong") logging.error("Stupid error") assert len(reports) == 1 assert reports[0]["message"] == "Something went wrong" def test_sentry_no_logging_from_venv(monkeypatch): monkeypatch.setenv("SWH_SENTRY_DISABLE_LOGGING_EVENTS", "True") reports = [] init_sentry( "http://example.org", extra_kwargs={"transport": reports.append}, ) capture_message("Something went wrong") logging.error("Stupid error") assert len(reports) == 1 assert reports[0]["message"] == "Something went wrong" def test_sentry_logging_from_venv(monkeypatch): monkeypatch.setenv("SWH_SENTRY_DISABLE_LOGGING_EVENTS", "false") reports = [] init_sentry( "http://example.org", extra_kwargs={"transport": reports.append}, ) capture_message("Something went wrong") logging.error("Stupid error") assert len(reports) == 2 + + +def test_sentry_events_fixture_capture_message(sentry_events): + message = "Something went wrong" + capture_message(message) + assert sentry_events + assert "message" in sentry_events[0] + assert sentry_events[0]["message"] == message + + +def test_sentry_events_fixture_capture_exception(sentry_events): + message = "Invalid value" + exception = ValueError(message) + capture_exception(exception) + assert sentry_events + assert "exception" in sentry_events[0] + assert "values" in sentry_events[0]["exception"] + exception_data = sentry_events[0]["exception"]["values"] + assert exception_data + assert exception_data[0].get("type") == type(exception).__name__ + assert exception_data[0].get("value") == message + + +def test_sentry_events_fixture_set_tag(sentry_events): + tag_name = "swh.test" + tag_value = "test" + set_tag(tag_name, tag_value) + message = "Something went wrong" + capture_message(message) + assert sentry_events + assert "tags" in sentry_events[0] + sentry_events[0]["tags"] == {tag_name: tag_value}