diff --git a/Makefile.local b/Makefile.local index 35f062f..cc21b3f 100644 --- a/Makefile.local +++ b/Makefile.local @@ -1 +1 @@ -TEST_DIRS := ./swh/core/tests +TEST_DIRS := ./swh/core/api/tests ./swh/core/db/tests ./swh/core/tests diff --git a/PKG-INFO b/PKG-INFO index 83b6a25..1f73769 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,93 +1,93 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.8.0 +Version: 0.9.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/requirements-db-pytestplugin.txt b/requirements-db-pytestplugin.txt new file mode 100644 index 0000000..54e57b4 --- /dev/null +++ b/requirements-db-pytestplugin.txt @@ -0,0 +1,2 @@ +# requirements for swh.core.db.pytest_plugin +pytest-postgresql diff --git a/requirements-db.txt b/requirements-db.txt index d0f0975..921e04d 100644 --- a/requirements-db.txt +++ b/requirements-db.txt @@ -1,4 +1,3 @@ # requirements for swh.core.db psycopg2 typing-extensions -pytest-postgresql diff --git a/setup.py b/setup.py index 54518ee..4f0555e 100755 --- a/setup.py +++ b/setup.py @@ -1,87 +1,89 @@ #!/usr/bin/env python3 # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from io import open import os from os import path from setuptools import find_packages, setup here = path.abspath(path.dirname(__file__)) # Get the long description from the README file with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() def parse_requirements(*names): requirements = [] for name in names: if name: reqf = "requirements-%s.txt" % name else: reqf = "requirements.txt" if not os.path.exists(reqf): return requirements with open(reqf) as f: for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( name="swh.core", description="Software Heritage core utilities", long_description=long_description, long_description_content_type="text/markdown", python_requires=">=3.7", author="Software Heritage developers", author_email="swh-devel@inria.fr", url="https://forge.softwareheritage.org/diffusion/DCORE/", packages=find_packages(), py_modules=["pytest_swh_core"], scripts=[], install_requires=parse_requirements(None, "swh"), setup_requires=["setuptools-scm"], use_scm_version=True, extras_require={ "testing-core": parse_requirements("test"), "logging": parse_requirements("logging"), - "db": parse_requirements("db"), + "db": parse_requirements("db", "db-pytestplugin"), "testing-db": parse_requirements("test-db"), "http": parse_requirements("http"), # kitchen sink, please do not use - "testing": parse_requirements("test", "test-db", "db", "http", "logging"), + "testing": parse_requirements( + "test", "test-db", "db", "db-pytestplugin", "http", "logging" + ), }, include_package_data=True, entry_points=""" [console_scripts] swh=swh.core.cli:main swh-db-init=swh.core.cli.db:db_init [swh.cli.subcommands] db=swh.core.cli.db [pytest11] pytest_swh_core = swh.core.pytest_plugin """, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", "Development Status :: 5 - Production/Stable", ], project_urls={ "Bug Reports": "https://forge.softwareheritage.org/maniphest", "Funding": "https://www.softwareheritage.org/donate", "Source": "https://forge.softwareheritage.org/source/swh-core", "Documentation": "https://docs.softwareheritage.org/devel/swh-core/", }, ) diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index 83b6a25..1f73769 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,93 +1,93 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.8.0 +Version: 0.9.0 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Project-URL: Documentation, https://docs.softwareheritage.org/devel/swh-core/ Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Requires-Python: >=3.7 Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/SOURCES.txt b/swh.core.egg-info/SOURCES.txt index a3abb0b..e7736df 100644 --- a/swh.core.egg-info/SOURCES.txt +++ b/swh.core.egg-info/SOURCES.txt @@ -1,109 +1,115 @@ .gitignore .pre-commit-config.yaml AUTHORS CODE_OF_CONDUCT.md CONTRIBUTORS LICENSE MANIFEST.in Makefile Makefile.local README.md conftest.py mypy.ini pyproject.toml pytest.ini +requirements-db-pytestplugin.txt requirements-db.txt requirements-http.txt requirements-logging.txt requirements-swh.txt requirements-test-db.txt requirements-test.txt requirements.txt setup.cfg setup.py tox.ini docs/.gitignore docs/Makefile docs/cli.rst docs/conf.py docs/index.rst docs/_static/.placeholder docs/_templates/.placeholder swh/__init__.py swh.core.egg-info/PKG-INFO swh.core.egg-info/SOURCES.txt swh.core.egg-info/dependency_links.txt swh.core.egg-info/entry_points.txt swh.core.egg-info/requires.txt swh.core.egg-info/top_level.txt swh/core/__init__.py swh/core/api_async.py swh/core/collections.py swh/core/config.py swh/core/logger.py swh/core/py.typed swh/core/pytest_plugin.py swh/core/sentry.py swh/core/statsd.py swh/core/tarball.py swh/core/utils.py swh/core/api/__init__.py swh/core/api/asynchronous.py swh/core/api/classes.py swh/core/api/gunicorn_config.py swh/core/api/negotiation.py swh/core/api/serializers.py swh/core/api/tests/__init__.py swh/core/api/tests/conftest.py swh/core/api/tests/server_testing.py swh/core/api/tests/test_async.py swh/core/api/tests/test_classes.py swh/core/api/tests/test_gunicorn.py swh/core/api/tests/test_init.py swh/core/api/tests/test_rpc_client.py swh/core/api/tests/test_rpc_client_server.py swh/core/api/tests/test_rpc_server.py swh/core/api/tests/test_rpc_server_asynchronous.py swh/core/api/tests/test_serializers.py swh/core/cli/__init__.py swh/core/cli/db.py swh/core/db/__init__.py swh/core/db/common.py swh/core/db/db_utils.py swh/core/db/pytest_plugin.py swh/core/db/tests/__init__.py swh/core/db/tests/conftest.py swh/core/db/tests/db_testing.py swh/core/db/tests/test_cli.py swh/core/db/tests/test_db.py -swh/core/db/tests/test_db_utils.py -swh/core/db/tests/data/0-schema.sql -swh/core/db/tests/data/1-data.sql +swh/core/db/tests/data/cli/0-superuser-init.sql +swh/core/db/tests/data/cli/1-schema.sql +swh/core/db/tests/data/cli/3-func.sql +swh/core/db/tests/data/cli/4-data.sql +swh/core/db/tests/pytest_plugin/__init__.py +swh/core/db/tests/pytest_plugin/test_pytest_plugin.py +swh/core/db/tests/pytest_plugin/data/0-schema.sql +swh/core/db/tests/pytest_plugin/data/1-data.sql swh/core/sql/log-schema.sql swh/core/tests/__init__.py swh/core/tests/test_cli.py swh/core/tests/test_collections.py swh/core/tests/test_config.py swh/core/tests/test_logger.py swh/core/tests/test_pytest_plugin.py swh/core/tests/test_statsd.py swh/core/tests/test_tarball.py swh/core/tests/test_utils.py swh/core/tests/data/archives/groff-1.02.tar.Z swh/core/tests/data/archives/hello.tar swh/core/tests/data/archives/hello.tar.bz2 swh/core/tests/data/archives/hello.tar.gz swh/core/tests/data/archives/hello.tar.lz swh/core/tests/data/archives/hello.tar.x swh/core/tests/data/archives/hello.zip swh/core/tests/data/http_example.com/something.json swh/core/tests/data/https_example.com/file.json swh/core/tests/data/https_example.com/file.json,name=doe,firstname=jane swh/core/tests/data/https_example.com/file.json_visit1 swh/core/tests/data/https_example.com/other.json swh/core/tests/data/https_forge.s.o/api_diffusion,attachments[uris]=1 swh/core/tests/data/https_www.reference.com/web,q=What+Is+an+Example+of+a+URL?,qo=contentPageRelatedSearch,o=600605,l=dir,sga=1 swh/core/tests/fixture/__init__.py swh/core/tests/fixture/conftest.py swh/core/tests/fixture/test_pytest_plugin.py swh/core/tests/fixture/data/https_example.com/file.json \ No newline at end of file diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 8ffc3b0..832a2ee 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,456 +1,465 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union, ) from flask import Flask, Request, Response, abort, request import requests from werkzeug.exceptions import HTTPException from .negotiation import Formatter as FormatterBase from .negotiation import Negotiator as NegotiatorBase from .negotiation import negotiate as _negotiate from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) from .serializers import decode_response from .serializers import encode_data_client as encode_data logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, "application/json" ) def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders=None): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = "json" mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = "msgpack" mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__( self, payload: Optional[Any] = None, response: Optional[requests.Response] = None, ): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if ( self.args and isinstance(self.args[0], dict) and "type" in self.args[0] and "args" in self.args[0] ): return ( f"' ) else: return super().__str__() F = TypeVar("F", bound=Callable) def remote_api_endpoint(path: str, method: str = "POST") -> Callable[[F], F]: def dec(f: F) -> F: f._endpoint_path = path # type: ignore f._method = method # type: ignore return f return dec class APIError(Exception): """API Error""" def __str__(self): return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop("self") post_data.pop("cur", None) post_data.pop("db", None) # Send the request. return self.post(meth._endpoint_path, post_data) if meth_name not in attributes: attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs, ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get("max_retries", 3), pool_connections=kwargs.get("pool_connections", 20), pool_maxsize=kwargs.get("pool_maxsize", 100), ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts["stream"] = True if self.timeout and "timeout" not in opts: opts["timeout"] = self.timeout try: return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, data=data, headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", }, **opts, ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload="404 not found", response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: - if exc_type.__name__ == data["exception"]["type"]: - exception = exc_type(*data["exception"]["args"]) + if exc_type.__name__ == data["type"]: + exception = exc_type(*data["args"]) break else: - exception = RemoteException( - payload=data["exception"], response=response - ) + # old dict encoded exception schema + # TODO: Remove that code once all servers are using new schema + if "exception" in data: + exception = RemoteException( + payload=data["exception"], response=response + ) + else: + exception = RemoteException(payload=data, response=response) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) - else: + # old dict encoded exception schema + # TODO: Remove that code once all servers are using new schema + elif "exception" in data: exception = RemoteException( payload=data["exception"], response=response ) + else: + exception = RemoteException(payload=data, response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f"API HTTP error: {status_code} {response.content}", response=response, ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = "utf-8" encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { "application/x-msgpack": msgpack_dumps, "application/json": json_dumps, } def encode_data_server( data, content_type="application/x-msgpack", extra_type_encoders=None ): encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) return Response(encoded_data, mimetype=content_type,) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: raise ValueError("Wrong content type `%s` for API request" % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) if backend_class is None and backend_factory is not None: raise ValueError( "backend_factory should only be provided if backend_class is" ) self.backend_class = backend_class if backend_class is not None: backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route("/" + meth._endpoint_path, methods=["POST"]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 3b6c139..16298c2 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,293 +1,300 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types from typing import Any, Dict, Tuple, Union from uuid import UUID import arrow import iso8601 import msgpack from requests import Response from swh.core.api.classes import PagedResult def encode_datetime(dt: datetime.datetime) -> str: """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" if dt.tzinfo is None: raise ValueError(f"{dt} is a naive datetime.") return dt.isoformat() def _encode_paged_result(obj: PagedResult) -> Dict[str, Any]: """Serialize PagedResult to a Dict.""" return { "results": obj.results, "next_page_token": obj.next_page_token, } def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult: """Deserialize Dict into PagedResult""" return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) +def exception_to_dict(exception: Exception) -> Dict[str, Any]: + tb = traceback.format_exception(None, exception, exception.__traceback__) + exc_type = type(exception) + return { + "type": exc_type.__name__, + "module": exc_type.__module__, + "args": exception.args, + "message": str(exception), + "traceback": tb, + } + + +def dict_to_exception(exc_dict: Dict[str, Any]) -> Exception: + temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]]) + return getattr(temp, exc_dict["type"])(*exc_dict["args"]) + + ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), (datetime.datetime, "datetime", encode_datetime), ( datetime.timedelta, "timedelta", lambda o: { "days": o.days, "seconds": o.seconds, "microseconds": o.microseconds, }, ), (UUID, "uuid", str), (PagedResult, "paged_result", _encode_paged_result), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), + (Exception, "exception", exception_to_dict), ] DECODERS = { "arrow": arrow.get, "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, "paged_result": _decode_paged_result, # Only for JSON: "bytes": base64.b85decode, + "exception": dict_to_exception, } class MsgpackExtTypeCodes(Enum): LONG_INT = 1 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType( MsgpackExtTypeCodes.LONG_INT.value, int.to_bytes(obj, length, "big") ) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Support for legacy encodings if b"__datetime__" in obj and obj[b"__datetime__"]: return iso8601.parse_date(obj[b"s"], default_timezone=None) if b"__uuid__" in obj and obj[b"__uuid__"]: return UUID(obj[b"s"]) if b"__timedelta__" in obj and obj[b"__timedelta__"]: return datetime.timedelta(**obj[b"s"]) if b"__arrow__" in obj and obj[b"__arrow__"]: return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) - - -def exception_to_dict(exception): - tb = traceback.format_exception(None, exception, exception.__traceback__) - return { - "exception": { - "type": type(exception).__name__, - "args": exception.args, - "message": str(exception), - "traceback": tb, - } - } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index dafbb81..1fad189 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,256 +1,256 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import msgpack import pytest from swh.core.api.asynchronous import ( Response, RPCServerApp, decode_data, decode_request, encode_msgpack, ) from swh.core.api.serializers import SWHJSONEncoder, json_dumps, msgpack_dumps pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): return Response("toor") STRUCT = { "txt": "something stupid", # 'date': datetime.date(2019, 6, 9), # not supported "datetime": datetime.datetime(2019, 6, 9, 10, 12, tzinfo=datetime.timezone.utc), "timedelta": datetime.timedelta(days=-2, hours=3), "int": 42, "float": 3.14, "subdata": { "int": 42, "datetime": datetime.datetime( 2019, 6, 10, 11, 12, tzinfo=datetime.timezone.utc ), }, "list": [ 42, datetime.datetime(2019, 9, 10, 11, 12, tzinfo=datetime.timezone.utc), "ok", ], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(";")[0].strip() dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) app.router.add_route("GET", "/", root) app.router.add_route("GET", "/struct", struct) app.router.add_route("POST", "/echo", echo) app.router.add_route("GET", "/server_exception", server_exception) app.router.add_route("GET", "/client_error", client_error) app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app @pytest.fixture def cli(async_app, aiohttp_client, loop): return loop.run_until_complete(aiohttp_client(async_app)) async def test_get_simple(cli) -> None: resp = await cli.get("/") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == "toor" async def test_get_server_exception(cli) -> None: resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data["exception"]["type"] == "TestServerException" + assert data["type"] == "TestServerException" async def test_get_client_error(cli) -> None: resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data["exception"]["type"] == "TestClientError" + assert data["type"] == "TestClientError" async def test_get_simple_nego(cli) -> None: for ctype in ("x-msgpack", "json"): resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == "toor" async def test_get_struct(cli) -> None: """Test returned structured from a simple GET data is OK""" resp = await cli.get("/struct") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(cli) -> None: """Test returned structured from a simple GET data is OK""" for ctype in ("x-msgpack", "json"): resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(cli) -> None: """Test that msgpack encoded posted struct data is returned as is""" # simple struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps({"toto": 42}), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps(STRUCT), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(cli) -> None: """Test that json encoded posted struct data is returned as is""" resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(cli) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(cli) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo-no-nego", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT def test_async_decode_data_failure(): with pytest.raises(ValueError, match="Wrong content type"): decode_data("some-data", "unknown-content-type") @pytest.mark.parametrize("data", [None, "", {}, []]) def test_async_decode_data_empty_cases(data): assert decode_data(data, "unknown-content-type") == {} @pytest.mark.parametrize( "data,content_type,encode_data_fn", [ ({"a": 1}, "application/json", json_dumps), ({"a": 1}, "application/x-msgpack", msgpack_dumps), ], ) def test_async_decode_data_nominal(data, content_type, encode_data_fn): actual_data = decode_data(encode_data_fn(data), content_type) assert actual_data == data diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py index 81b0afa..0d4e70a 100644 --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -1,117 +1,130 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core.api import ( RemoteException, RPCClient, RPCServerApp, encode_data_server, error_handler, remote_api_endpoint, ) # this class is used on the server part class RPCTest: @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): raise TypeError("Did I pass through?") + @remote_api_endpoint("raise_exception_exc_arg") + def raise_exception_exc_arg(self): + raise Exception(Exception("error")) + # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient # proxy class from this does not handle inheritance properly. # We do add an endpoint on the client side that has no implementation # server-side to test this very situation (in should generate a 404) class RPCTest2: @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data @remote_api_endpoint("not_on_server") def not_on_server(self, db=None, cur=None): return "ok" @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): return "data" class RPCTestClient(RPCClient): backend_class = RPCTest2 @pytest.fixture def app(): # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py application = RPCServerApp("testapp", backend_class=RPCTest) @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) return application @pytest.fixture def swh_rpc_client_class(): # This fixture is used by the 'swh_rpc_client' fixture # which is defined in swh/core/pytest_plugin.py return RPCTestClient def test_api_client_endpoint_missing(swh_rpc_client): with pytest.raises(AttributeError): swh_rpc_client.missing(data="whatever") def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL with pytest.raises(Exception, match="404 not found"): swh_rpc_client.not_on_server() def test_api_endpoint_kwargs(swh_rpc_client): res = swh_rpc_client.something(data="whatever") assert res == "whatever" res = swh_rpc_client.endpoint(test_data="spam") assert res == "egg" def test_api_endpoint_args(swh_rpc_client): res = swh_rpc_client.something("whatever") assert res == "whatever" res = swh_rpc_client.endpoint("spam") assert res == "egg" def test_api_typeerror(swh_rpc_client): with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() assert exc_info.value.args[0]["type"] == "TypeError" assert exc_info.value.args[0]["args"] == ["Did I pass through?"] assert ( str(exc_info.value) == "" ) + + +def test_api_raise_exception_exc_arg(swh_rpc_client): + with pytest.raises(RemoteException) as exc_info: + swh_rpc_client.post("raise_exception_exc_arg", data={}) + + assert exc_info.value.args[0]["type"] == "Exception" + assert type(exc_info.value.args[0]["args"][0]) == Exception + assert str(exc_info.value.args[0]["args"][0]) == "error" diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 9d7c261..f242541 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,250 +1,271 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple, Union from uuid import UUID import arrow from arrow import Arrow import pytest import requests +from requests.exceptions import ConnectionError from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, decode_response, msgpack_dumps, msgpack_loads, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } TZ = datetime.timezone(datetime.timedelta(minutes=118)) DATA_BYTES = b"123456789\x99\xaf\xff\x00\x12" ENCODED_DATA_BYTES = {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"} DATA_DATETIME = datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=TZ,) ENCODED_DATA_DATETIME = { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", } DATA_TIMEDELTA = datetime.timedelta(64) ENCODED_DATA_TIMEDELTA = { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, } DATA_ARROW = arrow.get("2018-04-25T16:17:53.533672+00:00") ENCODED_DATA_ARROW = {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"} DATA_UUID = UUID("cdd8f804-9db6-40c3-93ab-5955d3836234") ENCODED_DATA_UUID = {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"} # For test demonstration purposes TestPagedResultStr = PagedResult[ Union[UUID, datetime.datetime, datetime.timedelta], str ] DATA_PAGED_RESULT = TestPagedResultStr( results=[DATA_UUID, DATA_DATETIME, DATA_TIMEDELTA], next_page_token="10", ) ENCODED_DATA_PAGED_RESULT = { "d": { "results": [ENCODED_DATA_UUID, ENCODED_DATA_DATETIME, ENCODED_DATA_TIMEDELTA,], "next_page_token": "10", }, "swhtype": "paged_result", } TestPagedResultTuple = PagedResult[Union[str, bytes, Arrow], List[Union[str, UUID]]] DATA_PAGED_RESULT2 = TestPagedResultTuple( results=["data0", DATA_BYTES, DATA_ARROW], next_page_token=["10", DATA_UUID], ) ENCODED_DATA_PAGED_RESULT2 = { "d": { "results": ["data0", ENCODED_DATA_BYTES, ENCODED_DATA_ARROW,], "next_page_token": ["10", ENCODED_DATA_UUID], }, "swhtype": "paged_result", } DATA = { "bytes": DATA_BYTES, "datetime_tz": DATA_DATETIME, "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": DATA_TIMEDELTA, "arrow_date": DATA_ARROW, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": DATA_UUID, "paged-result": DATA_PAGED_RESULT, "paged-result2": DATA_PAGED_RESULT2, } ENCODED_DATA = { "bytes": ENCODED_DATA_BYTES, "datetime_tz": ENCODED_DATA_DATETIME, "datetime_utc": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00",}, "datetime_delta": ENCODED_DATA_TIMEDELTA, "arrow_date": ENCODED_DATA_ARROW, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": ENCODED_DATA_UUID, "paged-result": ENCODED_DATA_PAGED_RESULT, "paged-result2": ENCODED_DATA_PAGED_RESULT2, } def test_serializers_round_trip_json(): json_data = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_data, cls=SWHJSONDecoder) assert actual_data == DATA def test_serializers_round_trip_json_extra_types(): expected_original_data = [ExtraType("baz", DATA), "qux"] data = json.dumps( expected_original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders ) actual_data = json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) assert actual_data == expected_original_data +def test_exception_serializer_round_trip_json(): + error_message = "unreachable host" + json_data = json.dumps( + {"exception": ConnectionError(error_message)}, cls=SWHJSONEncoder + ) + actual_data = json.loads(json_data, cls=SWHJSONDecoder) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_encode_swh_json(): json_str = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_str) assert actual_data == ENCODED_DATA def test_serializers_round_trip_msgpack(): expected_original_data = { **DATA, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, } data = msgpack_dumps(expected_original_data) actual_data = msgpack_loads(data) assert actual_data == expected_original_data def test_serializers_round_trip_msgpack_extra_types(): original_data = [ExtraType("baz", DATA), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) actual_data = msgpack_loads(data, extra_decoders=extra_decoders) assert actual_data == original_data +def test_exception_serializer_round_trip_msgpack(): + error_message = "unreachable host" + data = msgpack_dumps({"exception": ConnectionError(error_message)}) + actual_data = msgpack_loads(data) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_generator_json(): data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder) assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)] def test_serializers_generator_msgpack(): data = msgpack_dumps((i for i in range(5))) assert msgpack_loads(data) == [i for i in range(5)] def test_serializers_decode_response_json(requests_mock): requests_mock.get( "https://example.org/test/data", json=ENCODED_DATA, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == DATA def test_serializers_decode_legacy_msgpack(): legacy_msgpack = { "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", "datetime_tz": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+01:58" ), "datetime_utc": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+00:00" ), "datetime_delta": ( b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" b"days@\xa7seconds\x00\xacmicroseconds\x00" ), "arrow_date": ( b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 2018-04-25T16:17:53.533672+00:00" ), "swhtype": b"\xa4fake", "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", "random_dict": b"\x81\xa7swhtype+", "uuid": ( b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } for k, v in legacy_msgpack.items(): assert msgpack_loads(v) == DATA[k] def test_serializers_encode_native_datetime(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) with pytest.raises(ValueError, match="naive datetime"): msgpack_dumps(dt) def test_serializers_decode_naive_datetime(): expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) # Current encoding assert ( msgpack_loads( b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" b"2015-01-01T12:04:42.231455" ) == expected_dt ) # Legacy encoding assert ( msgpack_loads( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba2015-01-01T12:04:42.231455" ) == expected_dt ) diff --git a/swh/core/db/pytest_plugin.py b/swh/core/db/pytest_plugin.py index 07edf27..8f5280e 100644 --- a/swh/core/db/pytest_plugin.py +++ b/swh/core/db/pytest_plugin.py @@ -1,174 +1,175 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging import subprocess from typing import Optional, Set, Union +from _pytest.fixtures import FixtureRequest import psycopg2 import pytest from pytest_postgresql import factories from pytest_postgresql.janitor import DatabaseJanitor, Version from swh.core.utils import numfile_sortkey as sortkey logger = logging.getLogger(__name__) class SWHDatabaseJanitor(DatabaseJanitor): """SWH database janitor implementation with a a different setup/teardown policy than than the stock one. Instead of dropping, creating and initializing the database for each test, it creates and initializes the db once, then truncates the tables (and sequences) in between tests. This is needed to have acceptable test performances. """ def __init__( self, user: str, host: str, port: str, db_name: str, version: Union[str, float, Version], dump_files: Optional[str] = None, no_truncate_tables: Set[str] = set(), ) -> None: super().__init__(user, host, port, db_name, version) if dump_files: self.dump_files = sorted(glob.glob(dump_files), key=sortkey) else: self.dump_files = [] # do no truncate the following tables self.no_truncate_tables = set(no_truncate_tables) def db_setup(self): conninfo = ( f"host={self.host} user={self.user} port={self.port} dbname={self.db_name}" ) for fname in self.dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", fname, ] ) def db_reset(self): """Truncate tables (all but self.no_truncate_tables set) and sequences """ with psycopg2.connect( dbname=self.db_name, user=self.user, host=self.host, port=self.port, ) as cnx: with cnx.cursor() as cur: cur.execute( "SELECT table_name FROM information_schema.tables " "WHERE table_schema = %s", ("public",), ) all_tables = set(table for (table,) in cur.fetchall()) tables_to_truncate = all_tables - self.no_truncate_tables for table in tables_to_truncate: cur.execute("TRUNCATE TABLE %s CASCADE" % table) cur.execute( "SELECT sequence_name FROM information_schema.sequences " "WHERE sequence_schema = %s", ("public",), ) seqs = set(seq for (seq,) in cur.fetchall()) for seq in seqs: cur.execute("ALTER SEQUENCE %s RESTART;" % seq) cnx.commit() def init(self): """Initialize db. Create the db if it does not exist. Reset it if it exists.""" with self.cursor() as cur: cur.execute( "SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,) ) db_exists = cur.fetchone()[0] == 1 if db_exists: cur.execute( "UPDATE pg_database SET datallowconn=true WHERE datname = %s;", (self.db_name,), ) self.db_reset() return # initialize the inexistent db with self.cursor() as cur: cur.execute('CREATE DATABASE "{}";'.format(self.db_name)) self.db_setup() def drop(self): """The original DatabaseJanitor implementation prevents new connections from happening, destroys current opened connections and finally drops the database. We actually do not want to drop the db so we instead do nothing and resets (truncate most tables and sequences) the db instead, in order to have some acceptable performance. """ pass # the postgres_fact factory fixture below is mostly a copy of the code # from pytest-postgresql. We need a custom version here to be able to # specify our version of the DBJanitor we use. def postgresql_fact( process_fixture_name: str, db_name: Optional[str] = None, dump_files: str = "", no_truncate_tables: Set[str] = {"dbversion"}, ): @pytest.fixture - def postgresql_factory(request): + def postgresql_factory(request: FixtureRequest): """Fixture factory for PostgreSQL. :param FixtureRequest request: fixture request object :rtype: psycopg2.connection :returns: postgresql client """ config = factories.get_config(request) proc_fixture = request.getfixturevalue(process_fixture_name) pg_host = proc_fixture.host pg_port = proc_fixture.port pg_user = proc_fixture.user pg_options = proc_fixture.options pg_db = db_name or config["dbname"] with SWHDatabaseJanitor( pg_user, pg_host, pg_port, pg_db, proc_fixture.version, dump_files=dump_files, no_truncate_tables=no_truncate_tables, ): connection = psycopg2.connect( dbname=pg_db, user=pg_user, host=pg_host, port=pg_port, options=pg_options, ) yield connection connection.close() return postgresql_factory diff --git a/swh/core/db/tests/data/cli/0-superuser-init.sql b/swh/core/db/tests/data/cli/0-superuser-init.sql new file mode 100644 index 0000000..480018c --- /dev/null +++ b/swh/core/db/tests/data/cli/0-superuser-init.sql @@ -0,0 +1 @@ +create extension if not exists pgcrypto; diff --git a/swh/core/db/tests/data/cli/1-schema.sql b/swh/core/db/tests/data/cli/1-schema.sql new file mode 100644 index 0000000..a5f6d2c --- /dev/null +++ b/swh/core/db/tests/data/cli/1-schema.sql @@ -0,0 +1,13 @@ +-- schema version table which won't get truncated +create table if not exists dbversion ( + version int primary key, + release timestamptz, + description text +); + +-- origin table +create table if not exists origin ( + id bigserial not null, + url text not null, + hash text not null +); diff --git a/swh/core/db/tests/data/cli/3-func.sql b/swh/core/db/tests/data/cli/3-func.sql new file mode 100644 index 0000000..d4dd410 --- /dev/null +++ b/swh/core/db/tests/data/cli/3-func.sql @@ -0,0 +1,6 @@ +create or replace function hash_sha1(text) + returns text + language sql strict immutable +as $$ + select encode(public.digest($1, 'sha1'), 'hex') +$$; diff --git a/swh/core/db/tests/data/cli/4-data.sql b/swh/core/db/tests/data/cli/4-data.sql new file mode 100644 index 0000000..ed29fa1 --- /dev/null +++ b/swh/core/db/tests/data/cli/4-data.sql @@ -0,0 +1,5 @@ +insert into dbversion(version, release, description) +values (1, '2016-02-22 15:56:28.358587+00', 'Work In Progress'); + +insert into origin(url, hash) +values ('https://forge.softwareheritage.org', hash_sha1('https://forge.softwareheritage.org')); diff --git a/swh/core/db/tests/pytest_plugin/__init__.py b/swh/core/db/tests/pytest_plugin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swh/core/db/tests/data/0-schema.sql b/swh/core/db/tests/pytest_plugin/data/0-schema.sql similarity index 100% rename from swh/core/db/tests/data/0-schema.sql rename to swh/core/db/tests/pytest_plugin/data/0-schema.sql diff --git a/swh/core/db/tests/data/1-data.sql b/swh/core/db/tests/pytest_plugin/data/1-data.sql similarity index 100% rename from swh/core/db/tests/data/1-data.sql rename to swh/core/db/tests/pytest_plugin/data/1-data.sql diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/pytest_plugin/test_pytest_plugin.py similarity index 100% rename from swh/core/db/tests/test_db_utils.py rename to swh/core/db/tests/pytest_plugin/test_pytest_plugin.py diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index 236d260..067524a 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,59 +1,257 @@ -# +# Copyright (C) 2019-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import copy +import glob +from os import path from click.testing import CliRunner +import pytest from swh.core.cli.db import db as swhdb +from swh.core.db import BaseDb +from swh.core.db.pytest_plugin import postgresql_fact + + +@pytest.fixture +def cli_runner(): + return CliRunner() + help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. Options: -l, --log-level [NOTSET|DEBUG|INFO|WARNING|ERROR|CRITICAL] Log level (defaults to INFO). --log-config FILENAME Python yaml logging configuration file. --sentry-dsn TEXT DSN of the Sentry instance to report to -h, --help Show this message and exit. Notes: If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. Commands: db Software Heritage database generic tools. """ -def test_swh_help(swhmain): +def test_cli_swh_help(swhmain, cli_runner): swhmain.add_command(swhdb) - runner = CliRunner() - result = runner.invoke(swhmain, ["-h"]) + result = cli_runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output == help_msg help_db_msg = """Usage: swh db [OPTIONS] COMMAND [ARGS]... Software Heritage database generic tools. Options: -C, --config-file FILE Configuration file. -h, --help Show this message and exit. Commands: create Create a database for the Software Heritage . init Initialize a database for the Software Heritage . init-admin Execute superuser-level initialization steps (e.g pg extensions,... """ -def test_swh_db_help(swhmain): +def test_cli_swh_db_help(swhmain, cli_runner): swhmain.add_command(swhdb) - runner = CliRunner() - result = runner.invoke(swhmain, ["db", "-h"]) + result = cli_runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 assert result.output == help_db_msg + + +@pytest.fixture() +def mock_package_sql(mocker, datadir): + """This bypasses the module manipulation to only returns the data test files. + + """ + from swh.core.utils import numfile_sortkey as sortkey + + mock_sql_files = mocker.patch("swh.core.cli.db.get_sql_for_package") + sql_files = sorted(glob.glob(path.join(datadir, "cli", "*.sql")), key=sortkey) + mock_sql_files.return_value = sql_files + return mock_sql_files + + +# We do not want the truncate behavior for those tests +test_db = postgresql_fact( + "postgresql_proc", db_name="clidb", no_truncate_tables={"dbversion", "origin"} +) + + +@pytest.fixture +def swh_db_cli(cli_runner, monkeypatch, test_db): + """This initializes a cli_runner and sets the correct environment variable expected by + the cli to run appropriately (when not specifying the --db-name flag) + + """ + db_params = test_db.get_dsn_parameters() + monkeypatch.setenv("PGHOST", db_params["host"]) + monkeypatch.setenv("PGUSER", db_params["user"]) + monkeypatch.setenv("PGPORT", db_params["port"]) + + return cli_runner, db_params + + +def craft_conninfo(test_db, dbname=None) -> str: + """Craft conninfo string out of the test_db object. This also allows to override the + dbname.""" + db_params = test_db.get_dsn_parameters() + if dbname: + params = copy.deepcopy(db_params) + params["dbname"] = dbname + else: + params = db_params + return "postgresql://{user}@{host}:{port}/{dbname}".format(**params) + + +def test_cli_swh_db_create_and_init_db(cli_runner, test_db, mock_package_sql): + """Create a db then initializing it should be ok + + """ + module_name = "something" + + conninfo = craft_conninfo(test_db, "new-db") + # This creates the db and installs the necessary admin extensions + result = cli_runner.invoke(swhdb, ["create", module_name, "--db-name", conninfo]) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + # This initializes the schema and data + result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + # the origin value in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, installed during db creation step) + with BaseDb.connect(conninfo).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 + + +def test_cli_swh_db_initialization_fail_without_creation_first( + cli_runner, test_db, mock_package_sql +): + """Init command on an inexisting db cannot work + + """ + module_name = "anything" # it's mocked here + conninfo = craft_conninfo(test_db, "inexisting-db") + + result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + # Fails because we cannot connect to an inexisting db + assert result.exit_code == 1, f"Unexpected output: {result.output}" + + +def test_cli_swh_db_initialization_fail_without_extension( + cli_runner, test_db, mock_package_sql +): + """Init command cannot work without privileged extension. + + In this test, the schema needs privileged extension to work. + + """ + module_name = "anything" # it's mocked here + conninfo = craft_conninfo(test_db) + + result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + # Fails as the function `public.digest` is not installed, init-admin calls is needed + # first (the next tests show such behavior) + assert result.exit_code == 1, f"Unexpected output: {result.output}" + + +def test_cli_swh_db_initialization_works_with_flags( + cli_runner, test_db, mock_package_sql +): + """Init commands with carefully crafted libpq conninfo works + + """ + module_name = "anything" # it's mocked here + conninfo = craft_conninfo(test_db) + + result = cli_runner.invoke( + swhdb, ["init-admin", module_name, "--db-name", conninfo] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + result = cli_runner.invoke(swhdb, ["init", module_name, "--db-name", conninfo]) + + assert result.exit_code == 0, f"Unexpected output: {result.output}" + # the origin values in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, init-admin calls installs it) + with BaseDb.connect(test_db.dsn).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 + + +def test_cli_swh_db_initialization_with_env(swh_db_cli, mock_package_sql, test_db): + """Init commands with standard environment variables works + + """ + module_name = "anything" # it's mocked here + cli_runner, db_params = swh_db_cli + result = cli_runner.invoke( + swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + result = cli_runner.invoke( + swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + ) + + assert result.exit_code == 0, f"Unexpected output: {result.output}" + # the origin values in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, init-admin calls installs it) + with BaseDb.connect(test_db.dsn).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 + + +def test_cli_swh_db_initialization_idempotent(swh_db_cli, mock_package_sql, test_db): + """Multiple runs of the init commands are idempotent + + """ + module_name = "anything" # mocked + cli_runner, db_params = swh_db_cli + + result = cli_runner.invoke( + swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + result = cli_runner.invoke( + swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + result = cli_runner.invoke( + swhdb, ["init-admin", module_name, "--db-name", db_params["dbname"]] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + result = cli_runner.invoke( + swhdb, ["init", module_name, "--db-name", db_params["dbname"]] + ) + assert result.exit_code == 0, f"Unexpected output: {result.output}" + + # the origin values in the scripts uses a hash function (which implementation wise + # uses a function from the pgcrypt extension, init-admin calls installs it) + with BaseDb.connect(test_db.dsn).cursor() as cur: + cur.execute("select * from origin") + origins = cur.fetchall() + assert len(origins) == 1 diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index 7a45019..c9da42c 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,317 +1,318 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from functools import partial import logging from os import path import re from typing import Dict, List, Optional from urllib.parse import unquote, urlparse +from _pytest.fixtures import FixtureRequest import pytest import requests from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( request: requests.Request, context, datadir, ignore_urls: List[str] = [], visits: Optional[Dict] = None, ): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ logger.debug("get_response_cb(%s, %s)", request, context) logger.debug("url: %s", request.url) logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] if filename.endswith("/"): filename = filename[:-1] filename = filename.replace("/", "_") if url.query: filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None fd = open(filepath, "rb") context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture -def datadir(request): +def datadir(request: FixtureRequest) -> str: """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ return path.join(path.dirname(str(request.fspath)), "data") def requests_mock_datadir_factory( ignore_urls: List[str] = [], has_multi_visit: bool = False ): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: - requests_mock_datadir_factory([]): This computes the file name from the query and always returns the same result. - requests_mock_datadir_factory(has_multi_visit=True): This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. - requests_mock_datadir_factory(ignore_urls=['url1', 'url2']): This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} requests_mock.get( re.compile("https?://"), body=partial( get_response_cb, ignore_urls=ignore_urls, visits=visits, datadir=datadir, ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory([]) # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ with app.test_client() as client: yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( request.url, method=request.method, headers=request.headers.items(), data=request.body, ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) -def _push_request_context(request): +def _push_request_context(request: FixtureRequest): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ if "app" not in request.fixturenames: return app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py index df310ce..8d02b2c 100644 --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -1,364 +1,364 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import pkg_resources.extern.packaging.version import pytest import yaml from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): @pytest.fixture - def tmp_path(request): + def tmp_path(): import pathlib import tempfile with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { "a": ("int", 2), "b": ("string", "default-string"), "c": ("bool", True), "d": ("int", 10), "e": ("int", None), "f": ("bool", None), "g": ("string", None), "h": ("bool", True), "i": ("bool", True), "ls": ("list[str]", ["a", "b", "c"]), "li": ("list[int]", [42, 43]), } other_default_conf = { "a": ("int", 3), } full_default_conf = default_conf.copy() full_default_conf["a"] = other_default_conf["a"] parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { "a": 1, "b": "this is a string", "c": True, "d": 10, "e": None, "f": None, "g": None, "h": False, "i": True, "ls": ["list", "of", "strings"], "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conf_contents = """ a: 1 b: this is a string c: true h: false ls: list, of, strings li: 1, 2, 3, 4 """ conffile.open("w").write(conf_contents) return conffile @pytest.fixture def swh_config_unreadable(swh_config): # Create an unreadable, proper configuration file os.chmod(str(swh_config), 0o000) yield swh_config # Make the broken perms file readable again to be able to remove them os.chmod(str(swh_config), 0o644) @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) yield perms_broken_dir / swh_config.name # Make the broken perms items readable again to be able to remove them os.chmod(str(perms_broken_dir), 0o755) @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder conffile = tmp_path / "config.yml" conffile.touch() return conffile def test_read(swh_config): # when res = config.read(str(swh_config), default_conf) # then assert res == parsed_conffile def test_read_no_default_conf(swh_config): """If no default config if provided to read, this should directly parse the config file yaml """ config_path = str(swh_config) actual_config = config.read(config_path) with open(config_path) as f: expected_config = yaml.safe_load(f) assert actual_config == expected_config def test_read_empty_file(): # when res = config.read(None, default_conf) # then assert res == parsed_default_conf def test_support_non_existing_conffile(tmp_path): # when res = config.read(str(tmp_path / "void.yml"), default_conf) # then assert res == parsed_default_conf def test_support_empty_conffile(swh_config_empty): # when res = config.read(str(swh_config_empty), default_conf) # then assert res == parsed_default_conf def test_raise_on_broken_directory_perms(swh_config_unreadable_dir): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable_dir), default_conf) def test_raise_on_broken_file_perms(swh_config_unreadable): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable), default_conf) def test_merge_default_configs(): # when res = config.merge_default_configs(default_conf, other_default_conf) # then assert res == full_default_conf def test_priority_read_nonexist_conf(swh_config): noexist = str(swh_config.parent / "void.yml") # when res = config.priority_read([noexist, str(swh_config)], default_conf) # then assert res == parsed_conffile def test_priority_read_conf_nonexist_empty(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (swh_config, noexist, empty)], default_conf ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): noexist = swh_config.parent / "void.yml" empty = swh_config.parent / "empty.yml" empty.touch() # when res = config.priority_read( [str(p) for p in (empty, swh_config, noexist)], default_conf ) # then assert res == parsed_default_conf def test_swh_config_paths(): res = config.swh_config_paths("foo/bar.yml") assert res == [ "~/.config/swh/foo/bar.yml", "~/.swh/foo/bar.yml", "/etc/softwareheritage/foo/bar.yml", ] def test_prepare_folder(tmp_path): # given conf = { "path1": str(tmp_path / "path1"), "path2": str(tmp_path / "path2" / "depth1"), } # the folders does not exists assert not os.path.exists(conf["path1"]), "path1 should not exist." assert not os.path.exists(conf["path2"]), "path2 should not exist." # when config.prepare_folders(conf, "path1") # path1 exists but not path2 assert os.path.exists(conf["path1"]), "path1 should now exist!" assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 config.prepare_folders(conf, "path1", "path2") assert os.path.exists(conf["path1"]), "path1 should still exist!" assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { "a": 42, "b": [1, 2, 3], "c": None, "d": {"gheez": 27}, "e": { "ea": "Mr. Bungle", "eb": None, "ec": [11, 12, 13], "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, "ee": 451, }, "f": "Janis", } cfg_b = { "a": 43, "b": [41, 42, 43], "c": "Tom Waits", "d": None, "e": { "ea": "Igorrr", "ec": [51, 52], "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { "a": 43, # b takes precedence "b": [41, 42, 43], # b takes precedence "c": "Tom Waits", # b takes precedence "d": None, # b['d'] takes precedence (explicit None) "e": { "ea": "Igorrr", # a takes precedence "eb": None, # only in a "ec": [51, 52], # b takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Sleepytime Gorilla Museum", # b takes precedence "edc": "Nils Peter Molvaer", }, # only defined in b "ee": 451, }, "f": "Janis", # only defined in a "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { "a": 42, # a takes precedence "b": [1, 2, 3], # a takes precedence "c": None, # a takes precedence "d": {"gheez": 27}, # a takes precedence "e": { "ea": "Mr. Bungle", # a takes precedence "eb": None, # only defined in a "ec": [11, 12, 13], # a takes precedence "ed": { "eda": "Secret Chief 3", # only in a "edb": "Faith No More", # a takes precedence "edc": "Nils Peter Molvaer", }, # only in b "ee": 451, }, "f": "Janis", # only in a "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) for v in (1, "str"): with pytest.raises(TypeError): config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): config.merge_configs({"a": {}}, {"a": v}) def test_load_from_envvar_no_environment_var_swh_config_filename_set(): """Without SWH_CONFIG_FILENAME set, load_from_envvar raises""" with pytest.raises(AssertionError, match="SWH_CONFIG_FILENAME environment"): config.load_from_envvar() def test_load_from_envvar_no_default_config(swh_config, monkeypatch): config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar() expected_config = config.read(config_path) assert actual_config == expected_config def test_load_from_envvar_with_default_config(swh_config, monkeypatch): default_config = { "number": 666, "something-cool": ["something", "cool"], } config_path = str(swh_config) monkeypatch.setenv("SWH_CONFIG_FILENAME", config_path) actual_config = config.load_from_envvar(default_config) expected_config = config.read(config_path) expected_config.update( {"number": 666, "something-cool": ["something", "cool"],} ) assert actual_config == expected_config