diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index 07edf27..8f5280e 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,174 +1,175 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging import subprocess from typing import Optional, Set, Union +from _pytest.fixtures import FixtureRequest import psycopg2 import pytest from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, Version from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: str, db_name: str, version: Union[str, float, Version], dump_files: Optional[str] = None, no_truncate_tables: Set[str] = set(), ) -> None: super().__init__(user, host, port, db_name, version) if dump_files: self.dump_files = sorted(glob.glob(dump_files), key=sortkey) else: self.dump_files = [] # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) def db_setup(self): conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" ) for fname in self.dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self): """Truncate tables (all but self.no_truncate_tables set) and sequences """ with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): """Initialize db. Create the db if it does not exist. Reset it if it exists.""" with self.cursor() as cur: cur.execute( "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", (self.db_name,), ) self.db_reset() return # initialize the inexistent db with self.cursor() as cur: cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) self.db_setup() def drop(self): """The original DatabaseJanitor implementation prevents new connections from happening, destroys current opened connections and finally drops the database. We actually do not want to drop the db so we instead do nothing and resets (truncate most tables and sequences) the db instead, in order to have some acceptable performance. """ pass # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact( process_fixture_name: str, db_name: Optional[str] = None, dump_files: str = "", no_truncate_tables: Set[str] = {"dbversion"}, ): @pytest.fixture - def postgresql_factory(request): + def postgresql_factory(request: FixtureRequest): """Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ config = factories.get_config(request) proc_fixture = request.getfixturevalue(process_fixture_name) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config["dbname"] with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, no_truncate_tables=no_truncate_tables, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index 7a45019..c9da42c 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,317 +1,318 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from functools import partial import logging from os import path import re from typing import Dict, List, Optional from urllib.parse import unquote, urlparse +from _pytest.fixtures import FixtureRequest import pytest import requests from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( request: requests.Request, context, datadir, ignore_urls: List[str] = [], visits: Optional[Dict] = None, ): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ logger.debug("get_response_cb(%s, %s)", request, context) logger.debug("url: %s", request.url) logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] if filename.endswith("/"): filename = filename[:-1] filename = filename.replace("/", "_") if url.query: filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None fd = open(filepath, "rb") context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture -def datadir(request): +def datadir(request: FixtureRequest) -> str: """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ return path.join(path.dirname(str(request.fspath)), "data") def requests_mock_datadir_factory( ignore_urls: List[str] = [], has_multi_visit: bool = False ): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: - requests_mock_datadir_factory([]): This computes the file name from the query and always returns the same result. - requests_mock_datadir_factory(has_multi_visit=True): This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. - requests_mock_datadir_factory(ignore_urls=['url1', 'url2']): This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} requests_mock.get( re.compile("https?://"), body=partial( get_response_cb, ignore_urls=ignore_urls, visits=visits, datadir=datadir, ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory([]) # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ with app.test_client() as client: yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( request.url, method=request.method, headers=request.headers.items(), data=request.body, ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) -def _push_request_context(request): +def _push_request_context(request: FixtureRequest): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ if "app" not in request.fixturenames: return app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py index df310ce..8d02b2c 100644 --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -1,364 +1,364 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import pkg_resources.extern.packaging.version import pytest import yaml from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): @pytest.fixture - def tmp_path(request): + def tmp_path(): import pathlib import tempfile with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { "a": ("int", 2), "b": ("string", "default-string"), "c": ("bool", True), "d": ("int", 10), "e": ("int", None), "f": ("bool", None), "g": ("string", None), "h": ("bool", True), "i": ("bool", True), "ls": ("list[str]", ["a", "b", "c"]), "li": ("list[int]", [42, 43]), } other_default_conf = { "a": ("int", 3), } full_default_conf = default_conf.copy() full_default_conf["a"] = other_default_conf["a"] parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { "a": 1, "b": "this is a string", "c": True, "d": 10, "e": None, "f": None, "g": None, "h": False, "i": True, "ls": ["list", "of", "strings"], "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conf_contents = """ a: 1 b: this is a string c: true h: false ls: list, of, strings li: 1, 2, 3, 4 """ conffile.open("w").write(conf_contents) return conffile @pytest.fixture def swh_config_unreadable(swh_config): # Create an unreadable, proper configuration file os.chmod(str(swh_config), 0o000) yield swh_config # Make the broken perms file readable again to be able to remove them os.chmod(str(swh_config), 0o644) @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) yield perms_broken_dir / swh_config.name # Make the broken perms items readable again to be able to remove them os.chmod(str(perms_broken_dir), 0o755) @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conffile.touch() return conffile def test_read(swh_config): # when res = config.read(str(swh_config), default_conf) # then assert res == parsed_conffile def test_read_no_default_conf(swh_config): """If no default config if provided to read, this should directly parse the config file yaml """ config_path = str(swh_config) actual_config = config.read(config_path) with open(config_path) as f: expected_config = yaml.safe_load(f) assert actual_config == expected_config def test_read_empty_file(): # when res = config.read(None, default_conf) # then assert res == parsed_default_conf def test_support_non_existing_conffile(tmp_path): # when res = config.read(str(tmp_path / "void.yml"), default_conf) # then assert res == parsed_default_conf def test_support_empty_conffile(swh_config_empty): # when res = config.read(str(swh_config_empty), default_conf) # then assert res == parsed_default_conf def test_raise_on_broken_directory_perms(swh_config_unreadable_dir): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable_dir), default_conf) def test_raise_on_broken_file_perms(swh_config_unreadable): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable), default_conf) def test_merge_default_configs(): # when res = config.merge_default_configs(default_conf, other_default_conf) # then assert res == full_default_conf def test_priority_read_nonexist_conf(swh_config): noexist = str(swh_config.parent / "void.yml") # when res = config.priority_read([noexist, str(swh_config)], default_conf) # then assert res == parsed_conffile def test_priority_read_conf_nonexist_empty(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (swh_config, noexist, empty)], default_conf ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (empty, swh_config, noexist)], default_conf ) # then assert res == parsed_default_conf def test_swh_config_paths(): res = config.swh_config_paths("foo/bar.yml") assert res == [ "~/.config/swh/foo/bar.yml", "~/.swh/foo/bar.yml", "/etc/softwareheritage/foo/bar.yml", ] def test_prepare_folder(tmp_path): # given conf = { "path1": str(tmp_path / "path1"), "path2": str(tmp_path / "path2" / "depth1"), } # the folders does not exists assert not os.path.exists(conf["path1"]), "path1 should not exist." assert not os.path.exists(conf["path2"]), "path2 should not exist." # when config.prepare_folders(conf, "path1") # path1 exists but not path2 assert os.path.exists(conf["path1"]), "path1 should now exist!" assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 config.prepare_folders(conf, "path1", "path2") assert os.path.exists(conf["path1"]), "path1 should still exist!" assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { "a": 42, "b": [1, 2, 3], "c": None, "d": {"gheez": 27}, "e": { "ea": "Mr. Bungle", "eb": None, "ec": [11, 12, 13], "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, "ee": 451, }, "f": "Janis", } cfg_b = { "a": 43, "b": [41, 42, 43], "c": "Tom Waits", "d": None, "e": { "ea": "Igorrr", "ec": [51, 52], "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { "a": 43, # b takes precedence "b": [41, 42, 43], # b takes precedence "c": "Tom Waits", # b takes precedence "d": None, # b['d'] takes precedence (explicit None) "e": { "ea": "Igorrr", # a takes precedence "eb": None, # only in a "ec": [51, 52], # b takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Sleepytime Gorilla Museum", # b takes precedence "edc": "Nils Peter Molvaer", }, # only defined in b "ee": 451, }, "f": "Janis", # only defined in a "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { "a": 42, # a takes precedence "b": [1, 2, 3], # a takes precedence "c": None, # a takes precedence "d": {"gheez": 27}, # a takes precedence "e": { "ea": "Mr. Bungle", # a takes precedence "eb": None, # only defined in a "ec": [11, 12, 13], # a takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Faith No More", # a takes precedence "edc": "Nils Peter Molvaer", }, # only in b "ee": 451, }, "f": "Janis", # only in a "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) for v in (1, "str"): with pytest.raises(TypeError): config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): config.merge_configs({"a": {}}, {"a": v}) def test_load_from_envvar_no_environment_var_swh_config_filename_set(): """Without SWH_CONFIG_FILENAME set, load_from_envvar raises""" with pytest.raises(AssertionError, match="SWH_CONFIG_FILENAME environment"): config.load_from_envvar() def test_load_from_envvar_no_default_config(swh_config, monkeypatch): config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar() expected_config = config.read(config_path) assert actual_config == expected_config def test_load_from_envvar_with_default_config(swh_config, monkeypatch): default_config = { "number": 666, "something-cool": ["something", "cool"], } config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar(default_config) expected_config = config.read(config_path) expected_config.update( {"number": 666, "something-cool": ["something", "cool"],} ) assert actual_config == expected_config