diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 380c658..5aaa4c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,46 +1,45 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.4.0 hooks: - id: trailing-whitespace - id: flake8 - id: check-json - id: check-yaml - repo: https://github.com/codespell-project/codespell rev: v1.16.0 hooks: - id: codespell - repo: local hooks: - id: mypy name: mypy entry: mypy args: [swh] pass_filenames: false language: system types: [python] +- repo: https://github.com/python/black + rev: 19.10b0 + hooks: + - id: black + # unfortunately, we are far from being able to enable this... # - repo: https://github.com/PyCQA/pydocstyle.git # rev: 4.0.0 # hooks: # - id: pydocstyle # name: pydocstyle # description: pydocstyle is a static analysis tool for checking compliance with Python docstring conventions. # entry: pydocstyle --convention=google # language: python # types: [python] -# black requires py3.6+ -#- repo: https://github.com/python/black -# rev: 19.3b0 -# hooks: -# - id: black -# language_version: python3 #- repo: https://github.com/asottile/blacken-docs # rev: v1.0.0-1 # hooks: # - id: blacken-docs # additional_dependencies: [black==19.3b0] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5413f6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +target-version = ['py37'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8d79b7e --- /dev/null +++ b/setup.cfg @@ -0,0 +1,6 @@ +[flake8] +# E203: whitespaces before ':' +# E231: missing whitespace after ',' +# W503: line break before binary operator +ignore = E203,E231,W503 +max-line-length = 88 diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 3f1ce49..c7337a4 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,421 +1,453 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle import requests from typing import ( - Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Tuple, + Type, + Union, ) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException -from .serializers import (decode_response, - encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, - json_dumps, json_loads, - exception_to_dict) +from .serializers import ( + decode_response, + encode_data_client as encode_data, + msgpack_dumps, + msgpack_loads, + json_dumps, + json_loads, + exception_to_dict, +) -from .negotiation import (Formatter as FormatterBase, - Negotiator as NegotiatorBase, - negotiate as _negotiate) +from .negotiation import ( + Formatter as FormatterBase, + Negotiator as NegotiatorBase, + negotiate as _negotiate, +) logger = logging.getLogger(__name__) # support for content negotiation + class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( - self.accept_mimetypes, 'application/json') + self.accept_mimetypes, "application/json" + ) def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders=None): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): - format = 'json' - mimetypes = ['application/json'] + format = "json" + mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): - format = 'msgpack' - mimetypes = ['application/x-msgpack'] + format = "msgpack" + mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes + class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ - def __init__(self, payload: Optional[Any] = None, - response: Optional[requests.Response] = None): + + def __init__( + self, + payload: Optional[Any] = None, + response: Optional[requests.Response] = None, + ): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): - if self.args and isinstance(self.args[0], dict) \ - and 'type' in self.args[0] and 'args' in self.args[0]: + if ( + self.args + and isinstance(self.args[0], dict) + and "type" in self.args[0] + and "args" in self.args[0] + ): return ( - f'') + f"' + ) else: return super().__str__() def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f + return dec class APIError(Exception): """API Error""" + def __str__(self): - return ('An unexpected error occurred in the backend: {}' - .format(self.args)) + return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" + def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. - backend_class = attributes.get('backend_class', None) + backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break - backend_class = getattr(base, 'backend_class', None) + backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): - if hasattr(meth, '_endpoint_path'): + if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters - post_data = inspect.getcallargs( - wrapped_meth, *args, **kwargs) + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed - self = post_data.pop('self') - post_data.pop('cur', None) - post_data.pop('db', None) + self = post_data.pop("self") + post_data.pop("cur", None) + post_data.pop("db", None) # Send the request. return self.post(meth._endpoint_path, post_data) + if meth_name not in attributes: attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, url, api_exception=None, - timeout=None, chunk_size=4096, - reraise_exceptions=None, **kwargs): + def __init__( + self, + url, + api_exception=None, + timeout=None, + chunk_size=4096, + reraise_exceptions=None, + **kwargs, + ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions - base_url = url if url.endswith('/') else url + '/' + base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( - max_retries=kwargs.get('max_retries', 3), - pool_connections=kwargs.get('pool_connections', 20), - pool_maxsize=kwargs.get('pool_maxsize', 100)) + max_retries=kwargs.get("max_retries", 3), + pool_connections=kwargs.get("pool_connections", 20), + pool_maxsize=kwargs.get("pool_maxsize", 100), + ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): - return '%s%s' % (self.url, endpoint) + return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): - if 'chunk_size' in opts: + if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. - opts['stream'] = True - if self.timeout and 'timeout' not in opts: - opts['timeout'] = self.timeout + opts["stream"] = True + if self.timeout and "timeout" not in opts: + opts["timeout"] = self.timeout try: - return getattr(self.session, verb)( - self._url(endpoint), - **opts - ) + return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) - chunk_size = opts.pop('chunk_size', self.chunk_size) + chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( - 'post', endpoint, data=data, - headers={'content-type': 'application/x-msgpack', - 'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + "post", + endpoint, + data=data, + headers={ + "content-type": "application/x-msgpack", + "accept": "application/x-msgpack", + }, + **opts, + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post def get(self, endpoint, **opts): - chunk_size = opts.pop('chunk_size', self.chunk_size) + chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( - 'get', endpoint, - headers={'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: - raise RemoteException(payload='404 not found', response=response) + raise RemoteException(payload="404 not found", response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: - if exc_type.__name__ == data['exception']['type']: - exception = exc_type(*data['exception']['args']) + if exc_type.__name__ == data["exception"]["type"]: + exception = exc_type(*data["exception"]["args"]) break else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data["exception"], response=response + ) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) - if 'exception_pickled' in data: - exception = pickle.loads(data['exception_pickled']) + if "exception_pickled" in data: + exception = pickle.loads(data["exception_pickled"]) else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data["exception"], response=response + ) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( - payload=f'API HTTP error: {status_code} {response.content}', - response=response) + payload=f"API HTTP error: {status_code} {response.content}", + response=response, + ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) - return decode_response( - response, extra_decoders=self.extra_type_decoders) + return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): - return '<{} url={}>'.format(self.__class__.__name__, self.url) + return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" - encoding = 'utf-8' - encoding_errors = 'surrogateescape' + + encoding = "utf-8" + encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { - 'application/x-msgpack': msgpack_dumps, - 'application/json': json_dumps, + "application/x-msgpack": msgpack_dumps, + "application/json": json_dumps, } def encode_data_server( - data, content_type='application/x-msgpack', extra_type_encoders=None): - encoded_data = ENCODERS[content_type]( - data, extra_encoders=extra_type_encoders) - return Response( - encoded_data, - mimetype=content_type, - ) + data, content_type="application/x-msgpack", extra_type_encoders=None +): + encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) + return Response(encoded_data, mimetype=content_type,) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} - if content_type == 'application/x-msgpack': + if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) - elif content_type == 'application/json': + elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) + r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError("Wrong content type `%s` for API request" % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ + request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, *args, backend_class=None, backend_factory=None, - **kwargs): + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): - if hasattr(meth, '_endpoint_path'): + if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request - @self.route('/'+meth._endpoint_path, methods=['POST']) + @self.route("/" + meth._endpoint_path, methods=["POST"]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request( - request, extra_decoders=self.extra_type_decoders) + kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index 1746915..0483043 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,97 +1,97 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import OrderedDict import logging from typing import Tuple, Type import aiohttp.web from deprecated import deprecated import multidict from .serializers import msgpack_dumps, msgpack_loads from .serializers import json_dumps, json_loads from .serializers import exception_to_dict from aiohttp_utils import negotiation, Response def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), - headers=multidict.MultiDict( - {'Content-Type': 'application/x-msgpack'}), - **kwargs + headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), + **kwargs, ) encode_data_server = Response def render_msgpack(request, data): return msgpack_dumps(data) def render_json(request, data): return json_dumps(data) async def decode_request(request): - content_type = request.headers.get('Content-Type').split(';')[0].strip() + content_type = request.headers.get("Content-Type").split(";")[0].strip() data = await request.read() if not data: return {} - if content_type == 'application/x-msgpack': + if content_type == "application/x-msgpack": r = msgpack_loads(data) - elif content_type == 'application/json': + elif content_type == "application/json": r = json_loads(data) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError("Wrong content type `%s` for API request" % content_type) return r async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) res = exception_to_dict(e) if isinstance(e, app.client_exception_classes): status = 400 else: status = 500 return encode_data_server(res, status=status) + return middleware_handler class RPCServerApp(aiohttp.web.Application): client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" def __init__(self, *args, middlewares=(), **kwargs): middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. - renderers = OrderedDict([ - ('application/json', render_json), - ('application/x-msgpack', render_msgpack), - ]) + renderers = OrderedDict( + [ + ("application/json", render_json), + ("application/x-msgpack", render_msgpack), + ] + ) nego_middleware = negotiation.negotiation_middleware( - renderers=renderers, - force_rendering=True) + renderers=renderers, force_rendering=True + ) middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) -@deprecated(version='0.0.64', - reason='Use the RPCServerApp instead') +@deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/gunicorn_config.py b/swh/core/api/gunicorn_config.py index 56bb170..ce618e0 100644 --- a/swh/core/api/gunicorn_config.py +++ b/swh/core/api/gunicorn_config.py @@ -1,30 +1,39 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """Default values for gunicorn's configuration. Other packages may override them by importing `*` from this module and redefining functions and variables they want. May be imported by gunicorn using `--config 'python:swh.core.api.gunicorn_config'`.""" from ..sentry import init_sentry def post_fork( - server, worker, *, default_sentry_dsn=None, flask=True, - sentry_integrations=None, extra_sentry_kwargs={}): + server, + worker, + *, + default_sentry_dsn=None, + flask=True, + sentry_integrations=None, + extra_sentry_kwargs={}, +): # Initializes sentry as soon as possible in gunicorn's worker processes. sentry_integrations = sentry_integrations or [] if flask: from sentry_sdk.integrations.flask import FlaskIntegration + sentry_integrations.append(FlaskIntegration()) init_sentry( - default_sentry_dsn, integrations=sentry_integrations, - extra_kwargs=extra_sentry_kwargs) + default_sentry_dsn, + integrations=sentry_integrations, + extra_kwargs=extra_sentry_kwargs, + ) diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py index 1322862..4e2abab 100644 --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -1,159 +1,157 @@ # This code is a partial and adapted copy of # https://github.com/nickstenning/negotiate # # Copyright 2012-2013 Nick Stenning # 2019 The Software Heritage developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import defaultdict from decorator import decorator from inspect import getcallargs -from typing import Any, List, Optional, Callable, \ - Type, NoReturn, DefaultDict +from typing import Any, List, Optional, Callable, Type, NoReturn, DefaultDict from requests import Response class FormatterNotFound(Exception): pass class Formatter: format: Optional[str] = None mimetypes: List[str] = [] def __init__(self, request_mimetype: Optional[str] = None) -> None: if request_mimetype is None or request_mimetype not in self.mimetypes: try: self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( - "%s.mimetypes should be a non-empty list" % - self.__class__.__name__) + "%s.mimetypes should be a non-empty list" % self.__class__.__name__ + ) else: self.response_mimetype = request_mimetype def configure(self) -> None: pass def render(self, obj: Any) -> bytes: raise NotImplementedError( - "render() should be implemented by Formatter subclasses") + "render() should be implemented by Formatter subclasses" + ) def __call__(self, obj: Any) -> Response: return self._make_response( - self.render(obj), content_type=self.response_mimetype) + self.render(obj), content_type=self.response_mimetype + ) def _make_response(self, body: bytes, content_type: str) -> Response: raise NotImplementedError( "_make_response() should be implemented by " "framework-specific subclasses of Formatter" ) class Negotiator: - def __init__(self, func: Callable[..., Any]) -> None: self.func = func self._formatters: List[Type[Formatter]] = [] self._formatters_by_format: DefaultDict = defaultdict(list) self._formatters_by_mimetype: DefaultDict = defaultdict(list) def __call__(self, *args, **kwargs) -> Response: result = self.func(*args, **kwargs) - format = getcallargs(self.func, *args, **kwargs).get('format') + format = getcallargs(self.func, *args, **kwargs).get("format") mimetype = self.best_mimetype() try: formatter = self.get_formatter(format, mimetype) except FormatterNotFound as e: return self._abort(404, str(e)) return formatter(result) - def register_formatter(self, formatter: Type[Formatter], - *args, **kwargs) -> None: + def register_formatter(self, formatter: Type[Formatter], *args, **kwargs) -> None: self._formatters.append(formatter) - self._formatters_by_format[formatter.format].append( - (formatter, args, kwargs)) + self._formatters_by_format[formatter.format].append((formatter, args, kwargs)) for mimetype in formatter.mimetypes: - self._formatters_by_mimetype[mimetype].append( - (formatter, args, kwargs)) + self._formatters_by_mimetype[mimetype].append((formatter, args, kwargs)) - def get_formatter(self, format: Optional[str] = None, - mimetype: Optional[str] = None) -> Formatter: + def get_formatter( + self, format: Optional[str] = None, mimetype: Optional[str] = None + ) -> Formatter: if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " - "kwargs to be set") + "kwargs to be set" + ) if format is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_format[format][0]) + formatter_cls, args, kwargs = self._formatters_by_format[format][0] except IndexError: - raise FormatterNotFound( - "Formatter for format '%s' not found!" % format) + raise FormatterNotFound("Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_mimetype[mimetype][0]) + formatter_cls, args, kwargs = self._formatters_by_mimetype[mimetype][0] except IndexError: raise FormatterNotFound( - "Formatter for mimetype '%s' not found!" % mimetype) + "Formatter for mimetype '%s' not found!" % mimetype + ) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) return formatter @property def accept_mimetypes(self) -> List[str]: return [m for f in self._formatters for m in f.mimetypes] def best_mimetype(self) -> str: raise NotImplementedError( "best_mimetype() should be implemented in " "framework-specific subclasses of Negotiator" ) def _abort(self, status_code: int, err: Optional[str] = None) -> NoReturn: raise NotImplementedError( "_abort() should be implemented in framework-specific " "subclasses of Negotiator" ) -def negotiate(negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], - *args, **kwargs) -> Callable: +def negotiate( + negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], *args, **kwargs +) -> Callable: def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): - if not hasattr(f, 'negotiator'): + if not hasattr(f, "negotiator"): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) return decorator(_negotiate, f) return decorate diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 9d7c9ec..2d01a67 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,250 +1,243 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json import traceback import types from uuid import UUID import arrow import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response ENCODERS = [ - (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), - (datetime.datetime, 'datetime', datetime.datetime.isoformat), - (datetime.timedelta, 'timedelta', lambda o: { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }), - (UUID, 'uuid', str), - + (arrow.Arrow, "arrow", arrow.Arrow.isoformat), + (datetime.datetime, "datetime", datetime.datetime.isoformat), + ( + datetime.timedelta, + "timedelta", + lambda o: { + "days": o.days, + "seconds": o.seconds, + "microseconds": o.microseconds, + }, + ), + (UUID, "uuid", str), # Only for JSON: - (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), + (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS = { - 'arrow': arrow.get, - 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), - 'timedelta': lambda d: datetime.timedelta(**d), - 'uuid': UUID, - + "arrow": arrow.get, + "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), + "timedelta": lambda d: datetime.timedelta(**d), + "uuid": UUID, # Only for JSON: - 'bytes': base64.b85decode, + "bytes": base64.b85decode, } def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: - raise ValueError('Limits were reached. Please, check your input.\n' + - str(e)) + raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: - content_type = response.headers['content-type'] - - if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content, - extra_decoders=extra_decoders) - elif content_type.startswith('application/json'): - r = json_loads(response.text, - extra_decoders=extra_decoders) - elif content_type.startswith('text/'): + content_type = response.headers["content-type"] + + if content_type.startswith("application/x-msgpack"): + r = msgpack_loads(response.content, extra_decoders=extra_decoders) + elif content_type.startswith("application/json"): + r = json_loads(response.text, extra_decoders=extra_decoders) + elif content_type.startswith("text/"): r = response.text else: - raise ValueError('Wrong content type `%s` for API response' - % content_type) + raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders - def default(self, o: Any - ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: + def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { - 'swhtype': type_name, - 'd': encoder(o), + "swhtype": type_name, + "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): - if set(o.keys()) == {'d', 'swhtype'}: - if o['swhtype'] == 'bytes': - return base64.b85decode(o['d']) - decoder = self.decoders.get(o['swhtype']) + if set(o.keys()) == {"d", "swhtype"}: + if o["swhtype"] == "bytes": + return base64.b85decode(o["d"]) + decoder = self.decoders.get(o["swhtype"]) if decoder: - return decoder(self.decode_data(o['d'])) + return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: - return json.dumps(data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: - return json.loads(data, cls=SWHJSONDecoder, - extra_decoders=extra_decoders) + return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { - b'swhtype': type_name, - b'd': encoder(obj), + b"swhtype": type_name, + b"d": encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def decode_types(obj): # Support for current encodings - if set(obj.keys()) == {b'd', b'swhtype'}: - decoder = decoders.get(obj[b'swhtype']) + if set(obj.keys()) == {b"d", b"swhtype"}: + decoder = decoders.get(obj[b"swhtype"]) if decoder: - return decoder(obj[b'd']) + return decoder(obj[b"d"]) # Support for legacy encodings - if b'__datetime__' in obj and obj[b'__datetime__']: - return iso8601.parse_date(obj[b's'], default_timezone=None) - if b'__uuid__' in obj and obj[b'__uuid__']: - return UUID(obj[b's']) - if b'__timedelta__' in obj and obj[b'__timedelta__']: - return datetime.timedelta(**obj[b's']) - if b'__arrow__' in obj and obj[b'__arrow__']: - return arrow.get(obj[b's']) + if b"__datetime__" in obj and obj[b"__datetime__"]: + return iso8601.parse_date(obj[b"s"], default_timezone=None) + if b"__uuid__" in obj and obj[b"__uuid__"]: + return UUID(obj[b"s"]) + if b"__timedelta__" in obj and obj[b"__timedelta__"]: + return datetime.timedelta(**obj[b"s"]) + if b"__arrow__" in obj and obj[b"__arrow__"]: + return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: - return msgpack.unpackb(data, raw=False, - object_hook=decode_types, - strict_map_key=False) + return msgpack.unpackb( + data, raw=False, object_hook=decode_types, strict_map_key=False + ) except TypeError: # msgpack < 0.6.0 - return msgpack.unpackb(data, raw=False, - object_hook=decode_types) + return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 - return msgpack.unpackb(data, encoding='utf-8', - object_hook=decode_types) + return msgpack.unpackb(data, encoding="utf-8", object_hook=decode_types) def exception_to_dict(exception): tb = traceback.format_exception(None, exception, exception.__traceback__) return { - 'exception': { - 'type': type(exception).__name__, - 'args': exception.args, - 'message': str(exception), - 'traceback': tb, + "exception": { + "type": type(exception).__name__, + "args": exception.args, + "message": str(exception), + "traceback": tb, } } diff --git a/swh/core/api/tests/server_testing.py b/swh/core/api/tests/server_testing.py index f007d95..0c6e2f4 100644 --- a/swh/core/api/tests/server_testing.py +++ b/swh/core/api/tests/server_testing.py @@ -1,144 +1,146 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import abc import multiprocessing import os import time from urllib.request import urlopen import aiohttp import aiohttp.test_utils class ServerTestFixtureBaseClass(metaclass=abc.ABCMeta): """Base class for http client/server testing implementations. Override this class to implement the following methods: - process_config: to do something needed for the server configuration (e.g propagate the configuration to other part) - define_worker_function: define the function that will actually run the server. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ + def setUp(self): super().setUp() self.start_server() def tearDown(self): self.stop_server() super().tearDown() def url(self): - return 'http://127.0.0.1:%d/' % self.port + return "http://127.0.0.1:%d/" % self.port def process_config(self): """Process the server's configuration. Do something useful for example, pass along the self.config dictionary inside the self.app. By default, do nothing. """ pass @abc.abstractmethod def define_worker_function(self, app, port): """Define how the actual implementation server will run. """ pass def start_server(self): """ Spawn the API server using multiprocessing. """ self.process = None self.process_config() self.port = aiohttp.test_utils.unused_port() worker_fn = self.define_worker_function() self.process = multiprocessing.Process( target=worker_fn, args=(self.app, self.port) ) self.process.start() # Wait max 5 seconds for server to spawn i = 0 while i < 500: try: urlopen(self.url()) except Exception: i += 1 time.sleep(0.01) else: return def stop_server(self): """ Terminate the API server's process. """ if self.process: self.process.terminate() class ServerTestFixture(ServerTestFixtureBaseClass): """Base class for http client/server testing (e.g flask). Mix this in a test class in order to have access to an http server running in background. Note that the subclass should define a dictionary in self.config that contains the server config. And an application in self.app that corresponds to the type of server the tested client needs. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ + def process_config(self): # WSGI app configuration for key, value in self.config.items(): self.app.config[key] = value def define_worker_function(self): def worker(app, port): # Make Flask 1.0 stop printing its server banner - os.environ['WERKZEUG_RUN_MAIN'] = 'true' + os.environ["WERKZEUG_RUN_MAIN"] = "true" return app.run(port=port, use_reloader=False) return worker class ServerTestFixtureAsync(ServerTestFixtureBaseClass): """Base class for http client/server async testing (e.g aiohttp). Mix this in a test class in order to have access to an http server running in background. Note that the subclass should define an application in self.app that corresponds to the type of server the tested client needs. To ensure test isolation, each test will run in a different server and a different folder. In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ + def define_worker_function(self): def worker(app, port): - return aiohttp.web.run_app(app, port=int(port), - print=lambda *_: None) + return aiohttp.web.run_app(app, port=int(port), print=lambda *_: None) return worker diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index f3cea5b..7d24408 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,223 +1,232 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import msgpack import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder -pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] +pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): - return Response('toor') + return Response("toor") -STRUCT = {'txt': 'something stupid', - # 'date': datetime.date(2019, 6, 9), # not supported - 'datetime': datetime.datetime(2019, 6, 9, 10, 12), - 'timedelta': datetime.timedelta(days=-2, hours=3), - 'int': 42, - 'float': 3.14, - 'subdata': {'int': 42, - 'datetime': datetime.datetime(2019, 6, 10, 11, 12), - }, - 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], - } + +STRUCT = { + "txt": "something stupid", + # 'date': datetime.date(2019, 6, 9), # not supported + "datetime": datetime.datetime(2019, 6, 9, 10, 12), + "timedelta": datetime.timedelta(days=-2, hours=3), + "int": 42, + "float": 3.14, + "subdata": {"int": 42, "datetime": datetime.datetime(2019, 6, 10, 11, 12),}, + "list": [42, datetime.datetime(2019, 9, 10, 11, 12), "ok"], +} async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): - src = src.split(';')[0].strip() - dst = dst.split(';')[0].strip() + src = src.split(";")[0].strip() + dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) - app.router.add_route('GET', '/', root) - app.router.add_route('GET', '/struct', struct) - app.router.add_route('POST', '/echo', echo) - app.router.add_route('GET', '/server_exception', server_exception) - app.router.add_route('GET', '/client_error', client_error) - app.router.add_route('POST', '/echo-no-nego', echo_no_nego) + app.router.add_route("GET", "/", root) + app.router.add_route("GET", "/struct", struct) + app.router.add_route("POST", "/echo", echo) + app.router.add_route("GET", "/server_exception", server_exception) + app.router.add_route("GET", "/client_error", client_error) + app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app async def test_get_simple(async_app, aiohttp_client) -> None: assert async_app is not None cli = await aiohttp_client(async_app) - resp = await cli.get('/') + resp = await cli.get("/") assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) - assert value == 'toor' + assert value == "toor" async def test_get_server_exception(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - resp = await cli.get('/server_exception') + resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data['exception']['type'] == 'TestServerException' + assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - resp = await cli.get('/client_error') + resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data['exception']['type'] == 'TestClientError' + assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): - resp = await cli.get('/', headers={'Accept': 'application/%s' % ctype}) + for ctype in ("x-msgpack", "json"): + resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) - assert (await decode_request(resp)) == 'toor' + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) + assert (await decode_request(resp)) == "toor" async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) - resp = await cli.get('/struct') + resp = await cli.get("/struct") assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): - resp = await cli.get('/struct', - headers={'Accept': 'application/%s' % ctype}) + for ctype in ("x-msgpack", "json"): + resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(async_app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps({'toto': 42})) + "/echo", + headers={"Content-Type": "application/x-msgpack"}, + data=msgpack_dumps({"toto": 42}), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') - assert (await decode_request(resp)) == {'toto': 42} + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") + assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps(STRUCT)) + "/echo", + headers={"Content-Type": "application/x-msgpack"}, + data=msgpack_dumps(STRUCT), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json'}, - data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) + "/echo", + headers={"Content-Type": "application/json"}, + data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') - assert (await decode_request(resp)) == {'toto': 42} + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") + assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json'}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo", + headers={"Content-Type": "application/json"}, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): + for ctype in ("x-msgpack", "json"): resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo", + headers={ + "Content-Type": "application/json", + "Accept": "application/%s" % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): + for ctype in ("x-msgpack", "json"): resp = await cli.post( - '/echo-no-nego', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo-no-nego", + headers={ + "Content-Type": "application/json", + "Accept": "application/%s" % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_gunicorn.py b/swh/core/api/tests/test_gunicorn.py index 1f3cacb..c0d12ef 100644 --- a/swh/core/api/tests/test_gunicorn.py +++ b/swh/core/api/tests/test_gunicorn.py @@ -1,107 +1,116 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import pkg_resources from unittest.mock import patch import swh.core.api.gunicorn_config as gunicorn_config def test_post_fork_default(): - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_not_called() def test_post_fork_with_dsn_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=False, release=None, environment=None, ) def test_post_fork_with_package_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_ENVIRONMENT': 'tests', - 'SWH_MAIN_PACKAGE': 'swh.core'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict( + os.environ, + { + "SWH_SENTRY_DSN": "test_dsn", + "SWH_SENTRY_ENVIRONMENT": "tests", + "SWH_MAIN_PACKAGE": "swh.core", + }, + ): gunicorn_config.post_fork(None, None) - version = pkg_resources.get_distribution('swh.core').version + version = pkg_resources.get_distribution("swh.core").version sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=False, - release='swh.core@' + version, - environment='tests', + release="swh.core@" + version, + environment="tests", ) def test_post_fork_debug(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict( + os.environ, {"SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1"} + ): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=True, release=None, environment=None, ) def test_post_fork_no_flask(): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None, flask=False) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - integrations=[], - debug=False, - release=None, - environment=None, + dsn="test_dsn", integrations=[], debug=False, release=None, environment=None, ) def test_post_fork_extras(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork( - None, None, sentry_integrations=['foo'], - extra_sentry_kwargs={'bar': 'baz'}) + None, + None, + sentry_integrations=["foo"], + extra_sentry_kwargs={"bar": "baz"}, + ) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - integrations=['foo', flask_integration], + dsn="test_dsn", + integrations=["foo", flask_integration], debug=False, - bar='baz', + bar="baz", release=None, environment=None, ) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index eff708e..24a2fba 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,85 +1,85 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import pytest from swh.core.api import remote_api_endpoint, RPCClient from .test_serializers import ExtraType, extra_encoders, extra_decoders @pytest.fixture def rpc_client(requests_mock): class TestStorage: - @remote_api_endpoint('test_endpoint_url') + @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): ... - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): ... - @remote_api_endpoint('serializer_test') + @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): ... - @remote_api_endpoint('overridden/endpoint') + @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): - return 'foo' + return "foo" class Testclient(RPCClient): backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders def overridden_method(self, data): - return 'bar' + return "bar" def callback(request, context): - assert request.headers['Content-Type'] == 'application/x-msgpack' - context.headers['Content-Type'] = 'application/x-msgpack' - if request.path == '/test_endpoint_url': - context.content = b'\xa3egg' - elif request.path == '/path/to/endpoint': - context.content = b'\xa4spam' - elif request.path == '/serializer_test': + assert request.headers["Content-Type"] == "application/x-msgpack" + context.headers["Content-Type"] = "application/x-msgpack" + if request.path == "/test_endpoint_url": + context.content = b"\xa3egg" + elif request.path == "/path/to/endpoint": + context.content = b"\xa4spam" + elif request.path == "/serializer_test": context.content = ( - b'\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b"\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) else: assert False return context.content - requests_mock.post(re.compile('mock://example.com/'), - content=callback) + requests_mock.post(re.compile("mock://example.com/"), content=callback) - return Testclient(url='mock://example.com') + return Testclient(url="mock://example.com") def test_client(rpc_client): - assert hasattr(rpc_client, 'test_endpoint') - assert hasattr(rpc_client, 'something') + assert hasattr(rpc_client, "test_endpoint") + assert hasattr(rpc_client, "something") - res = rpc_client.test_endpoint('spam') - assert res == 'egg' - res = rpc_client.test_endpoint(test_data='spam') - assert res == 'egg' + res = rpc_client.test_endpoint("spam") + assert res == "egg" + res = rpc_client.test_endpoint(test_data="spam") + assert res == "egg" - res = rpc_client.something('whatever') - assert res == 'spam' - res = rpc_client.something(data='whatever') - assert res == 'spam' + res = rpc_client.something("whatever") + assert res == "spam" + res = rpc_client.something(data="whatever") + assert res == "spam" def test_client_extra_serializers(rpc_client): - res = rpc_client.serializer_test(['foo', ExtraType('bar', b'baz')]) - assert res == ExtraType({'spam': 'egg'}, 'qux') + res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) + assert res == ExtraType({"spam": "egg"}, "qux") def test_client_overridden_method(rpc_client): - res = rpc_client.overridden_method('foo') - assert res == 'bar' + res = rpc_client.overridden_method("foo") + assert res == "bar" diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py index 16a84b2..ec651eb 100644 --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -1,107 +1,111 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core.api import remote_api_endpoint, RPCServerApp, RPCClient from swh.core.api import error_handler, encode_data_server, RemoteException # this class is used on the server part class RPCTest: - @remote_api_endpoint('endpoint_url') + @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('raises_typeerror') + @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): - raise TypeError('Did I pass through?') + raise TypeError("Did I pass through?") # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient # proxy class from this does not handle inheritance properly. # We do add an endpoint on the client side that has no implementation # server-side to test this very situation (in should generate a 404) class RPCTest2: - @remote_api_endpoint('endpoint_url') + @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('not_on_server') + @remote_api_endpoint("not_on_server") def not_on_server(self, db=None, cur=None): - return 'ok' + return "ok" - @remote_api_endpoint('raises_typeerror') + @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): - return 'data' + return "data" class RPCTestClient(RPCClient): backend_class = RPCTest2 @pytest.fixture def app(): # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py - application = RPCServerApp('testapp', backend_class=RPCTest) + application = RPCServerApp("testapp", backend_class=RPCTest) + @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) + return application @pytest.fixture def swh_rpc_client_class(): # This fixture is used by the 'swh_rpc_client' fixture # which is defined in swh/core/pytest_plugin.py return RPCTestClient def test_api_client_endpoint_missing(swh_rpc_client): with pytest.raises(AttributeError): - swh_rpc_client.missing(data='whatever') + swh_rpc_client.missing(data="whatever") def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL - with pytest.raises(Exception, match='404 not found'): + with pytest.raises(Exception, match="404 not found"): swh_rpc_client.not_on_server() def test_api_endpoint_kwargs(swh_rpc_client): - res = swh_rpc_client.something(data='whatever') - assert res == 'whatever' - res = swh_rpc_client.endpoint(test_data='spam') - assert res == 'egg' + res = swh_rpc_client.something(data="whatever") + assert res == "whatever" + res = swh_rpc_client.endpoint(test_data="spam") + assert res == "egg" def test_api_endpoint_args(swh_rpc_client): - res = swh_rpc_client.something('whatever') - assert res == 'whatever' - res = swh_rpc_client.endpoint('spam') - assert res == 'egg' + res = swh_rpc_client.something("whatever") + assert res == "whatever" + res = swh_rpc_client.endpoint("spam") + assert res == "egg" def test_api_typeerror(swh_rpc_client): with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() - assert exc_info.value.args[0]['type'] == 'TypeError' - assert exc_info.value.args[0]['args'] == ['Did I pass through?'] - assert str(exc_info.value) \ + assert exc_info.value.args[0]["type"] == "TypeError" + assert exc_info.value.args[0]["args"] == ["Did I pass through?"] + assert ( + str(exc_info.value) == "" + ) diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py index ace42b1..31fdb27 100644 --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -1,118 +1,122 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import json import msgpack from flask import url_for from swh.core.api import remote_api_endpoint, RPCServerApp from swh.core.api import negotiate, JSONFormatter, MsgpackFormatter from .test_serializers import ExtraType, extra_encoders, extra_decoders class MyRPCServerApp(RPCServerApp): extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders @pytest.fixture def app(): class TestStorage: - @remote_api_endpoint('test_endpoint_url') + @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('serializer_test') + @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): - assert data == ['foo', ExtraType('bar', b'baz')] - return ExtraType({'spam': 'egg'}, 'qux') + assert data == ["foo", ExtraType("bar", b"baz")] + return ExtraType({"spam": "egg"}, "qux") - return MyRPCServerApp('testapp', backend_class=TestStorage) + return MyRPCServerApp("testapp", backend_class=TestStorage) def test_api_endpoint(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Content-Type', 'application/json'), - ('Accept', 'application/json')], - data=json.dumps({'data': 'toto'}), + url_for("something"), + headers=[("Content-Type", "application/json"), ("Accept", "application/json")], + data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" def test_api_nego_default(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Content-Type', 'application/json')], - data=json.dumps({'data': 'toto'}), + url_for("something"), + headers=[("Content-Type", "application/json")], + data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Accept', 'application/x-msgpack'), - ('Content-Type', 'application/x-msgpack')], - data=msgpack.dumps({'data': 'toto'}), + url_for("something"), + headers=[ + ("Accept", "application/x-msgpack"), + ("Content-Type", "application/x-msgpack"), + ], + data=msgpack.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' - assert res.data == b'\xa4toto' + assert res.mimetype == "application/x-msgpack" + assert res.data == b"\xa4toto" def test_rpc_server(flask_app_client): res = flask_app_client.post( - url_for('test_endpoint'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], - data=b'\x81\xa9test_data\xa4spam') + url_for("test_endpoint"), + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=b"\x81\xa9test_data\xa4spam", + ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' - assert res.data == b'\xa3egg' + assert res.mimetype == "application/x-msgpack" + assert res.data == b"\xa3egg" def test_rpc_server_extra_serializers(flask_app_client): res = flask_app_client.post( - url_for('serializer_test'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], - data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + url_for("serializer_test"), + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\xa3bar\xc4\x03baz", + ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' + assert res.mimetype == "application/x-msgpack" assert res.data == ( - b'\x82\xc4\x07swhtype\xa9extratype\xc4' - b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b"\x82\xc4\x07swhtype\xa9extratype\xc4" b"\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) def test_api_negotiate_no_extra_encoders(app, flask_app_client): - url = '/test/negotiate/no/extra/encoders' + url = "/test/negotiate/no/extra/encoders" - @app.route(url, methods=['POST']) + @app.route(url, methods=["POST"]) @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) def endpoint(): - return 'test' + return "test" - res = flask_app_client.post( - url, - headers=[('Content-Type', 'application/json')], - ) + res = flask_app_client.post(url, headers=[("Content-Type", "application/json")],) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" assert res.data == b'"test"' diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 041fdc3..1a4f8be 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,172 +1,185 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, - decode_response + decode_response, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): - return f'ExtraType({self.arg1}, {self.arg2})' + return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): - return isinstance(other, ExtraType) \ - and (self.arg1, self.arg2) == (other.arg1, other.arg2) + return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( + other.arg1, + other.arg2, + ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ - (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) + (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { - 'extratype': lambda o: ExtraType(*o), + "extratype": lambda o: ExtraType(*o), } class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { - 'bytes': b'123456789\x99\xaf\xff\x00\x12', - 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), - 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=self.tz), - 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=datetime.timezone.utc), - 'datetime_delta': datetime.timedelta(64), - 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), - 'swhtype': 'fake', - 'swh_dict': {'swhtype': 42, 'd': 'test'}, - 'random_dict': {'swhtype': 43}, - 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), + "bytes": b"123456789\x99\xaf\xff\x00\x12", + "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), + "datetime_tz": datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz + ), + "datetime_utc": datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc + ), + "datetime_delta": datetime.timedelta(64), + "arrow_date": arrow.get("2018-04-25T16:17:53.533672+00:00"), + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, + "uuid": UUID("cdd8f804-9db6-40c3-93ab-5955d3836234"), } self.encoded_data = { - 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, - 'datetime_naive': {'swhtype': 'datetime', - 'd': '2015-01-01T12:04:42.231455'}, - 'datetime_tz': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+01:58'}, - 'datetime_utc': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+00:00'}, - 'datetime_delta': {'swhtype': 'timedelta', - 'd': {'days': 64, 'seconds': 0, - 'microseconds': 0}}, - 'arrow_date': {'swhtype': 'arrow', - 'd': '2018-04-25T16:17:53.533672+00:00'}, - 'swhtype': 'fake', - 'swh_dict': {'swhtype': 42, 'd': 'test'}, - 'random_dict': {'swhtype': 43}, - 'uuid': {'swhtype': 'uuid', - 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, + "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, + "datetime_naive": { + "swhtype": "datetime", + "d": "2015-01-01T12:04:42.231455", + }, + "datetime_tz": { + "swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+01:58", + }, + "datetime_utc": { + "swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+00:00", + }, + "datetime_delta": { + "swhtype": "timedelta", + "d": {"days": 64, "seconds": 0, "microseconds": 0}, + }, + "arrow_date": {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"}, + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, + "uuid": {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"}, } self.legacy_msgpack = { - 'bytes': b'\xc4\x0e123456789\x99\xaf\xff\x00\x12', - 'datetime_naive': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba' - b'2015-01-01T12:04:42.231455' + "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", + "datetime_naive": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" + b"2015-01-01T12:04:42.231455" ), - 'datetime_tz': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 ' - b'2015-03-04T18:25:13.001234+01:58' + "datetime_tz": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " + b"2015-03-04T18:25:13.001234+01:58" ), - 'datetime_utc': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 ' - b'2015-03-04T18:25:13.001234+00:00' + "datetime_utc": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " + b"2015-03-04T18:25:13.001234+00:00" ), - 'datetime_delta': ( - b'\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4' - b'days@\xa7seconds\x00\xacmicroseconds\x00' + "datetime_delta": ( + b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" + b"days@\xa7seconds\x00\xacmicroseconds\x00" ), - 'arrow_date': ( - b'\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 ' - b'2018-04-25T16:17:53.533672+00:00' + "arrow_date": ( + b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 " + b"2018-04-25T16:17:53.533672+00:00" ), - 'swhtype': b'\xa4fake', - 'swh_dict': b'\x82\xa7swhtype*\xa1d\xa4test', - 'random_dict': b'\x81\xa7swhtype+', - 'uuid': ( - b'\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$' - b'cdd8f804-9db6-40c3-93ab-5955d3836234' + "swhtype": b"\xa4fake", + "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", + "random_dict": b"\x81\xa7swhtype+", + "uuid": ( + b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" + b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_round_trip_json_extra_types(self): - original_data = [ExtraType('baz', self.data), 'qux'] + original_data = [ExtraType("baz", self.data), "qux"] - data = json.dumps(original_data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + data = json.dumps( + original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders + ) self.assertEqual( original_data, - json.loads( - data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders), + ) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): original_data = { **self.data, - 'none_dict_key': {None: 42}, + "none_dict_key": {None: 42}, } data = msgpack_dumps(original_data) self.assertEqual(original_data, msgpack_loads(data)) def test_round_trip_msgpack_extra_types(self): - original_data = [ExtraType('baz', self.data), 'qux'] + original_data = [ExtraType("baz", self.data), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( - original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + original_data, msgpack_loads(data, extra_decoders=extra_decoders) + ) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): - mock_requests.get('https://example.org/test/data', - json=self.encoded_data, - headers={'content-type': 'application/json'}) - response = requests.get('https://example.org/test/data') + mock_requests.get( + "https://example.org/test/data", + json=self.encoded_data, + headers={"content-type": "application/json"}, + ) + response = requests.get("https://example.org/test/data") assert decode_response(response) == self.data def test_decode_legacy_msgpack(self): for k, v in self.legacy_msgpack.items(): assert msgpack_loads(v) == self.data[k] diff --git a/swh/core/cli/__init__.py b/swh/core/cli/__init__.py index 9dfb041..bcaf77c 100644 --- a/swh/core/cli/__init__.py +++ b/swh/core/cli/__init__.py @@ -1,115 +1,126 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import logging.config import signal import click import pkg_resources import yaml from ..sentry import init_sentry -LOG_LEVEL_NAMES = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +LOG_LEVEL_NAMES = ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) logger = logging.getLogger(__name__) class AliasedGroup(click.Group): - '''A simple Group that supports command aliases, as well as notes related to - options''' + """A simple Group that supports command aliases, as well as notes related to + options""" def __init__(self, name=None, commands=None, **attrs): - self.option_notes = attrs.pop('option_notes', None) + self.option_notes = attrs.pop("option_notes", None) self.aliases = {} super().__init__(name, commands, **attrs) def get_command(self, ctx, cmd_name): return super().get_command(ctx, self.aliases.get(cmd_name, cmd_name)) def add_alias(self, name, alias): if not isinstance(name, str): name = name.name self.aliases[alias] = name def format_options(self, ctx, formatter): click.Command.format_options(self, ctx, formatter) if self.option_notes: - with formatter.section('Notes'): + with formatter.section("Notes"): formatter.write_text(self.option_notes) self.format_commands(ctx, formatter) def clean_exit_on_signal(signal, frame): """Raise a SystemExit exception to let command-line clients wind themselves down on exit""" raise SystemExit(0) @click.group( - context_settings=CONTEXT_SETTINGS, cls=AliasedGroup, - option_notes='''\ + context_settings=CONTEXT_SETTINGS, + cls=AliasedGroup, + option_notes="""\ If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. -''' +""", +) +@click.option( + "--log-level", + "-l", + default=None, + type=click.Choice(LOG_LEVEL_NAMES), + help="Log level (defaults to INFO).", +) +@click.option( + "--log-config", + default=None, + type=click.File("r"), + help="Python yaml logging configuration file.", +) +@click.option( + "--sentry-dsn", default=None, help="DSN of the Sentry instance to report to" +) +@click.option( + "--sentry-debug/--no-sentry-debug", + default=False, + hidden=True, + help="Enable debugging of sentry", ) -@click.option('--log-level', '-l', default=None, - type=click.Choice(LOG_LEVEL_NAMES), - help="Log level (defaults to INFO).") -@click.option('--log-config', default=None, - type=click.File('r'), - help="Python yaml logging configuration file.") -@click.option('--sentry-dsn', default=None, - help="DSN of the Sentry instance to report to") -@click.option('--sentry-debug/--no-sentry-debug', - default=False, hidden=True, - help="Enable debugging of sentry") @click.pass_context def swh(ctx, log_level, log_config, sentry_dsn, sentry_debug): """Command line interface for Software Heritage. """ signal.signal(signal.SIGTERM, clean_exit_on_signal) signal.signal(signal.SIGINT, clean_exit_on_signal) init_sentry(sentry_dsn, debug=sentry_debug) if log_level is None and log_config is None: - log_level = 'INFO' + log_level = "INFO" if log_config: logging.config.dictConfig(yaml.safe_load(log_config.read())) if log_level: log_level = logging.getLevelName(log_level) logging.root.setLevel(log_level) ctx.ensure_object(dict) - ctx.obj['log_level'] = log_level + ctx.obj["log_level"] = log_level def main(): # Even though swh() sets up logging, we need an earlier basic logging setup # for the next few logging statements logging.basicConfig() # load plugins that define cli sub commands - for entry_point in pkg_resources.iter_entry_points('swh.cli.subcommands'): + for entry_point in pkg_resources.iter_entry_points("swh.cli.subcommands"): try: cmd = entry_point.load() swh.add_command(cmd, name=entry_point.name) except Exception as e: - logger.warning('Could not load subcommand %s: %s', - entry_point.name, str(e)) + logger.warning("Could not load subcommand %s: %s", entry_point.name, str(e)) - return swh(auto_envvar_prefix='SWH') + return swh(auto_envvar_prefix="SWH") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 59f29aa..d018792 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -1,172 +1,188 @@ #!/usr/bin/env python3 # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import glob import logging from os import path, environ import subprocess import warnings warnings.filterwarnings("ignore") # noqa prevent psycopg from telling us sh*t import click from swh.core.cli import CONTEXT_SETTINGS from swh.core.config import read as config_read logger = logging.getLogger(__name__) @click.group(name="db", context_settings=CONTEXT_SETTINGS) -@click.option("--config-file", "-C", default=None, - type=click.Path(exists=True, dir_okay=False), - help="Configuration file.") +@click.option( + "--config-file", + "-C", + default=None, + type=click.Path(exists=True, dir_okay=False), + help="Configuration file.", +) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. """ ctx.ensure_object(dict) if config_file is None: - config_file = environ.get('SWH_CONFIG_FILENAME') + config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @db.command(name="init", context_settings=CONTEXT_SETTINGS) @click.pass_context def init(ctx): """Initialize the database for every Software Heritage module found in the configuration file. For every configuration section in the config file that: 1. has the name of an existing swh package, 2. has credentials for a local db access, it will run the initialization scripts from the swh package against the given database. Example for the config file:: \b storage: cls: local args: db: postgresql:///?service=swh-storage objstorage: cls: remote args: url: http://swh-objstorage:5003/ the command: swh db -C /path/to/config.yml init will initialize the database for the `storage` section using initialization scripts from the `swh.storage` package. """ for modname, cfg in ctx.obj["config"].items(): if cfg.get("cls") == "local" and cfg.get("args"): try: sqlfiles = get_sql_for_package(modname) except click.BadParameter: logger.info( - "Failed to load/find sql initialization files for %s", - modname) + "Failed to load/find sql initialization files for %s", modname + ) if sqlfiles: conninfo = cfg["args"]["db"] for sqlfile in sqlfiles: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", conninfo, "-f", sqlfile, ] ) @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('module', nargs=-1, required=True) -@click.option('--db-name', '-d', help='Database name.', - default='softwareheritage-dev', show_default=True) -@click.option('--create-db/--no-create-db', '-C', - help='Attempt to create the database.', - default=False) +@click.argument("module", nargs=-1, required=True) +@click.option( + "--db-name", + "-d", + help="Database name.", + default="softwareheritage-dev", + show_default=True, +) +@click.option( + "--create-db/--no-create-db", + "-C", + help="Attempt to create the database.", + default=False, +) def db_init(module, db_name, create_db): """Initialise a database for the Software Heritage . By default, does not attempt to create the database. Example: swh db-init -d swh-test storage If you want to specify non-default postgresql connection parameters, please provide them using standard environment variables. See psql(1) man page (section ENVIRONMENTS) for details. Example: PGPORT=5434 swh db-init indexer """ # put import statements here so we can keep startup time of the main swh # command as short as possible from swh.core.db.tests.db_testing import ( - pg_createdb, pg_restore, DB_DUMP_TYPES, - swh_db_version + pg_createdb, + pg_restore, + DB_DUMP_TYPES, + swh_db_version, ) - logger.debug('db_init %s dn_name=%s', module, db_name) + logger.debug("db_init %s dn_name=%s", module, db_name) dump_files = [] for modname in module: dump_files.extend(get_sql_for_package(modname)) if create_db: # Create the db (or fail silently if already existing) pg_createdb(db_name, check=False) # Try to retrieve the db version if any db_version = swh_db_version(db_name) if not db_version: # Initialize the db - dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) - for x in dump_files] + dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) for x in dump_files] for dump, dtype in dump_files: - click.secho('Loading {}'.format(dump), fg='yellow') + click.secho("Loading {}".format(dump), fg="yellow") pg_restore(db_name, dump, dtype) db_version = swh_db_version(db_name) # TODO: Ideally migrate the version from db_version to the latest # db version - click.secho('DONE database is {} version {}'.format(db_name, db_version), - fg='green', bold=True) + click.secho( + "DONE database is {} version {}".format(db_name, db_version), + fg="green", + bold=True, + ) def get_sql_for_package(modname): from importlib import import_module from swh.core.utils import numfile_sortkey as sortkey if not modname.startswith("swh."): modname = "swh.{}".format(modname) try: m = import_module(modname) except ImportError: raise click.BadParameter("Unable to load module {}".format(modname)) sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( - "Module {} does not provide a db schema " - "(no sql/ dir)".format(modname)) + "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) + ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) diff --git a/swh/core/config.py b/swh/core/config.py index 748fce1..4c7fcc7 100644 --- a/swh/core/config.py +++ b/swh/core/config.py @@ -1,362 +1,364 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import configparser import logging import os import yaml from itertools import chain from copy import deepcopy from typing import Any, Dict, Optional, Tuple logger = logging.getLogger(__name__) SWH_CONFIG_DIRECTORIES = [ - '~/.config/swh', - '~/.swh', - '/etc/softwareheritage', + "~/.config/swh", + "~/.swh", + "/etc/softwareheritage", ] -SWH_GLOBAL_CONFIG = 'global.ini' +SWH_GLOBAL_CONFIG = "global.ini" SWH_DEFAULT_GLOBAL_CONFIG = { - 'max_content_size': ('int', 100 * 1024 * 1024), - 'log_db': ('str', 'dbname=softwareheritage-log'), + "max_content_size": ("int", 100 * 1024 * 1024), + "log_db": ("str", "dbname=softwareheritage-log"), } SWH_CONFIG_EXTENSIONS = [ - '.yml', - '.ini', + ".yml", + ".ini", ] # conversion per type _map_convert_fn = { - 'int': int, - 'bool': lambda x: x.lower() == 'true', - 'list[str]': lambda x: [value.strip() for value in x.split(',')], - 'list[int]': lambda x: [int(value.strip()) for value in x.split(',')], + "int": int, + "bool": lambda x: x.lower() == "true", + "list[str]": lambda x: [value.strip() for value in x.split(",")], + "list[int]": lambda x: [int(value.strip()) for value in x.split(",")], } _map_check_fn = { - 'int': lambda x: isinstance(x, int), - 'bool': lambda x: isinstance(x, bool), - 'list[str]': lambda x: (isinstance(x, list) and - all(isinstance(y, str) for y in x)), - 'list[int]': lambda x: (isinstance(x, list) and - all(isinstance(y, int) for y in x)), + "int": lambda x: isinstance(x, int), + "bool": lambda x: isinstance(x, bool), + "list[str]": lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), + "list[int]": lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } def exists_accessible(file): """Check whether a file exists, and is accessible. Returns: True if the file exists and is accessible False if the file does not exist Raises: PermissionError if the file cannot be read. """ try: os.stat(file) except PermissionError: raise except FileNotFoundError: return False else: if os.access(file, os.R_OK): return True else: raise PermissionError("Permission denied: %r" % file) def config_basepath(config_path): """Return the base path of a configuration file""" - if config_path.endswith(('.ini', '.yml')): + if config_path.endswith((".ini", ".yml")): return config_path[:-4] return config_path def read_raw_config(base_config_path): """Read the raw config corresponding to base_config_path. Can read yml or ini files. """ - yml_file = base_config_path + '.yml' + yml_file = base_config_path + ".yml" if exists_accessible(yml_file): - logger.info('Loading config file %s', yml_file) + logger.info("Loading config file %s", yml_file) with open(yml_file) as f: return yaml.safe_load(f) - ini_file = base_config_path + '.ini' + ini_file = base_config_path + ".ini" if exists_accessible(ini_file): config = configparser.ConfigParser() config.read(ini_file) - if 'main' in config._sections: - logger.info('Loading config file %s', ini_file) - return config._sections['main'] + if "main" in config._sections: + logger.info("Loading config file %s", ini_file) + return config._sections["main"] else: - logger.warning('Ignoring config file %s (no [main] section)', - ini_file) + logger.warning("Ignoring config file %s (no [main] section)", ini_file) return {} def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) - return any(exists_accessible(basepath + extension) - for extension in SWH_CONFIG_EXTENSIONS) + return any( + exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS + ) def read(conf_file=None, default_conf=None): """Read the user's configuration file. Fill in the gap using `default_conf`. `default_conf` is similar to this:: DEFAULT_CONF = { 'a': ('str', '/tmp/swh-loader-git/log'), 'b': ('str', 'dbname=swhloadergit') 'c': ('bool', true) 'e': ('bool', None) 'd': ('int', 10) } If conf_file is None, return the default config. """ conf = {} if conf_file: base_config_path = config_basepath(os.path.expanduser(conf_file)) conf = read_raw_config(base_config_path) if not default_conf: default_conf = {} # remaining missing default configuration key are set # also type conversion is enforced for underneath layer for key in default_conf: nature_type, default_value = default_conf[key] val = conf.get(key, None) if val is None: # fallback to default value conf[key] = default_value elif not _map_check_fn.get(nature_type, lambda x: True)(val): # value present but not in the proper format, force type conversion conf[key] = _map_convert_fn.get(nature_type, lambda x: x)(val) return conf def priority_read(conf_filenames, default_conf=None): """Try reading the configuration files from conf_filenames, in order, and return the configuration from the first one that exists. default_conf has the same specification as it does in read. """ # Try all the files in order for filename in conf_filenames: full_filename = os.path.expanduser(filename) if config_exists(full_filename): return read(full_filename, default_conf) # Else, return the default configuration return read(None, default_conf) def merge_default_configs(base_config, *other_configs): """Merge several default config dictionaries, from left to right""" full_config = base_config.copy() for config in other_configs: full_config.update(config) return full_config def merge_configs(base, other): """Merge two config dictionaries This does merge config dicts recursively, with the rules, for every value of the dicts (with 'val' not being a dict): - None + type -> type - type + None -> None - dict + dict -> dict (merged) - val + dict -> TypeError - dict + val -> TypeError - val + val -> val (other) for instance: >>> d1 = { ... 'key1': { ... 'skey1': 'value1', ... 'skey2': {'sskey1': 'value2'}, ... }, ... 'key2': 'value3', ... } with >>> d2 = { ... 'key1': { ... 'skey1': 'value4', ... 'skey2': {'sskey2': 'value5'}, ... }, ... 'key3': 'value6', ... } will give: >>> d3 = { ... 'key1': { ... 'skey1': 'value4', # <-- note this ... 'skey2': { ... 'sskey1': 'value2', ... 'sskey2': 'value5', ... }, ... }, ... 'key2': 'value3', ... 'key3': 'value6', ... } >>> assert merge_configs(d1, d2) == d3 Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): - raise TypeError( - 'Cannot merge a %s with a %s' % (type(base), type(other))) + raise TypeError("Cannot merge a %s with a %s" % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) for k in allkeys: vb = base.get(k) vo = other.get(k) if isinstance(vo, dict): output[k] = merge_configs(vb is not None and vb or {}, vo) elif isinstance(vb, dict) and k in other and other[k] is not None: output[k] = merge_configs(vb, vo is not None and vo or {}) elif k in other: output[k] = deepcopy(vo) else: output[k] = deepcopy(vb) return output def swh_config_paths(base_filename): """Return the Software Heritage specific configuration paths for the given filename.""" - return [os.path.join(dirname, base_filename) - for dirname in SWH_CONFIG_DIRECTORIES] + return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ + def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) for key in keys: makedir(conf[key]) def load_global_config(): """Load the global Software Heritage config""" return priority_read( - swh_config_paths(SWH_GLOBAL_CONFIG), - SWH_DEFAULT_GLOBAL_CONFIG, + swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG, ) def load_named_config(name, default_conf=None, global_conf=True): """Load the config named `name` from the Software Heritage configuration paths. If global_conf is True (default), read the global configuration too. """ conf = {} if global_conf: conf.update(load_global_config()) conf.update(priority_read(swh_config_paths(name), default_conf)) return conf class SWHConfig: """Mixin to add configuration parsing abilities to classes The class should override the class attributes: - DEFAULT_CONFIG (default configuration to be parsed) - CONFIG_BASE_FILENAME (the filename of the configuration to be used) This class defines one classmethod, parse_config_file, which parses a configuration file using the default config as set in the class attribute. """ DEFAULT_CONFIG = {} # type: Dict[str, Tuple[str, Any]] - CONFIG_BASE_FILENAME = '' # type: Optional[str] + CONFIG_BASE_FILENAME = "" # type: Optional[str] @classmethod - def parse_config_file(cls, base_filename=None, config_filename=None, - additional_configs=None, global_config=True): + def parse_config_file( + cls, + base_filename=None, + config_filename=None, + additional_configs=None, + global_config=True, + ): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration cls.CONFIG_BASE_FILENAME from one of the Software Heritage configuration directories, in order, unless it is overridden by base_filename or config_filename (which shortcuts the file lookup completely). Args: - base_filename (str): overrides the default cls.CONFIG_BASE_FILENAME - config_filename (str): sets the file to parse instead of the defaults set from cls.CONFIG_BASE_FILENAME - additional_configs: (list of default configuration dicts) allows to override or extend the configuration set in cls.DEFAULT_CONFIG. - global_config (bool): Load the global configuration (default: True) """ if config_filename: config_filenames = [config_filename] - elif 'SWH_CONFIG_FILENAME' in os.environ: - config_filenames = [os.environ['SWH_CONFIG_FILENAME']] + elif "SWH_CONFIG_FILENAME" in os.environ: + config_filenames = [os.environ["SWH_CONFIG_FILENAME"]] else: if not base_filename: base_filename = cls.CONFIG_BASE_FILENAME config_filenames = swh_config_paths(base_filename) if not additional_configs: additional_configs = [] - full_default_config = merge_default_configs(cls.DEFAULT_CONFIG, - *additional_configs) + full_default_config = merge_default_configs( + cls.DEFAULT_CONFIG, *additional_configs + ) config = {} if global_config: config = load_global_config() config.update(priority_read(config_filenames, full_default_config)) return config diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index 19a5867..307f038 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,212 +1,216 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import binascii import datetime import enum import json import logging import os import sys import threading from contextlib import contextmanager import psycopg2 import psycopg2.extras logger = logging.getLogger(__name__) psycopg2.extras.register_uuid() def escape(data): if data is None: - return '' + return "" if isinstance(data, bytes): - return '\\x%s' % binascii.hexlify(data).decode('ascii') + return "\\x%s" % binascii.hexlify(data).decode("ascii") elif isinstance(data, str): return '"%s"' % data.replace('"', '""') elif isinstance(data, datetime.datetime): # We escape twice to make sure the string generated by # isoformat gets escaped return escape(data.isoformat()) elif isinstance(data, dict): return escape(json.dumps(data)) elif isinstance(data, list): - return escape("{%s}" % ','.join(escape(d) for d in data)) + return escape("{%s}" % ",".join(escape(d) for d in data)) elif isinstance(data, psycopg2.extras.Range): # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( - '%s%s,%s%s' % ( - '[' if data.lower_inc else '(', - '-infinity' if data.lower_inf else escape(data.lower), - 'infinity' if data.upper_inf else escape(data.upper), - ']' if data.upper_inc else ')', + "%s%s,%s%s" + % ( + "[" if data.lower_inc else "(", + "-infinity" if data.lower_inf else escape(data.lower), + "infinity" if data.upper_inf else escape(data.upper), + "]" if data.upper_inc else ")", ) ) elif isinstance(data, enum.IntEnum): return escape(int(data)) else: # We don't escape here to make sure we pass literals properly return str(data) def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) return data.tobytes() class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ @classmethod def adapt_conn(cls, conn): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" - t_bytes = psycopg2.extensions.new_type( - (17,), "bytea", typecast_bytea) + t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) - t_bytes_array = psycopg2.extensions.new_array_type( - (1001,), "bytea[]", t_bytes) + t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) @classmethod def from_pool(cls, pool): conn = pool.getconn() return cls(conn, pool=pool) def __init__(self, conn, pool=None): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.adapt_conn(conn) self.conn = conn self.pool = pool def put_conn(self): if self.pool: self.pool.putconn(self.conn) def cursor(self, cur_arg=None): """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() + _cursor = cursor # for bw compat @contextmanager def transaction(self): """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise - def copy_to(self, items, tblname, columns, - cur=None, item_cb=None, default_values={}): + def copy_to( + self, items, tblname, columns, cur=None, item_cb=None, default_values={} + ): """Copy items' entries to table tblname with columns information. Args: items (List[dict]): dictionaries of data to copy over tblname. tblname (str): destination table's name. columns ([str]): keys to access data in items and also the column names in the destination table. default_values (dict): dictionary of default values to use when inserting entried int the tblname table. cur: a db cursor; if not given, a new cursor will be created. item_cb (fn): optional function to apply to items's entry. """ read_file, write_file = os.pipe() exc_info = None def writer(): nonlocal exc_info cursor = self.cursor(cur) - with open(read_file, 'r') as f: + with open(read_file, "r") as f: try: - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) + cursor.copy_expert( + "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f + ) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() write_thread = threading.Thread(target=writer) write_thread.start() try: - with open(write_file, 'w') as f: + with open(write_file, "w") as f: for d in items: if item_cb is not None: item_cb(d) line = [] for k in columns: value = d.get(k, default_values.get(k)) try: line.append(escape(value)) except Exception as e: logger.error( - 'Could not escape value `%r` for column `%s`:' - 'Received exception: `%s`', - value, k, e + "Could not escape value `%r` for column `%s`:" + "Received exception: `%s`", + value, + k, + e, ) raise e from None - f.write(','.join(line)) - f.write('\n') + f.write(",".join(line)) + f.write("\n") finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() if exc_info: # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) def mktemp(self, tblname, cur=None): - self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) + self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) diff --git a/swh/core/db/common.py b/swh/core/db/common.py index b5f163a..17c46be 100644 --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -1,102 +1,103 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import inspect import functools def remove_kwargs(names): def decorator(f): sig = inspect.signature(f) params = sig.parameters - params = [param for param in params.values() - if param.name not in names] + params = [param for param in params.values() if param.name not in names] sig = sig.replace(parameters=params) f.__signature__ = sig return f return decorator def apply_options(cursor, options): """Applies the given postgresql client options to the given cursor. Returns a dictionary with the old values if they changed.""" old_options = {} for option, value in options.items(): - cursor.execute('SHOW %s' % option) + cursor.execute("SHOW %s" % option) old_value = cursor.fetchall()[0][0] if old_value != value: - cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) + cursor.execute("SET LOCAL %s TO %%s" % option, (value,)) old_options[option] = old_value return old_options def db_transaction(**client_options): """decorator to execute Backend methods within DB transactions The decorated method must accept a `cur` and `db` keyword argument Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction_generator for generator functions.') + raise ValueError("Use db_transaction_generator for generator functions.") - @remove_kwargs(['cur', 'db']) + @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] + if "cur" in kwargs and kwargs["cur"]: + cur = kwargs["cur"] old_options = apply_options(cur, __client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) return ret else: db = self.get_db() try: with db.transaction() as cur: apply_options(cur, __client_options) return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) + return _meth return decorator def db_transaction_generator(**client_options): """decorator to execute Backend methods within DB transactions, while returning a generator The decorated method must accept a `cur` and `db` keyword argument Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if not inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction for non-generator functions.') + raise ValueError("Use db_transaction for non-generator functions.") - @remove_kwargs(['cur', 'db']) + @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] + if "cur" in kwargs and kwargs["cur"]: + cur = kwargs["cur"] old_options = apply_options(cur, __client_options) yield from meth(self, *args, **kwargs) apply_options(cur, old_options) else: db = self.get_db() try: with db.transaction() as cur: apply_options(cur, __client_options) yield from meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) return _meth + return decorator diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py index 451fb58..aa97a93 100644 --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -1,149 +1,151 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # # This code has been imported from psycopg2, version 2.7.4, # https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd # and modified by Software Heritage. # # Original file: lib/extras.py # # psycopg2 is free software: you can redistribute it and/or modify it under the # terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 3 of the License, or (at your option) any # later version. import re import functools import psycopg2.extensions def stored_procedure(stored_proc): """decorator to execute remote stored procedure, specified as argument Generally, the body of the decorated function should be empty. If it is not, the stored procedure will be executed first; the function body then. """ + def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): - cur = kwargs.get('cur', None) - self._cursor(cur).execute('SELECT %s()' % stored_proc) + cur = kwargs.get("cur", None) + self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) + return _meth + return wrap def jsonize(value): """Convert a value to a psycopg2 JSON object if necessary""" if isinstance(value, dict): return psycopg2.extras.Json(value) return value def _paginate(seq, page_size): """Consume an iterable and return it in chunks. Every chunk is at most `page_size`. Never return an empty chunk. """ page = [] it = iter(seq) while 1: try: for i in range(page_size): page.append(next(it)) yield page page = [] except StopIteration: if page: yield page return def _split_sql(sql): """Split *sql* on a single ``%s`` placeholder. Split on the %s, perform %% replacement and return pre, post lists of snippets. """ curr = pre = [] post = [] - tokens = re.split(br'(%.)', sql) + tokens = re.split(br"(%.)", sql) for token in tokens: - if len(token) != 2 or token[:1] != b'%': + if len(token) != 2 or token[:1] != b"%": curr.append(token) continue - if token[1:] == b's': + if token[1:] == b"s": if curr is pre: curr = post else: - raise ValueError( - "the query contains more than one '%s' placeholder") - elif token[1:] == b'%': - curr.append(b'%') + raise ValueError("the query contains more than one '%s' placeholder") + elif token[1:] == b"%": + curr.append(b"%") else: - raise ValueError("unsupported format character: '%s'" - % token[1:].decode('ascii', 'replace')) + raise ValueError( + "unsupported format character: '%s'" + % token[1:].decode("ascii", "replace") + ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") return pre, post def execute_values_generator(cur, sql, argslist, template=None, page_size=100): - '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. + """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! :param cur: the cursor to use to execute the query. :param sql: the query to execute. It must contain a single ``%s`` placeholder, which will be replaced by a `VALUES list`__. Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``. :param argslist: sequence of sequences or dictionaries with the arguments to send to the query. The type and content must be consistent with *template*. :param template: the snippet to merge to every item in *argslist* to compose the query. - If the *argslist* items are sequences it should contain positional placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there are constants value...). - If the *argslist* items are mappings it should contain named placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``). If not specified, assume the arguments are sequence and use a simple positional template (i.e. ``(%s, %s, ...)``), with the number of placeholders sniffed by the first element in *argslist*. :param page_size: maximum number of *argslist* items to include in every statement. If there are more items the function will execute more than one statement. :param yield_from_cur: Whether to yield results from the cursor in this function directly. .. __: https://www.postgresql.org/docs/current/static/queries-values.html After the execution of the function the `cursor.rowcount` property will **not** contain a total result. - ''' + """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): - sql = sql.encode( - psycopg2.extensions.encodings[cur.connection.encoding] - ) + sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: - template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' + template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) - parts.append(b',') + parts.append(b",") parts[-1:] = post - cur.execute(b''.join(parts)) + cur.execute(b"".join(parts)) yield from cur diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py index 5d8dcd5..7ce4272 100644 --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,2 +1,3 @@ import os -os.environ['LC_ALL'] = 'C.UTF-8' + +os.environ["LC_ALL"] = "C.UTF-8" diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py index c8bed92..3c8c34d 100644 --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -1,320 +1,340 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import glob import subprocess import psycopg2 from typing import Dict, Iterable, Optional, Tuple, Union from swh.core.utils import numfile_sortkey as sortkey -DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} # type: Dict[str, str] +DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] def swh_db_version(dbname_or_service): """Retrieve the swh version if any. In case of the db not initialized, this returns None. Otherwise, this returns the db's version. Args: dbname_or_service (str): The db's name or service Returns: Optional[Int]: Either the db's version or None """ - query = 'select version from dbversion order by dbversion desc limit 1' + query = "select version from dbversion order by dbversion desc limit 1" cmd = [ - 'psql', '--tuples-only', '--no-psqlrc', '--quiet', - '-v', 'ON_ERROR_STOP=1', "--command=%s" % query, - dbname_or_service + "psql", + "--tuples-only", + "--no-psqlrc", + "--quiet", + "-v", + "ON_ERROR_STOP=1", + "--command=%s" % query, + dbname_or_service, ] try: - r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, - universal_newlines=True) + r = subprocess.run( + cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True + ) result = int(r.stdout.strip()) except Exception: # db not initialized result = None return result -def pg_restore(dbname, dumpfile, dumptype='pg_dump'): +def pg_restore(dbname, dumpfile, dumptype="pg_dump"): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ - assert dumptype in ['pg_dump', 'psql'] - if dumptype == 'pg_dump': - subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', - '--dbname', dbname, dumpfile]) - elif dumptype == 'psql': - subprocess.check_call(['psql', '--quiet', - '--no-psqlrc', - '-v', 'ON_ERROR_STOP=1', - '-f', dumpfile, - dbname]) + assert dumptype in ["pg_dump", "psql"] + if dumptype == "pg_dump": + subprocess.check_call( + [ + "pg_restore", + "--no-owner", + "--no-privileges", + "--dbname", + dbname, + dumpfile, + ] + ) + elif dumptype == "psql": + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-f", + dumpfile, + dbname, + ] + ) def pg_dump(dbname, dumpfile): - subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', - '-f', dumpfile, dbname]) + subprocess.check_call( + ["pg_dump", "--no-owner", "--no-privileges", "-Fc", "-f", dumpfile, dbname] + ) def pg_dropdb(dbname): - subprocess.check_call(['dropdb', dbname]) + subprocess.check_call(["dropdb", dbname]) def pg_createdb(dbname, check=True): """Create a db. If check is True and the db already exists, this will raise an exception (original behavior). If check is False and the db already exists, this will fail silently. If the db does not exist, the db will be created. """ - subprocess.run(['createdb', dbname], check=check) + subprocess.run(["createdb", dbname], check=check) def db_create(dbname, dumps=None): """create the test DB and load the test data dumps into it dumps is an iterable of couples (dump_file, dump_type). context: setUpClass """ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case - pg_dropdb(dbname) # the db already existed + pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) return dbname def db_destroy(dbname): """destroy the test DB context: tearDownClass """ pg_dropdb(dbname) def db_connect(dbname): """connect to the test DB and open a cursor context: setUp """ - conn = psycopg2.connect('dbname=' + dbname) - return { - 'conn': conn, - 'cursor': conn.cursor() - } + conn = psycopg2.connect("dbname=" + dbname) + return {"conn": conn, "cursor": conn.cursor()} def db_close(conn): """rollback current transaction and disconnect from the test DB context: tearDown """ if not conn.closed: conn.rollback() conn.close() class DbTestConn: def __init__(self, dbname): self.dbname = dbname def __enter__(self): self.db_setup = db_connect(self.dbname) - self.conn = self.db_setup['conn'] - self.cursor = self.db_setup['cursor'] + self.conn = self.db_setup["conn"] + self.cursor = self.db_setup["cursor"] return self def __exit__(self, *_): db_close(self.conn) class DbTestContext: - def __init__(self, name='softwareheritage-test', dumps=None): + def __init__(self, name="softwareheritage-test", dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): - db_create(dbname=self.dbname, - dumps=self.dumps) + db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): db_destroy(self.dbname) class DbTestFixture: """Mix this in a test subject class to get DB testing support. Use the class method add_db() to add a new database to be tested. Using this will create a DbTestConn entry in the `test_db` dictionary for all the tests, indexed by the name of the database. Example: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): cls.add_db('db_name', DUMP) super().setUpClass() def setUp(self): db = self.test_db['db_name'] print('conn: {}, cursor: {}'.format(db.conn, db.cursor)) To ensure test isolation, each test method of the test case class will execute in its own connection, cursor, and transaction. Note that if you want to define setup/teardown methods, you need to explicitly call super() to ensure that the fixture setup/teardown methods are invoked. Here is an example where all setup/teardown methods are defined in a test case: class TestDb(DbTestFixture, unittest.TestCase): @classmethod def setUpClass(cls): # your add_db() calls here super().setUpClass() # your class setup code here def setUp(self): super().setUp() # your instance setup code here def tearDown(self): # your instance teardown code here super().tearDown() @classmethod def tearDownClass(cls): # your class teardown code here super().tearDownClass() """ _DB_DUMP_LIST = {} # type: Dict[str, Iterable[Tuple[str, str]]] _DB_LIST = {} # type: Dict[str, DbTestContext] DB_TEST_FIXTURE_IMPORTED = True @classmethod - def add_db(cls, name='softwareheritage-test', dumps=None): + def add_db(cls, name="softwareheritage-test", dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod def setUpClass(cls): for name, dumps in cls._DB_DUMP_LIST.items(): cls._DB_LIST[name] = DbTestContext(name, dumps) cls._DB_LIST[name].__enter__() super().setUpClass() @classmethod def tearDownClass(cls): super().tearDownClass() for name, context in cls._DB_LIST.items(): context.__exit__() def setUp(self, *args, **kwargs): self.test_db = {} for name in self._DB_LIST.keys(): self.test_db[name] = DbTestConn(name) self.test_db[name].__enter__() super().setUp(*args, **kwargs) def tearDown(self): super().tearDown() for name in self._DB_LIST.keys(): self.test_db[name].__exit__() def reset_db_tables(self, name, excluded=None): db = self.test_db[name] conn = db.conn cursor = db.cursor - cursor.execute("""SELECT table_name FROM information_schema.tables - WHERE table_schema = %s""", ('public',)) + cursor.execute( + """SELECT table_name FROM information_schema.tables + WHERE table_schema = %s""", + ("public",), + ) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: - cursor.execute('truncate table %s cascade' % table) + cursor.execute("truncate table %s cascade" % table) conn.commit() class SingleDbTestFixture(DbTestFixture): """Simplified fixture like DbTest but that can only handle a single DB. Gives access to shortcuts like self.cursor and self.conn. DO NOT use this with other fixtures that need to access databases, like StorageTestFixture. The class can override the following class attributes: TEST_DB_NAME: name of the DB used for testing TEST_DB_DUMP: DB dump to be restored before running test methods; can be set to None if no restore from dump is required. If the dump file name endswith" - '.sql' it will be loaded via psql, - '.dump' it will be loaded via pg_restore. Other file extensions will be ignored. Can be a string or a list of strings; each path will be expanded using glob pattern matching. The test case class will then have the following attributes, accessible via self: dbname: name of the test database conn: psycopg2 connection object cursor: open psycopg2 cursor to the DB """ - TEST_DB_NAME = 'softwareheritage-test' + TEST_DB_NAME = "softwareheritage-test" TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod def setUpClass(cls): cls.dbname = cls.TEST_DB_NAME # XXX to kill? dump_files = cls.TEST_DB_DUMP if dump_files is None: dump_files = [] elif isinstance(dump_files, str): dump_files = [dump_files] all_dump_files = [] for files in dump_files: - all_dump_files.extend( - sorted(glob.glob(files), key=sortkey)) + all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) - all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) - for x in all_dump_files] + all_dump_files = [ + (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files + ] - cls.add_db(name=cls.TEST_DB_NAME, - dumps=all_dump_files) + cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): super().setUp(*args, **kwargs) db = self.test_db[self.TEST_DB_NAME] self.conn = db.conn self.cursor = db.cursor diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py index 118ff7d..d815271 100644 --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -1,57 +1,57 @@ # from click.testing import CliRunner from swh.core.cli.db import db as swhdb -help_msg = '''Usage: swh [OPTIONS] COMMAND [ARGS]... +help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. Options: -l, --log-level [NOTSET|DEBUG|INFO|WARNING|ERROR|CRITICAL] Log level (defaults to INFO). --log-config FILENAME Python yaml logging configuration file. --sentry-dsn TEXT DSN of the Sentry instance to report to -h, --help Show this message and exit. Notes: If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. Commands: db Software Heritage database generic tools. -''' +""" def test_swh_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output == help_msg -help_db_msg = '''Usage: swh db [OPTIONS] COMMAND [ARGS]... +help_db_msg = """Usage: swh db [OPTIONS] COMMAND [ARGS]... Software Heritage database generic tools. Options: -C, --config-file FILE Configuration file. -h, --help Show this message and exit. Commands: init Initialize the database for every Software Heritage module found in... -''' +""" def test_swh_db_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() - result = runner.invoke(swhmain, ['db', '-h']) + result = runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 assert result.output == help_db_msg diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py index 506b60f..2cee3fc 100644 --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -1,222 +1,226 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import inspect import os.path import tempfile import unittest from unittest.mock import Mock, MagicMock from hypothesis import strategies, given import psycopg2 import pytest from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator from .db_testing import ( - SingleDbTestFixture, db_create, db_destroy, db_close, + SingleDbTestFixture, + db_create, + db_destroy, + db_close, ) -INIT_SQL = ''' +INIT_SQL = """ create table test_table ( i int, txt text, bytes bytea ); -''' - -db_rows = strategies.lists(strategies.tuples( - strategies.integers(-2147483648, +2147483647), - strategies.text( - alphabet=strategies.characters( - blacklist_categories=['Cs'], # surrogates - blacklist_characters=[ - '\x00', # pgsql does not support the null codepoint - '\r', # pgsql normalizes those - ] +""" + +db_rows = strategies.lists( + strategies.tuples( + strategies.integers(-2147483648, +2147483647), + strategies.text( + alphabet=strategies.characters( + blacklist_categories=["Cs"], # surrogates + blacklist_characters=[ + "\x00", # pgsql does not support the null codepoint + "\r", # pgsql normalizes those + ], + ), ), - ), - strategies.binary(), -)) + strategies.binary(), + ) +) @pytest.mark.db def test_connect(): - db_name = db_create('test-db2', dumps=[]) + db_name = db_create("test-db2", dumps=[]) try: - db = BaseDb.connect('dbname=%s' % db_name) + db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) cur.execute("select * from test_table;") - assert list(cur) == [(1, 'foo', b'bar')] + assert list(cur) == [(1, "foo", b"bar")] finally: db_close(db.conn) db_destroy(db_name) @pytest.mark.db class TestDb(SingleDbTestFixture, unittest.TestCase): - TEST_DB_NAME = 'test-db' + TEST_DB_NAME = "test-db" @classmethod def setUpClass(cls): with tempfile.TemporaryDirectory() as td: - with open(os.path.join(td, 'init.sql'), 'a') as fd: + with open(os.path.join(td, "init.sql"), "a") as fd: fd.write(INIT_SQL) - cls.TEST_DB_DUMP = os.path.join(td, '*.sql') + cls.TEST_DB_DUMP = os.path.join(td, "*.sql") super().setUpClass() def setUp(self): super().setUp() self.db = BaseDb(self.conn) def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) cur.execute("select * from test_table;") - self.assertEqual(list(cur), [(1, 'foo', b'bar')]) + self.assertEqual(list(cur), [(1, "foo", b"bar")]) def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) - self.reset_db_tables('test-db') + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + self.reset_db_tables("test-db") cur.execute("select * from test_table;") self.assertEqual(list(cur), []) @given(db_rows) def test_copy_to(self, data): # the table is not reset between runs by hypothesis - self.reset_db_tables('test-db') + self.reset_db_tables("test-db") - items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] - self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] + self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) cur = self.db.cursor() - cur.execute('select * from test_table;') + cur.execute("select * from test_table;") self.assertCountEqual(list(cur), data) def test_copy_to_thread_exception(self): - data = [(2**65, 'foo', b'bar')] + data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] + items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) def test_db_transaction(mocker): expected_cur = object() called = False class Storage: @db_transaction() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction__with_generator(): - with pytest.raises(ValueError, match='generator'): + with pytest.raises(ValueError, match="generator"): + class Storage: @db_transaction() def endpoint(self, cur=None, db=None): yield None def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction() def g(self, foo, *, bar=None, db=None, cur=None): pass actual_sig = inspect.signature(g) assert actual_sig == expected_sig def test_db_transaction_generator(mocker): expected_cur = object() called = False class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): nonlocal called called = True assert cur is expected_cur yield None storage = Storage() # 'with storage.get_db().transaction() as cur:' should cause # 'cur' to be 'expected_cur' db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) assert called put_db_mock.assert_called_once_with(db_mock) def test_db_transaction_generator__with_nongenerator(): - with pytest.raises(ValueError, match='generator'): + with pytest.raises(ValueError, match="generator"): + class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): pass def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction_generator() def g(self, foo, *, bar=None, db=None, cur=None): yield None actual_sig = inspect.signature(g) assert actual_sig == expected_sig diff --git a/swh/core/logger.py b/swh/core/logger.py index d75580f..f0163a6 100644 --- a/swh/core/logger.py +++ b/swh/core/logger.py @@ -1,114 +1,117 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import logging from typing import Any, Generator, List, Tuple from systemd.journal import JournalHandler as _JournalHandler, send try: from celery import current_task except ImportError: current_task = None -EXTRA_LOGDATA_PREFIX = 'swh_' +EXTRA_LOGDATA_PREFIX = "swh_" def db_level_of_py_level(lvl): """convert a log level of the logging module to a log level suitable for the logging Postgres DB """ return logging.getLevelName(lvl).lower() def get_extra_data(record, task_args=True): """Get the extra data to insert to the database from the logging record""" log_data = record.__dict__ - extra_data = {k[len(EXTRA_LOGDATA_PREFIX):]: v - for k, v in log_data.items() - if k.startswith(EXTRA_LOGDATA_PREFIX)} + extra_data = { + k[len(EXTRA_LOGDATA_PREFIX) :]: v + for k, v in log_data.items() + if k.startswith(EXTRA_LOGDATA_PREFIX) + } - args = log_data.get('args') + args = log_data.get("args") if args: - extra_data['logging_args'] = args + extra_data["logging_args"] = args # Retrieve Celery task info if current_task and current_task.request: - extra_data['task'] = { - 'id': current_task.request.id, - 'name': current_task.name, + extra_data["task"] = { + "id": current_task.request.id, + "name": current_task.name, } if task_args: - extra_data['task'].update({ - 'kwargs': current_task.request.kwargs, - 'args': current_task.request.args, - }) + extra_data["task"].update( + { + "kwargs": current_task.request.kwargs, + "args": current_task.request.args, + } + ) return extra_data -def flatten( - data: Any, - separator: str = "_" -) -> Generator[Tuple[str, Any], None, None]: +def flatten(data: Any, separator: str = "_") -> Generator[Tuple[str, Any], None, None]: """Flatten the data dictionary into a flat structure""" def inner_flatten( data: Any, prefix: List[str] ) -> Generator[Tuple[List[str], Any], None, None]: if isinstance(data, dict): if all(isinstance(key, str) for key in data): for key, value in data.items(): yield from inner_flatten(value, prefix + [key]) else: yield prefix, str(data) elif isinstance(data, (list, tuple)): for key, value in enumerate(data): yield from inner_flatten(value, prefix + [str(key)]) else: yield prefix, data for path, value in inner_flatten(data, []): yield separator.join(path), value def stringify(value): """Convert value to string""" if isinstance(value, datetime.datetime): return value.isoformat() return str(value) class JournalHandler(_JournalHandler): def emit(self, record): """Write `record` as a journal event. MESSAGE is taken from the message provided by the user, and PRIORITY, LOGGER, THREAD_NAME, CODE_{FILE,LINE,FUNC} fields are appended automatically. In addition, record.MESSAGE_ID will be used if present. """ try: extra_data = flatten(get_extra_data(record, task_args=False)) extra_data = { (EXTRA_LOGDATA_PREFIX + key).upper(): stringify(value) for key, value in extra_data } msg = self.format(record) pri = self.mapPriority(record.levelno) - send(msg, - PRIORITY=format(pri), - LOGGER=record.name, - THREAD_NAME=record.threadName, - CODE_FILE=record.pathname, - CODE_LINE=record.lineno, - CODE_FUNC=record.funcName, - **extra_data) + send( + msg, + PRIORITY=format(pri), + LOGGER=record.name, + THREAD_NAME=record.threadName, + CODE_FILE=record.pathname, + CODE_LINE=record.lineno, + CODE_FUNC=record.funcName, + **extra_data, + ) except Exception: self.handleError(record) diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py index 653d463..233d554 100644 --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -1,311 +1,319 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import re import pytest import requests from functools import partial from os import path from typing import Dict, List, Optional from urllib.parse import urlparse, unquote from requests.adapters import BaseAdapter from requests.structures import CaseInsensitiveDict from requests.utils import get_encoding_from_headers logger = logging.getLogger(__name__) # Check get_local_factory function # Maximum number of iteration checks to generate requests responses MAX_VISIT_FILES = 10 def get_response_cb( - request: requests.Request, context, datadir, - ignore_urls: List[str] = [], - visits: Optional[Dict] = None): + request: requests.Request, + context, + datadir, + ignore_urls: List[str] = [], + visits: Optional[Dict] = None, +): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. This is meant to be used as 'body' argument of the requests_mock.get() method. It will look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Eg. if you use the requests_mock fixture in your test file as: requests_mock.get('https?://nowhere.com', body=get_response_cb) # or even requests_mock.get(re.compile('https?://'), body=get_response_cb) then a call requests.get like: requests.get('https://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/https_nowhere.com/path_to_resource,a=b,c=d or a call requests.get like: requests.get('http://nowhere.com/path/to/resource?a=b&c=d') will look the content of the response in: datadir/http_nowhere.com/path_to_resource,a=b,c=d Args: request: Object requests context (requests.Context): Object holding response metadata information (status_code, headers, etc...) datadir: Data files path ignore_urls: urls whose status response should be 404 even if the local file exists visits: Dict of url, number of visits. If None, disable multi visit support (default) Returns: Optional[FileDescriptor] on disk file to read from the test context """ - logger.debug('get_response_cb(%s, %s)', request, context) - logger.debug('url: %s', request.url) - logger.debug('ignore_urls: %s', ignore_urls) + logger.debug("get_response_cb(%s, %s)", request, context) + logger.debug("url: %s", request.url) + logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 return None url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org - dirname = '%s_%s' % (url.scheme, url.hostname) + dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] - if filename.endswith('/'): + if filename.endswith("/"): filename = filename[:-1] - filename = filename.replace('/', '_') + filename = filename.replace("/", "_") if url.query: - filename += ',' + url.query.replace('&', ',') + filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: - filepath = filepath + '_visit%s' % visit + filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): - logger.debug('not found filepath: %s', filepath) + logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None - fd = open(filepath, 'rb') - context.headers['content-length'] = str(path.getsize(filepath)) + fd = open(filepath, "rb") + context.headers["content-length"] = str(path.getsize(filepath)) return fd @pytest.fixture def datadir(request): """By default, returns the test directory's data directory. This can be overridden on a per file tree basis. Add an override definition in the local conftest, for example:: import pytest from os import path @pytest.fixture def datadir(): return path.join(path.abspath(path.dirname(__file__)), 'resources') """ - return path.join(path.dirname(str(request.fspath)), 'data') + return path.join(path.dirname(str(request.fspath)), "data") -def requests_mock_datadir_factory(ignore_urls: List[str] = [], - has_multi_visit: bool = False): +def requests_mock_datadir_factory( + ignore_urls: List[str] = [], has_multi_visit: bool = False +): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: - files are searched in the datadir/ directory - the local file name is the path part of the URL with path hierarchy markers (aka '/') replaced by '_' Multiple implementations are possible, for example: - requests_mock_datadir_factory([]): This computes the file name from the query and always returns the same result. - requests_mock_datadir_factory(has_multi_visit=True): This computes the file name from the query and returns the content of the filename the first time, the next call returning the content of files suffixed with _visit1 and so on and so forth. If the file is not found, returns a 404. - requests_mock_datadir_factory(ignore_urls=['url1', 'url2']): This will ignore any files corresponding to url1 and url2, always returning 404. Args: ignore_urls: List of urls to always returns 404 (whether file exists or not) has_multi_visit: Activate or not the multiple visits behavior """ + @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: - cb = partial(get_response_cb, - ignore_urls=ignore_urls, - datadir=datadir) - requests_mock.get(re.compile('https?://'), body=cb) + cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) + requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} - requests_mock.get(re.compile('https?://'), body=partial( - get_response_cb, ignore_urls=ignore_urls, visits=visits, - datadir=datadir) + requests_mock.get( + re.compile("https?://"), + body=partial( + get_response_cb, + ignore_urls=ignore_urls, + visits=visits, + datadir=datadir, + ), ) return requests_mock return requests_mock_datadir # Default `requests_mock_datadir` implementation requests_mock_datadir = requests_mock_datadir_factory([]) # Implementation for multiple visits behavior: # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... -requests_mock_datadir_visits = requests_mock_datadir_factory( - has_multi_visit=True) +requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture def swh_rpc_client(swh_rpc_client_class, swh_rpc_adapter): """This fixture generates an RPCClient instance that uses the class generated by the rpc_client_class fixture as backend. Since it uses the swh_rpc_adapter, HTTP queries will be intercepted and routed directly to the current Flask app (as provided by the `app` fixture). So this stack of fixtures allows to test the RPCClient -> RPCServerApp communication path using a real RPCClient instance and a real Flask (RPCServerApp) app instance. To use this fixture: - ensure an `app` fixture exists and generate a Flask application, - implement an `swh_rpc_client_class` fixtures that returns the RPCClient-based class to use as client side for the tests, - implement your tests using this `swh_rpc_client` fixture. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ - url = 'mock://example.com' + url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() - cli.session.mount('mock://', swh_rpc_adapter) + cli.session.mount("mock://", swh_rpc_adapter) return cli @pytest.fixture def swh_rpc_adapter(app): """Fixture that generates a requests.Adapter instance that can be used to test client/servers code based on swh.core.api classes. See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ with app.test_client() as client: yield RPCTestAdapter(client) class RPCTestAdapter(BaseAdapter): def __init__(self, client): self._client = client def build_response(self, req, resp): response = requests.Response() # Fallback to None if there's no status_code, for whatever reason. response.status_code = resp.status_code # Make headers case-insensitive. - response.headers = CaseInsensitiveDict(getattr(resp, 'headers', {})) + response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) response.raw = resp response.reason = response.raw.status if isinstance(req.url, bytes): - response.url = req.url.decode('utf-8') + response.url = req.url.decode("utf-8") else: response.url = req.url # Give the Response some context. response.request = req response.connection = self response._content = resp.data return response def send(self, request, **kw): """ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( - request.url, method=request.method, + request.url, + method=request.method, headers=request.headers.items(), data=request.body, - ) + ) return self.build_response(request, resp) @pytest.fixture def flask_app_client(app): with app.test_client() as client: yield client # stolen from pytest-flask, required to have url_for() working within tests # using flask_app_client fixture. @pytest.fixture(autouse=True) def _push_request_context(request): """During tests execution request context has been pushed, e.g. `url_for`, `session`, etc. can be used in tests as is:: def test_app(app, client): assert client.get(url_for('myview')).status_code == 200 """ - if 'app' not in request.fixturenames: + if "app" not in request.fixturenames: return - app = request.getfixturevalue('app') + app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() def teardown(): ctx.pop() request.addfinalizer(teardown) diff --git a/swh/core/sentry.py b/swh/core/sentry.py index 536bc4b..2af66a6 100644 --- a/swh/core/sentry.py +++ b/swh/core/sentry.py @@ -1,37 +1,35 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pkg_resources import os def get_sentry_release(): - main_package = os.environ.get('SWH_MAIN_PACKAGE') + main_package = os.environ.get("SWH_MAIN_PACKAGE") if main_package: version = pkg_resources.get_distribution(main_package).version - return f'{main_package}@{version}' + return f"{main_package}@{version}" else: return None -def init_sentry( - sentry_dsn, *, debug=None, integrations=[], - extra_kwargs={}): +def init_sentry(sentry_dsn, *, debug=None, integrations=[], extra_kwargs={}): if debug is None: - debug = bool(os.environ.get('SWH_SENTRY_DEBUG')) - sentry_dsn = sentry_dsn or os.environ.get('SWH_SENTRY_DSN') - environment = os.environ.get('SWH_SENTRY_ENVIRONMENT') + debug = bool(os.environ.get("SWH_SENTRY_DEBUG")) + sentry_dsn = sentry_dsn or os.environ.get("SWH_SENTRY_DSN") + environment = os.environ.get("SWH_SENTRY_ENVIRONMENT") if sentry_dsn: import sentry_sdk sentry_sdk.init( release=get_sentry_release(), environment=environment, dsn=sentry_dsn, integrations=integrations, debug=debug, **extra_kwargs, ) diff --git a/swh/core/statsd.py b/swh/core/statsd.py index 30be881..c212e71 100644 --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -1,430 +1,441 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # # Vastly adapted for integration in swh.core: # # - Removed python < 3.5 compat code # - trimmed the imports down to be a single module # - adjust some options: # - drop unix socket connection option # - add environment variable support for setting the statsd host and # port (pulled the idea from the main python statsd module) # - only send timer metrics in milliseconds (that's what # prometheus-statsd-exporter expects) # - drop DataDog-specific metric types (that are unsupported in # prometheus-statsd-exporter) # - made the tags a dict instead of a list (prometheus-statsd-exporter only # supports tags with a value, mirroring prometheus) # - switch from time.time to time.monotonic # - improve unit test coverage # - documentation cleanup from asyncio import iscoroutinefunction from functools import wraps from random import random from time import monotonic import itertools import logging import os import socket import threading import warnings -log = logging.getLogger('swh.core.statsd') +log = logging.getLogger("swh.core.statsd") class TimedContextManagerDecorator(object): """ A context manager and a decorator which will report the elapsed time in the context OR in a function call. Attributes: elapsed (float): the elapsed time at the point of completion """ - def __init__(self, statsd, metric=None, error_metric=None, - tags=None, sample_rate=1): + + def __init__( + self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1 + ): self.statsd = statsd self.metric = metric self.error_metric = error_metric self.tags = tags self.sample_rate = sample_rate self.elapsed = None # this is for testing purpose def __call__(self, func): """ Decorator which returns the elapsed time of the function call. Default to the function name if metric was not provided. """ if not self.metric: - self.metric = '%s.%s' % (func.__module__, func.__name__) + self.metric = "%s.%s" % (func.__module__, func.__name__) # Coroutines if iscoroutinefunction(func): + @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() try: result = await func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result + return wrapped_co # Others @wraps(func) def wrapped(*args, **kwargs): start = monotonic() try: result = func(*args, **kwargs) except: # noqa self._send_error() raise self._send(start) return result + return wrapped def __enter__(self): if not self.metric: raise TypeError("Cannot used timed without a metric!") self._start = monotonic() return self def __exit__(self, type, value, traceback): # Report the elapsed time of the context manager if no error. if type is None: self._send(self._start) else: self._send_error() def _send(self, start): elapsed = (monotonic() - start) * 1000 - self.statsd.timing(self.metric, elapsed, - tags=self.tags, sample_rate=self.sample_rate) + self.statsd.timing( + self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate + ) self.elapsed = elapsed def _send_error(self): if self.error_metric is None: - self.error_metric = self.metric + '_error_count' + self.error_metric = self.metric + "_error_count" self.statsd.increment(self.error_metric, tags=self.tags) def start(self): """Start the timer""" self.__enter__() def stop(self): """Stop the timer, send the metric value""" self.__exit__(None, None, None) class Statsd(object): """Initialize a client to send metrics to a StatsD server. Arguments: host (str): the host of the StatsD server. Defaults to localhost. port (int): the port of the StatsD server. Defaults to 8125. max_buffer_size (int): Maximum number of metrics to buffer before sending to the server if sending metrics in batch namespace (str): Namespace to prefix all metric names constant_tags (Dict[str, str]): Tags to attach to all metrics Note: This class also supports the following environment variables: STATSD_HOST Override the default host of the statsd server STATSD_PORT Override the default port of the statsd server STATSD_TAGS Tags to attach to every metric reported. Example value: "label:value,other_label:other_value" """ - def __init__(self, host=None, port=None, max_buffer_size=50, - namespace=None, constant_tags=None): + def __init__( + self, + host=None, + port=None, + max_buffer_size=50, + namespace=None, + constant_tags=None, + ): # Connection if host is None: - host = os.environ.get('STATSD_HOST') or 'localhost' + host = os.environ.get("STATSD_HOST") or "localhost" self.host = host if port is None: - port = os.environ.get('STATSD_PORT') or 8125 + port = os.environ.get("STATSD_PORT") or 8125 self.port = int(port) # Socket self._socket = None self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server - self.encoding = 'utf-8' + self.encoding = "utf-8" # Tags self.constant_tags = {} - tags_envvar = os.environ.get('STATSD_TAGS', '') - for tag in tags_envvar.split(','): + tags_envvar = os.environ.get("STATSD_TAGS", "") + for tag in tags_envvar.split(","): if not tag: continue - if ':' not in tag: + if ":" not in tag: warnings.warn( - 'STATSD_TAGS needs to be in key:value format, ' - '%s invalid' % tag, + "STATSD_TAGS needs to be in key:value format, " "%s invalid" % tag, UserWarning, ) continue - k, v = tag.split(':', 1) + k, v = tag.split(":", 1) self.constant_tags[k] = v if constant_tags: - self.constant_tags.update({ - str(k): str(v) - for k, v in constant_tags.items() - }) + self.constant_tags.update( + {str(k): str(v) for k, v in constant_tags.items()} + ) # Namespace if namespace is not None: namespace = str(namespace) self.namespace = namespace def __enter__(self): self.open_buffer(self.max_buffer_size) return self def __exit__(self, type, value, traceback): self.close_buffer() def gauge(self, metric, value, tags=None, sample_rate=1): """ Record the value of a gauge, optionally setting a list of tags and a sample rate. >>> statsd.gauge('users.online', 123) >>> statsd.gauge('active.connections', 1001, tags={"protocol": "http"}) """ - return self._report(metric, 'g', value, tags, sample_rate) + return self._report(metric, "g", value, tags, sample_rate) def increment(self, metric, value=1, tags=None, sample_rate=1): """ Increment a counter, optionally setting a value, tags and a sample rate. >>> statsd.increment('page.views') >>> statsd.increment('files.transferred', 124) """ - self._report(metric, 'c', value, tags, sample_rate) + self._report(metric, "c", value, tags, sample_rate) def decrement(self, metric, value=1, tags=None, sample_rate=1): """ Decrement a counter, optionally setting a value, tags and a sample rate. >>> statsd.decrement('files.remaining') >>> statsd.decrement('active.connections', 2) """ metric_value = -value if value else value - self._report(metric, 'c', metric_value, tags, sample_rate) + self._report(metric, "c", metric_value, tags, sample_rate) def histogram(self, metric, value, tags=None, sample_rate=1): """ Sample a histogram value, optionally setting tags and a sample rate. >>> statsd.histogram('uploaded.file.size', 1445) >>> statsd.histogram('file.count', 26, tags={"filetype": "python"}) """ - self._report(metric, 'h', value, tags, sample_rate) + self._report(metric, "h", value, tags, sample_rate) def timing(self, metric, value, tags=None, sample_rate=1): """ Record a timing, optionally setting tags and a sample rate. >>> statsd.timing("query.response.time", 1234) """ - self._report(metric, 'ms', value, tags, sample_rate) + self._report(metric, "ms", value, tags, sample_rate) def timed(self, metric=None, error_metric=None, tags=None, sample_rate=1): """ A decorator or context manager that will measure the distribution of a function's/context's run time. Optionally specify a list of tags or a sample rate. If the metric is not defined as a decorator, the module name and function name will be used. The metric is required as a context manager. :: @statsd.timed('user.query.time', sample_rate=0.5) def get_user(user_id): # Do what you need to ... pass # Is equivalent to ... with statsd.timed('user.query.time', sample_rate=0.5): # Do what you need to ... pass # Is equivalent to ... start = time.monotonic() try: get_user(user_id) finally: statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( - statsd=self, metric=metric, + statsd=self, + metric=metric, error_metric=error_metric, - tags=tags, sample_rate=sample_rate) + tags=tags, + sample_rate=sample_rate, + ) def set(self, metric, value, tags=None, sample_rate=1): """ Sample a set value. >>> statsd.set('visitors.uniques', 999) """ - self._report(metric, 's', value, tags, sample_rate) + self._report(metric, "s", value, tags, sample_rate) @property def socket(self): """ Return a connected socket. Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ with self.lock: if not self._socket: sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect((self.host, self.port)) self._socket = sock return self._socket def open_buffer(self, max_buffer_size=50): """ Open a buffer to send a batch of metrics in one packet. You can also use this as a context manager. >>> with Statsd() as batch: ... batch.gauge('users.online', 123) ... batch.gauge('active.connections', 1001) """ self.max_buffer_size = max_buffer_size self.buffer = [] self._send = self._send_to_buffer def close_buffer(self): """ Flush the buffer and switch back to single metric packets. """ self._send = self._send_to_server if self.buffer: # Only send packets if there are packets to send self._flush_buffer() def close_socket(self): """ Closes connected socket if connected. """ with self.lock: if self._socket: self._socket.close() self._socket = None def _report(self, metric, metric_type, value, tags, sample_rate): """ Create a metric packet and send it. """ if value is None: return if sample_rate != 1 and random() > sample_rate: return # Resolve the full tag list tags = self._add_constant_tags(tags) # Create/format the metric packet payload = "%s%s:%s|%s%s%s" % ( (self.namespace + ".") if self.namespace else "", metric, value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", - ("|#" + ",".join( - "%s:%s" % (k, v) - for (k, v) in sorted(tags.items()) - )) if tags else "", + ("|#" + ",".join("%s:%s" % (k, v) for (k, v) in sorted(tags.items()))) + if tags + else "", ) # Send it self._send(payload) def _send_to_server(self, packet): try: # If set, use socket directly - self.socket.send(packet.encode('utf-8')) + self.socket.send(packet.encode("utf-8")) except socket.timeout: return except socket.error: log.debug( "Error submitting statsd packet." " Dropping the packet and closing the socket." ) self.close_socket() def _send_to_buffer(self, packet): self.buffer.append(packet) if len(self.buffer) >= self.max_buffer_size: self._flush_buffer() def _flush_buffer(self): self._send_to_server("\n".join(self.buffer)) self.buffer = [] def _add_constant_tags(self, tags): return { str(k): str(v) for k, v in itertools.chain( - self.constant_tags.items(), - (tags if tags else {}).items(), + self.constant_tags.items(), (tags if tags else {}).items(), ) } statsd = Statsd() diff --git a/swh/core/tarball.py b/swh/core/tarball.py index 1544432..d557efd 100644 --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -1,147 +1,148 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import stat import tarfile import zipfile from subprocess import run from . import utils def _unpack_tar(tarpath: str, extract_dir: str) -> str: """Unpack tarballs unsupported by the standard python library. Examples include tar.Z, tar.lz, tar.x, etc.... As this implementation relies on the `tar` command, this function supports the same compression the tar command supports. This expects the `extract_dir` to exist. Raises shutil.ReadError in case of issue uncompressing the archive (tarpath does not exist, extract_dir does not exist, etc...) Returns full path to the uncompressed directory. """ try: - run(['tar', 'xf', tarpath, '-C', extract_dir], check=True) + run(["tar", "xf", tarpath, "-C", extract_dir], check=True) return extract_dir except Exception as e: raise shutil.ReadError( - f'Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}') + f"Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}" + ) def register_new_archive_formats(): """Register new archive formats to uncompress """ registered_formats = [f[0] for f in shutil.get_unpack_formats()] for name, extensions, function in ADDITIONAL_ARCHIVE_FORMATS: if name in registered_formats: continue shutil.register_unpack_format(name, extensions, function) def uncompress(tarpath: str, dest: str): """Uncompress tarpath to dest folder if tarball is supported. Note that this fixes permissions after successfully uncompressing the archive. Args: tarpath: path to tarball to uncompress dest: the destination folder where to uncompress the tarball Returns: The nature of the tarball, zip or tar. Raises: ValueError when a problem occurs during unpacking """ try: shutil.unpack_archive(tarpath, extract_dir=dest) except shutil.ReadError as e: - raise ValueError(f'Problem during unpacking {tarpath}. Reason: {e}') + raise ValueError(f"Problem during unpacking {tarpath}. Reason: {e}") # Fix permissions for dirpath, _, fnames in os.walk(dest): os.chmod(dirpath, 0o755) for fname in fnames: fpath = os.path.join(dirpath, fname) if not os.path.islink(fpath): fpath_exec = os.stat(fpath).st_mode & stat.S_IXUSR if not fpath_exec: os.chmod(fpath, 0o644) def _ls(rootdir): """Generator of filepath, filename from rootdir. """ for dirpath, dirnames, fnames in os.walk(rootdir): - for fname in (dirnames+fnames): + for fname in dirnames + fnames: fpath = os.path.join(dirpath, fname) fname = utils.commonname(rootdir, fpath) yield fpath, fname def _compress_zip(tarpath, files): """Compress dirpath's content as tarpath. """ - with zipfile.ZipFile(tarpath, 'w') as z: + with zipfile.ZipFile(tarpath, "w") as z: for fpath, fname in files: z.write(fpath, arcname=fname) def _compress_tar(tarpath, files): """Compress dirpath's content as tarpath. """ - with tarfile.open(tarpath, 'w:bz2') as t: + with tarfile.open(tarpath, "w:bz2") as t: for fpath, fname in files: t.add(fpath, arcname=fname, recursive=False) def compress(tarpath, nature, dirpath_or_files): """Create a tarball tarpath with nature nature. The content of the tarball is either dirpath's content (if representing a directory path) or dirpath's iterable contents. Compress the directory dirpath's content to a tarball. The tarball being dumped at tarpath. The nature of the tarball is determined by the nature argument. """ if isinstance(dirpath_or_files, str): files = _ls(dirpath_or_files) else: # iterable of 'filepath, filename' files = dirpath_or_files - if nature == 'zip': + if nature == "zip": _compress_zip(tarpath, files) else: _compress_tar(tarpath, files) return tarpath # Additional uncompression archive format support ADDITIONAL_ARCHIVE_FORMATS = [ # name , extensions, function - ('tar.Z|x', ['.tar.Z', '.tar.x'], _unpack_tar), + ("tar.Z|x", [".tar.Z", ".tar.x"], _unpack_tar), # FIXME: make this optional depending on the runtime lzip package install - ('tar.lz', ['.tar.lz'], _unpack_tar), + ("tar.lz", [".tar.lz"], _unpack_tar), ] register_new_archive_formats() diff --git a/swh/core/tests/__init__.py b/swh/core/tests/__init__.py index 0e407c3..e70ce2a 100644 --- a/swh/core/tests/__init__.py +++ b/swh/core/tests/__init__.py @@ -1,5 +1,5 @@ from os import path import swh.core -SQL_DIR = path.join(path.dirname(swh.core.__file__), 'sql') +SQL_DIR = path.join(path.dirname(swh.core.__file__), "sql") diff --git a/swh/core/tests/fixture/conftest.py b/swh/core/tests/fixture/conftest.py index 399adac..412d102 100644 --- a/swh/core/tests/fixture/conftest.py +++ b/swh/core/tests/fixture/conftest.py @@ -1,16 +1,16 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from os import path -DATADIR = path.join(path.abspath(path.dirname(__file__)), 'data') +DATADIR = path.join(path.abspath(path.dirname(__file__)), "data") @pytest.fixture def datadir(): return DATADIR diff --git a/swh/core/tests/fixture/test_pytest_plugin.py b/swh/core/tests/fixture/test_pytest_plugin.py index 534a7b7..dbabbd8 100644 --- a/swh/core/tests/fixture/test_pytest_plugin.py +++ b/swh/core/tests/fixture/test_pytest_plugin.py @@ -1,25 +1,24 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import requests from .conftest import DATADIR # In this arborescence, we override in the local conftest.py module the # "datadir" fixture to specify where to retrieve the data files from. -def test_requests_mock_datadir_with_datadir_fixture_override( - requests_mock_datadir): +def test_requests_mock_datadir_with_datadir_fixture_override(requests_mock_datadir): """Override datadir fixture should retrieve data from elsewhere """ - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'welcome': 'you'} + assert response.json() == {"welcome": "you"} def test_data_dir_override(datadir): assert datadir == DATADIR diff --git a/swh/core/tests/test_cli.py b/swh/core/tests/test_cli.py index 1c7ec19..089eb93 100644 --- a/swh/core/tests/test_cli.py +++ b/swh/core/tests/test_cli.py @@ -1,298 +1,283 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import pkg_resources import textwrap from unittest.mock import patch import click from click.testing import CliRunner import pytest -help_msg = '''Usage: swh [OPTIONS] COMMAND [ARGS]... +help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. Options: -l, --log-level [NOTSET|DEBUG|INFO|WARNING|ERROR|CRITICAL] Log level (defaults to INFO). --log-config FILENAME Python yaml logging configuration file. --sentry-dsn TEXT DSN of the Sentry instance to report to -h, --help Show this message and exit. Notes: If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. -''' +""" def test_swh_help(swhmain): runner = CliRunner() - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) - result = runner.invoke(swhmain, ['--help']) + result = runner.invoke(swhmain, ["--help"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) def test_command(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: - result = runner.invoke(swhmain, ['test']) + with patch("sentry_sdk.init") as sentry_sdk_init: + result = runner.invoke(swhmain, ["test"]) sentry_sdk_init.assert_not_called() assert result.exit_code == 0 - assert result.output.strip() == 'Hello SWH!' + assert result.output.strip() == "Hello SWH!" def test_loglevel_default(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 20 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['test']) + result = runner.invoke(swhmain, ["test"]) assert result.exit_code == 0 print(result.output) - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_loglevel_error(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 40 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['-l', 'ERROR', 'test']) + result = runner.invoke(swhmain, ["-l", "ERROR", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_loglevel_debug(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 10 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['-l', 'DEBUG', 'test']) + result = runner.invoke(swhmain, ["-l", "DEBUG", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_sentry(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: - result = runner.invoke(swhmain, ['--sentry-dsn', 'test_dsn', 'test']) + with patch("sentry_sdk.init") as sentry_sdk_init: + result = runner.invoke(swhmain, ["--sentry-dsn", "test_dsn", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=False, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=False, integrations=[], release=None, environment=None, ) def test_sentry_debug(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke( - swhmain, ['--sentry-dsn', 'test_dsn', '--sentry-debug', 'test']) + swhmain, ["--sentry-dsn", "test_dsn", "--sentry-debug", "test"] + ) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, ) def test_sentry_env(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: env = { - 'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1', + "SWH_SENTRY_DSN": "test_dsn", + "SWH_SENTRY_DEBUG": "1", } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=True, integrations=[], release=None, environment=None, ) def test_sentry_env_main_package(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: env = { - 'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_MAIN_PACKAGE': 'swh.core', - 'SWH_SENTRY_ENVIRONMENT': 'tests', + "SWH_SENTRY_DSN": "test_dsn", + "SWH_MAIN_PACKAGE": "swh.core", + "SWH_SENTRY_ENVIRONMENT": "tests", } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 - version = pkg_resources.get_distribution('swh.core').version + version = pkg_resources.get_distribution("swh.core").version - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", debug=False, integrations=[], - release='swh.core@' + version, - environment='tests', + release="swh.core@" + version, + environment="tests", ) @pytest.fixture def log_config_path(tmp_path): - log_config = textwrap.dedent('''\ + log_config = textwrap.dedent( + """\ --- version: 1 formatters: formatter: format: 'custom format:%(name)s:%(levelname)s:%(message)s' handlers: console: class: logging.StreamHandler stream: ext://sys.stdout formatter: formatter level: DEBUG root: level: DEBUG handlers: - console loggers: dontshowdebug: level: INFO - ''') + """ + ) - (tmp_path / 'log_config.yml').write_text(log_config) + (tmp_path / "log_config.yml").write_text(log_config) - yield str(tmp_path / 'log_config.yml') + yield str(tmp_path / "log_config.yml") def test_log_config(caplog, log_config_path, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - logging.debug('Root log debug') - logging.info('Root log info') - logging.getLogger('dontshowdebug').debug('Not shown') - logging.getLogger('dontshowdebug').info('Shown') + logging.debug("Root log debug") + logging.info("Root log info") + logging.getLogger("dontshowdebug").debug("Not shown") + logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() - result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - 'test', - ], - ) + result = runner.invoke(swhmain, ["--log-config", log_config_path, "test",],) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:DEBUG:Root log debug', - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == "\n".join( + [ + "custom format:root:DEBUG:Root log debug", + "custom format:root:INFO:Root log info", + "custom format:dontshowdebug:INFO:Shown", + ] + ) def test_log_config_log_level_interaction(caplog, log_config_path, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - logging.debug('Root log debug') - logging.info('Root log info') - logging.getLogger('dontshowdebug').debug('Not shown') - logging.getLogger('dontshowdebug').info('Shown') + logging.debug("Root log debug") + logging.info("Root log info") + logging.getLogger("dontshowdebug").debug("Not shown") + logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - '--log-level', 'INFO', - 'test', - ], + swhmain, ["--log-config", log_config_path, "--log-level", "INFO", "test",], ) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == "\n".join( + [ + "custom format:root:INFO:Root log info", + "custom format:dontshowdebug:INFO:Shown", + ] + ) def test_aliased_command(swhmain): - @swhmain.command(name='canonical-test') + @swhmain.command(name="canonical-test") @click.pass_context def swhtest(ctx): - 'A test command.' - click.echo('Hello SWH!') - swhmain.add_alias(swhtest, 'othername') + "A test command." + click.echo("Hello SWH!") + + swhmain.add_alias(swhtest, "othername") runner = CliRunner() # check we have only 'canonical-test' listed in the usage help msg - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 - assert 'canonical-test A test command.' in result.output - assert 'othername' not in result.output + assert "canonical-test A test command." in result.output + assert "othername" not in result.output # check we can execute the cmd with 'canonical-test' - result = runner.invoke(swhmain, ['canonical-test']) + result = runner.invoke(swhmain, ["canonical-test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" # check we can also execute the cmd with the alias 'othername' - result = runner.invoke(swhmain, ['othername']) + result = runner.invoke(swhmain, ["othername"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py index 8e5bbf8..973b98d 100644 --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -1,312 +1,314 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import pytest import pkg_resources.extern.packaging.version from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version -if pytest_v < pkg_resources.extern.packaging.version.parse('3.9'): +if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): + @pytest.fixture def tmp_path(request): import tempfile import pathlib + with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { - 'a': ('int', 2), - 'b': ('string', 'default-string'), - 'c': ('bool', True), - 'd': ('int', 10), - 'e': ('int', None), - 'f': ('bool', None), - 'g': ('string', None), - 'h': ('bool', True), - 'i': ('bool', True), - 'ls': ('list[str]', ['a', 'b', 'c']), - 'li': ('list[int]', [42, 43]), + "a": ("int", 2), + "b": ("string", "default-string"), + "c": ("bool", True), + "d": ("int", 10), + "e": ("int", None), + "f": ("bool", None), + "g": ("string", None), + "h": ("bool", True), + "i": ("bool", True), + "ls": ("list[str]", ["a", "b", "c"]), + "li": ("list[int]", [42, 43]), } other_default_conf = { - 'a': ('int', 3), + "a": ("int", 3), } full_default_conf = default_conf.copy() -full_default_conf['a'] = other_default_conf['a'] +full_default_conf["a"] = other_default_conf["a"] -parsed_default_conf = { - key: value - for key, (type, value) - in default_conf.items() -} +parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { - 'a': 1, - 'b': 'this is a string', - 'c': True, - 'd': 10, - 'e': None, - 'f': None, - 'g': None, - 'h': False, - 'i': True, - 'ls': ['list', 'of', 'strings'], - 'li': [1, 2, 3, 4], + "a": 1, + "b": "this is a string", + "c": True, + "d": 10, + "e": None, + "f": None, + "g": None, + "h": False, + "i": True, + "ls": ["list", "of", "strings"], + "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder - conffile = tmp_path / 'config.ini' + conffile = tmp_path / "config.ini" conf_contents = """[main] a = 1 b = this is a string c = true h = false ls = list, of, strings li = 1, 2, 3, 4 """ - conffile.open('w').write(conf_contents) + conffile.open("w").write(conf_contents) return conffile @pytest.fixture def swh_config_unreadable(swh_config): # Create an unreadable, proper configuration file os.chmod(str(swh_config), 0o000) yield swh_config # Make the broken perms file readable again to be able to remove them os.chmod(str(swh_config), 0o644) @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory - perms_broken_dir = swh_config.parent / 'unreadabledir' + perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) yield perms_broken_dir / swh_config.name # Make the broken perms items readable again to be able to remove them os.chmod(str(perms_broken_dir), 0o755) @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder - conffile = tmp_path / 'config.ini' + conffile = tmp_path / "config.ini" conffile.touch() return conffile def test_read(swh_config): # when res = config.read(str(swh_config), default_conf) # then assert res == parsed_conffile def test_read_empty_file(): # when res = config.read(None, default_conf) # then assert res == parsed_default_conf def test_support_non_existing_conffile(tmp_path): # when - res = config.read(str(tmp_path / 'void.ini'), default_conf) + res = config.read(str(tmp_path / "void.ini"), default_conf) # then assert res == parsed_default_conf def test_support_empty_conffile(swh_config_empty): # when res = config.read(str(swh_config_empty), default_conf) # then assert res == parsed_default_conf def test_raise_on_broken_directory_perms(swh_config_unreadable_dir): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable_dir), default_conf) def test_raise_on_broken_file_perms(swh_config_unreadable): with pytest.raises(PermissionError): config.read(str(swh_config_unreadable), default_conf) def test_merge_default_configs(): # when res = config.merge_default_configs(default_conf, other_default_conf) # then assert res == full_default_conf def test_priority_read_nonexist_conf(swh_config): - noexist = str(swh_config.parent / 'void.ini') + noexist = str(swh_config.parent / "void.ini") # when res = config.priority_read([noexist, str(swh_config)], default_conf) # then assert res == parsed_conffile def test_priority_read_conf_nonexist_empty(swh_config): - noexist = swh_config.parent / 'void.ini' - empty = swh_config.parent / 'empty.ini' + noexist = swh_config.parent / "void.ini" + empty = swh_config.parent / "empty.ini" empty.touch() # when - res = config.priority_read([str(p) for p in ( - swh_config, noexist, empty)], default_conf) + res = config.priority_read( + [str(p) for p in (swh_config, noexist, empty)], default_conf + ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): - noexist = swh_config.parent / 'void.ini' - empty = swh_config.parent / 'empty.ini' + noexist = swh_config.parent / "void.ini" + empty = swh_config.parent / "empty.ini" empty.touch() # when - res = config.priority_read([str(p) for p in ( - empty, swh_config, noexist)], default_conf) + res = config.priority_read( + [str(p) for p in (empty, swh_config, noexist)], default_conf + ) # then assert res == parsed_default_conf def test_swh_config_paths(): - res = config.swh_config_paths('foo/bar.ini') + res = config.swh_config_paths("foo/bar.ini") assert res == [ - '~/.config/swh/foo/bar.ini', - '~/.swh/foo/bar.ini', - '/etc/softwareheritage/foo/bar.ini', + "~/.config/swh/foo/bar.ini", + "~/.swh/foo/bar.ini", + "/etc/softwareheritage/foo/bar.ini", ] def test_prepare_folder(tmp_path): # given - conf = {'path1': str(tmp_path / 'path1'), - 'path2': str(tmp_path / 'path2' / 'depth1')} + conf = { + "path1": str(tmp_path / "path1"), + "path2": str(tmp_path / "path2" / "depth1"), + } # the folders does not exists - assert not os.path.exists(conf['path1']), "path1 should not exist." - assert not os.path.exists(conf['path2']), "path2 should not exist." + assert not os.path.exists(conf["path1"]), "path1 should not exist." + assert not os.path.exists(conf["path2"]), "path2 should not exist." # when - config.prepare_folders(conf, 'path1') + config.prepare_folders(conf, "path1") # path1 exists but not path2 - assert os.path.exists(conf['path1']), "path1 should now exist!" - assert not os.path.exists(conf['path2']), "path2 should not exist." + assert os.path.exists(conf["path1"]), "path1 should now exist!" + assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 - config.prepare_folders(conf, 'path1', 'path2') + config.prepare_folders(conf, "path1", "path2") - assert os.path.exists(conf['path1']), "path1 should still exist!" - assert os.path.exists(conf['path2']), "path2 should now exist." + assert os.path.exists(conf["path1"]), "path1 should still exist!" + assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { - 'a': 42, - 'b': [1, 2, 3], - 'c': None, - 'd': {'gheez': 27}, - 'e': { - 'ea': 'Mr. Bungle', - 'eb': None, - 'ec': [11, 12, 13], - 'ed': {'eda': 'Secret Chief 3', - 'edb': 'Faith No More'}, - 'ee': 451, + "a": 42, + "b": [1, 2, 3], + "c": None, + "d": {"gheez": 27}, + "e": { + "ea": "Mr. Bungle", + "eb": None, + "ec": [11, 12, 13], + "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, + "ee": 451, }, - 'f': 'Janis', + "f": "Janis", } cfg_b = { - 'a': 43, - 'b': [41, 42, 43], - 'c': 'Tom Waits', - 'd': None, - 'e': { - 'ea': 'Igorrr', - 'ec': [51, 52], - 'ed': {'edb': 'Sleepytime Gorilla Museum', - 'edc': 'Nils Peter Molvaer'}, + "a": 43, + "b": [41, 42, 43], + "c": "Tom Waits", + "d": None, + "e": { + "ea": "Igorrr", + "ec": [51, 52], + "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, - 'g': 'Hüsker Dü', + "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { - 'a': 43, # b takes precedence - 'b': [41, 42, 43], # b takes precedence - 'c': 'Tom Waits', # b takes precedence - 'd': None, # b['d'] takes precedence (explicit None) - 'e': { - 'ea': 'Igorrr', # a takes precedence - 'eb': None, # only in a - 'ec': [51, 52], # b takes precedence - 'ed': { - 'eda': 'Secret Chief 3', # only in a - 'edb': 'Sleepytime Gorilla Museum', # b takes precedence - 'edc': 'Nils Peter Molvaer'}, # only defined in b - 'ee': 451, + "a": 43, # b takes precedence + "b": [41, 42, 43], # b takes precedence + "c": "Tom Waits", # b takes precedence + "d": None, # b['d'] takes precedence (explicit None) + "e": { + "ea": "Igorrr", # a takes precedence + "eb": None, # only in a + "ec": [51, 52], # b takes precedence + "ed": { + "eda": "Secret Chief 3", # only in a + "edb": "Sleepytime Gorilla Museum", # b takes precedence + "edc": "Nils Peter Molvaer", + }, # only defined in b + "ee": 451, }, - 'f': 'Janis', # only defined in a - 'g': 'Hüsker Dü', # only defined in b + "f": "Janis", # only defined in a + "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { - 'a': 42, # a takes precedence - 'b': [1, 2, 3], # a takes precedence - 'c': None, # a takes precedence - 'd': {'gheez': 27}, # a takes precedence - 'e': { - 'ea': 'Mr. Bungle', # a takes precedence - 'eb': None, # only defined in a - 'ec': [11, 12, 13], # a takes precedence - 'ed': { - 'eda': 'Secret Chief 3', # only in a - 'edb': 'Faith No More', # a takes precedence - 'edc': 'Nils Peter Molvaer'}, # only in b - 'ee': 451, + "a": 42, # a takes precedence + "b": [1, 2, 3], # a takes precedence + "c": None, # a takes precedence + "d": {"gheez": 27}, # a takes precedence + "e": { + "ea": "Mr. Bungle", # a takes precedence + "eb": None, # only defined in a + "ec": [11, 12, 13], # a takes precedence + "ed": { + "eda": "Secret Chief 3", # only in a + "edb": "Faith No More", # a takes precedence + "edc": "Nils Peter Molvaer", + }, # only in b + "ee": 451, }, - 'f': 'Janis', # only in a - 'g': 'Hüsker Dü', # only in b + "f": "Janis", # only in a + "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): - for v in (1, 'str', None): + for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) - for v in (1, 'str'): + for v in (1, "str"): with pytest.raises(TypeError): - config.merge_configs({'a': v}, {'a': {}}) + config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): - config.merge_configs({'a': {}}, {'a': v}) + config.merge_configs({"a": {}}, {"a": v}) diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py index 09afaaf..6980e6e 100644 --- a/swh/core/tests/test_logger.py +++ b/swh/core/tests/test_logger.py @@ -1,133 +1,130 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime import logging import pytz import inspect from unittest.mock import patch from swh.core import logger def lineno(): """Returns the current line number in our program.""" return inspect.currentframe().f_back.f_lineno def test_db_level(): - assert logger.db_level_of_py_level(10) == 'debug' - assert logger.db_level_of_py_level(20) == 'info' - assert logger.db_level_of_py_level(30) == 'warning' - assert logger.db_level_of_py_level(40) == 'error' - assert logger.db_level_of_py_level(50) == 'critical' + assert logger.db_level_of_py_level(10) == "debug" + assert logger.db_level_of_py_level(20) == "info" + assert logger.db_level_of_py_level(30) == "warning" + assert logger.db_level_of_py_level(40) == "error" + assert logger.db_level_of_py_level(50) == "critical" def test_flatten_scalar(): - assert list(logger.flatten('')) == [('', '')] - assert list(logger.flatten('toto')) == [('', 'toto')] + assert list(logger.flatten("")) == [("", "")] + assert list(logger.flatten("toto")) == [("", "toto")] - assert list(logger.flatten(10)) == [('', 10)] - assert list(logger.flatten(10.5)) == [('', 10.5)] + assert list(logger.flatten(10)) == [("", 10)] + assert list(logger.flatten(10.5)) == [("", 10.5)] def test_flatten_list(): assert list(logger.flatten([])) == [] - assert list(logger.flatten([1])) == [('0', 1)] + assert list(logger.flatten([1])) == [("0", 1)] - assert list(logger.flatten([1, 2, ['a', 'b']])) == [ - ('0', 1), - ('1', 2), - ('2_0', 'a'), - ('2_1', 'b'), + assert list(logger.flatten([1, 2, ["a", "b"]])) == [ + ("0", 1), + ("1", 2), + ("2_0", "a"), + ("2_1", "b"), ] - assert list(logger.flatten([1, 2, ['a', ('x', 1)]])) == [ - ('0', 1), - ('1', 2), - ('2_0', 'a'), - ('2_1_0', 'x'), - ('2_1_1', 1), + assert list(logger.flatten([1, 2, ["a", ("x", 1)]])) == [ + ("0", 1), + ("1", 2), + ("2_0", "a"), + ("2_1_0", "x"), + ("2_1_1", 1), ] def test_flatten_dict(): assert list(logger.flatten({})) == [] - assert list(logger.flatten({'a': 1})) == [('a', 1)] - - assert sorted(logger.flatten({'a': 1, - 'b': (2, 3,), - 'c': {'d': 4, 'e': 'f'}})) == [ - ('a', 1), - ('b_0', 2), - ('b_1', 3), - ('c_d', 4), - ('c_e', 'f'), + assert list(logger.flatten({"a": 1})) == [("a", 1)] + + assert sorted(logger.flatten({"a": 1, "b": (2, 3,), "c": {"d": 4, "e": "f"}})) == [ + ("a", 1), + ("b_0", 2), + ("b_1", 3), + ("c_d", 4), + ("c_e", "f"), ] def test_flatten_dict_binary_keys(): d = {b"a": "a"} str_d = str(d) assert list(logger.flatten(d)) == [("", str_d)] assert list(logger.flatten({"a": d})) == [("a", str_d)] - assert list(logger.flatten({"a": [d, d]})) == [ - ("a_0", str_d), ("a_1", str_d) - ] + assert list(logger.flatten({"a": [d, d]})) == [("a_0", str_d), ("a_1", str_d)] def test_stringify(): - assert logger.stringify(None) == 'None' - assert logger.stringify(123) == '123' - assert logger.stringify('abc') == 'abc' + assert logger.stringify(None) == "None" + assert logger.stringify(123) == "123" + assert logger.stringify("abc") == "abc" date = datetime(2019, 9, 1, 16, 32) - assert logger.stringify(date) == '2019-09-01T16:32:00' + assert logger.stringify(date) == "2019-09-01T16:32:00" tzdate = datetime(2019, 9, 1, 16, 32, tzinfo=pytz.utc) - assert logger.stringify(tzdate) == '2019-09-01T16:32:00+00:00' + assert logger.stringify(tzdate) == "2019-09-01T16:32:00+00:00" -@patch('swh.core.logger.send') +@patch("swh.core.logger.send") def test_journal_handler(send): - log = logging.getLogger('test_logger') + log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) - _, ln = log.info('hello world'), lineno() + _, ln = log.info("hello world"), lineno() send.assert_called_with( - 'hello world', + "hello world", CODE_FILE=__file__, - CODE_FUNC='test_journal_handler', + CODE_FUNC="test_journal_handler", CODE_LINE=ln, - LOGGER='test_logger', - PRIORITY='6', - THREAD_NAME='MainThread') + LOGGER="test_logger", + PRIORITY="6", + THREAD_NAME="MainThread", + ) -@patch('swh.core.logger.send') +@patch("swh.core.logger.send") def test_journal_handler_w_data(send): - log = logging.getLogger('test_logger') + log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) _, ln = ( - log.debug('something cool %s', ['with', {'extra': 'data'}]), + log.debug("something cool %s", ["with", {"extra": "data"}]), lineno() - 1, ) send.assert_called_with( "something cool ['with', {'extra': 'data'}]", CODE_FILE=__file__, - CODE_FUNC='test_journal_handler_w_data', + CODE_FUNC="test_journal_handler_w_data", CODE_LINE=ln, - LOGGER='test_logger', - PRIORITY='7', - THREAD_NAME='MainThread', - SWH_LOGGING_ARGS_0_0='with', - SWH_LOGGING_ARGS_0_1_EXTRA='data' + LOGGER="test_logger", + PRIORITY="7", + THREAD_NAME="MainThread", + SWH_LOGGING_ARGS_0_0="with", + SWH_LOGGING_ARGS_0_1_EXTRA="data", ) diff --git a/swh/core/tests/test_pytest_plugin.py b/swh/core/tests/test_pytest_plugin.py index 547d2b0..7c72df3 100644 --- a/swh/core/tests/test_pytest_plugin.py +++ b/swh/core/tests/test_pytest_plugin.py @@ -1,120 +1,117 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import requests from os import path from urllib.parse import unquote from swh.core.pytest_plugin import requests_mock_datadir_factory def test_get_response_cb_with_encoded_url(requests_mock_datadir): # The following urls (quoted, unquoted) will be resolved as the same file for encoded_url, expected_response in [ - ('https://forge.s.o/api/diffusion?attachments%5Buris%5D=1', - "something"), - ('https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1', # noqa - "something else"), + ("https://forge.s.o/api/diffusion?attachments%5Buris%5D=1", "something"), + ( + "https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1", # noqa + "something else", + ), ]: for url in [encoded_url, unquote(encoded_url)]: response = requests.get(url) assert response.ok assert response.json() == expected_response def test_get_response_cb_with_visits_nominal(requests_mock_datadir_visits): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('http://example.com/something.json') + response = requests.get("http://example.com/something.json") assert response.ok assert response.json() == "something" - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'world'} + assert response.json() == {"hello": "world"} - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_with_visits(requests_mock_datadir_visits): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('https://example.com/other.json') + response = requests.get("https://example.com/other.json") assert response.ok assert response.json() == "foobar" - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'world'} + assert response.json() == {"hello": "world"} - response = requests.get('https://example.com/other.json') + response = requests.get("https://example.com/other.json") assert not response.ok assert response.status_code == 404 - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_no_visit(requests_mock_datadir): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} def test_get_response_cb_query_params(requests_mock_datadir): - response = requests.get('https://example.com/file.json?toto=42') + response = requests.get("https://example.com/file.json?toto=42") assert not response.ok assert response.status_code == 404 - response = requests.get( - 'https://example.com/file.json?name=doe&firstname=jane') + response = requests.get("https://example.com/file.json?name=doe&firstname=jane") assert response.ok - assert response.json() == {'hello': 'jane doe'} + assert response.json() == {"hello": "jane doe"} requests_mock_datadir_ignore = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=False, + ignore_urls=["https://example.com/file.json"], has_multi_visit=False, ) def test_get_response_cb_ignore_url(requests_mock_datadir_ignore): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 requests_mock_datadir_ignore_and_visit = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=True, + ignore_urls=["https://example.com/file.json"], has_multi_visit=True, ) -def test_get_response_cb_ignore_url_with_visit( - requests_mock_datadir_ignore_and_visit): - response = requests.get('https://example.com/file.json') +def test_get_response_cb_ignore_url_with_visit(requests_mock_datadir_ignore_and_visit): + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_data_dir(datadir): - expected_datadir = path.join(path.abspath(path.dirname(__file__)), 'data') + expected_datadir = path.join(path.abspath(path.dirname(__file__)), "data") assert datadir == expected_datadir diff --git a/swh/core/tests/test_statsd.py b/swh/core/tests/test_statsd.py index 56d1aa9..c0fa1ff 100644 --- a/swh/core/tests/test_statsd.py +++ b/swh/core/tests/test_statsd.py @@ -1,563 +1,560 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information # Initially imported from https://github.com/DataDog/datadogpy/ # at revision 62b3a3e89988dc18d78c282fe3ff5d1813917436 # # Copyright (c) 2015, Datadog # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of Datadog nor the names of its contributors may be # used to endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # from collections import deque from contextlib import contextmanager import os import socket import time import unittest import pytest from swh.core.statsd import Statsd, TimedContextManagerDecorator @contextmanager def preserve_envvars(*envvars): """Context manager preserving the value of environment variables""" preserved = {} to_delete = object() for var in envvars: preserved[var] = os.environ.get(var, to_delete) yield for var in envvars: old = preserved[var] if old is not to_delete: os.environ[var] = old else: del os.environ[var] class FakeSocket(object): """ A fake socket for testing. """ def __init__(self): self.payloads = deque() def send(self, payload): assert type(payload) == bytes self.payloads.append(payload) def recv(self): try: - return self.payloads.popleft().decode('utf-8') + return self.payloads.popleft().decode("utf-8") except IndexError: return None def close(self): pass def __repr__(self): return str(self.payloads) class BrokenSocket(FakeSocket): def send(self, payload): raise socket.error("Socket error") class SlowSocket(FakeSocket): def send(self, payload): raise socket.timeout("Socket timeout") class TestStatsd(unittest.TestCase): - def setUp(self): """ Set up a default Statsd instance and mock the socket. """ # self.statsd = Statsd() self.statsd._socket = FakeSocket() def recv(self): return self.statsd.socket.recv() def test_set(self): - self.statsd.set('set', 123) - assert self.recv() == 'set:123|s' + self.statsd.set("set", 123) + assert self.recv() == "set:123|s" def test_gauge(self): - self.statsd.gauge('gauge', 123.4) - assert self.recv() == 'gauge:123.4|g' + self.statsd.gauge("gauge", 123.4) + assert self.recv() == "gauge:123.4|g" def test_counter(self): - self.statsd.increment('page.views') - self.assertEqual('page.views:1|c', self.recv()) + self.statsd.increment("page.views") + self.assertEqual("page.views:1|c", self.recv()) - self.statsd.increment('page.views', 11) - self.assertEqual('page.views:11|c', self.recv()) + self.statsd.increment("page.views", 11) + self.assertEqual("page.views:11|c", self.recv()) - self.statsd.decrement('page.views') - self.assertEqual('page.views:-1|c', self.recv()) + self.statsd.decrement("page.views") + self.assertEqual("page.views:-1|c", self.recv()) - self.statsd.decrement('page.views', 12) - self.assertEqual('page.views:-12|c', self.recv()) + self.statsd.decrement("page.views", 12) + self.assertEqual("page.views:-12|c", self.recv()) def test_histogram(self): - self.statsd.histogram('histo', 123.4) - self.assertEqual('histo:123.4|h', self.recv()) + self.statsd.histogram("histo", 123.4) + self.assertEqual("histo:123.4|h", self.recv()) def test_tagged_gauge(self): - self.statsd.gauge('gt', 123.4, tags={'country': 'china', 'age': 45}) - self.assertEqual('gt:123.4|g|#age:45,country:china', self.recv()) + self.statsd.gauge("gt", 123.4, tags={"country": "china", "age": 45}) + self.assertEqual("gt:123.4|g|#age:45,country:china", self.recv()) def test_tagged_counter(self): - self.statsd.increment('ct', tags={'country': 'españa'}) - self.assertEqual('ct:1|c|#country:españa', self.recv()) + self.statsd.increment("ct", tags={"country": "españa"}) + self.assertEqual("ct:1|c|#country:españa", self.recv()) def test_tagged_histogram(self): - self.statsd.histogram('h', 1, tags={'test_tag': 'tag_value'}) - self.assertEqual('h:1|h|#test_tag:tag_value', self.recv()) + self.statsd.histogram("h", 1, tags={"test_tag": "tag_value"}) + self.assertEqual("h:1|h|#test_tag:tag_value", self.recv()) def test_sample_rate(self): - self.statsd.increment('c', sample_rate=0) + self.statsd.increment("c", sample_rate=0) assert not self.recv() for i in range(10000): - self.statsd.increment('sampled_counter', sample_rate=0.3) + self.statsd.increment("sampled_counter", sample_rate=0.3) self.assert_almost_equal(3000, len(self.statsd.socket.payloads), 150) - self.assertEqual('sampled_counter:1|c|@0.3', self.recv()) + self.assertEqual("sampled_counter:1|c|@0.3", self.recv()) def test_tags_and_samples(self): for i in range(100): - self.statsd.gauge('gst', 23, tags={"sampled": True}, - sample_rate=0.9) + self.statsd.gauge("gst", 23, tags={"sampled": True}, sample_rate=0.9) self.assert_almost_equal(90, len(self.statsd.socket.payloads), 10) - self.assertEqual('gst:23|g|@0.9|#sampled:True', self.recv()) + self.assertEqual("gst:23|g|@0.9|#sampled:True", self.recv()) def test_timing(self): - self.statsd.timing('t', 123) - self.assertEqual('t:123|ms', self.recv()) + self.statsd.timing("t", 123) + self.assertEqual("t:123|ms", self.recv()) def test_metric_namespace(self): """ Namespace prefixes all metric names. """ self.statsd.namespace = "foo" - self.statsd.gauge('gauge', 123.4) - self.assertEqual('foo.gauge:123.4|g', self.recv()) + self.statsd.gauge("gauge", 123.4) + self.assertEqual("foo.gauge:123.4|g", self.recv()) # Test Client level constant tags def test_gauge_constant_tags(self): self.statsd.constant_tags = { - 'bar': 'baz', + "bar": "baz", } - self.statsd.gauge('gauge', 123.4) - assert self.recv() == 'gauge:123.4|g|#bar:baz' + self.statsd.gauge("gauge", 123.4) + assert self.recv() == "gauge:123.4|g|#bar:baz" def test_counter_constant_tag_with_metric_level_tags(self): self.statsd.constant_tags = { - 'bar': 'baz', - 'foo': True, + "bar": "baz", + "foo": True, } - self.statsd.increment('page.views', tags={'extra': 'extra'}) + self.statsd.increment("page.views", tags={"extra": "extra"}) self.assertEqual( - 'page.views:1|c|#bar:baz,extra:extra,foo:True', - self.recv(), + "page.views:1|c|#bar:baz,extra:extra,foo:True", self.recv(), ) def test_gauge_constant_tags_with_metric_level_tags_twice(self): - metric_level_tag = {'foo': 'bar'} - self.statsd.constant_tags = {'bar': 'baz'} - self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) - assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' + metric_level_tag = {"foo": "bar"} + self.statsd.constant_tags = {"bar": "baz"} + self.statsd.gauge("gauge", 123.4, tags=metric_level_tag) + assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar" # sending metrics multiple times with same metric-level tags # should not duplicate the tags being sent - self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) - assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' + self.statsd.gauge("gauge", 123.4, tags=metric_level_tag) + assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar" def assert_almost_equal(self, a, b, delta): self.assertTrue( - 0 <= abs(a - b) <= delta, - "%s - %s not within %s" % (a, b, delta) + 0 <= abs(a - b) <= delta, "%s - %s not within %s" % (a, b, delta) ) def test_socket_error(self): self.statsd._socket = BrokenSocket() - self.statsd.gauge('no error', 1) - assert True, 'success' + self.statsd.gauge("no error", 1) + assert True, "success" def test_socket_timeout(self): self.statsd._socket = SlowSocket() - self.statsd.gauge('no error', 1) - assert True, 'success' + self.statsd.gauge("no error", 1) + assert True, "success" def test_timed(self): """ Measure the distribution of a function's run time. """ - @self.statsd.timed('timed.test') + + @self.statsd.timed("timed.test") def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a, b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed.test", name) self.assert_almost_equal(500, float(value), 100) def test_timed_exception(self): """ Exception bubble out of the decorator and is reported to statsd as a dedicated counter. """ - @self.statsd.timed('timed.test') + + @self.statsd.timed("timed.test") def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a / b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) with self.assertRaises(ZeroDivisionError): func(1, 0) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('c', type_) - self.assertEqual('timed.test_error_count', name) + self.assertEqual("c", type_) + self.assertEqual("timed.test_error_count", name) self.assertEqual(int(value), 1) - def test_timed_no_metric(self, ): + def test_timed_no_metric(self,): """ Test using a decorator without providing a metric. """ @self.statsd.timed() def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a, b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('swh.core.tests.test_statsd.func', name) + self.assertEqual("ms", type_) + self.assertEqual("swh.core.tests.test_statsd.func", name) self.assert_almost_equal(500, float(value), 100) def test_timed_coroutine(self): """ Measure the distribution of a coroutine function's run time. Warning: Python >= 3.5 only. """ import asyncio - @self.statsd.timed('timed.test') + @self.statsd.timed("timed.test") @asyncio.coroutine def print_foo(): """docstring""" time.sleep(0.5) print("foo") loop = asyncio.new_event_loop() loop.run_until_complete(print_foo()) loop.close() # Assert packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed.test", name) self.assert_almost_equal(500, float(value), 100) def test_timed_context(self): """ Measure the distribution of a context's run time. """ # In milliseconds - with self.statsd.timed('timed_context.test') as timer: + with self.statsd.timed("timed_context.test") as timer: self.assertIsInstance(timer, TimedContextManagerDecorator) time.sleep(0.5) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed_context.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed_context.test", name) self.assert_almost_equal(500, float(value), 100) self.assert_almost_equal(500, timer.elapsed, 100) def test_timed_context_exception(self): """ Exception bubbles out of the `timed` context manager and is reported to statsd as a dedicated counter. """ + class ContextException(Exception): pass def func(self): - with self.statsd.timed('timed_context.test'): + with self.statsd.timed("timed_context.test"): time.sleep(0.5) raise ContextException() # Ensure the exception was raised. self.assertRaises(ContextException, func, self) # Ensure the timing was recorded. packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('c', type_) - self.assertEqual('timed_context.test_error_count', name) + self.assertEqual("c", type_) + self.assertEqual("timed_context.test_error_count", name) self.assertEqual(int(value), 1) def test_timed_context_no_metric_name_exception(self): """Test that an exception occurs if using a context manager without a metric name. """ def func(self): with self.statsd.timed(): time.sleep(0.5) # Ensure the exception was raised. self.assertRaises(TypeError, func, self) # Ensure the timing was recorded. packet = self.recv() self.assertEqual(packet, None) def test_timed_start_stop_calls(self): - timer = self.statsd.timed('timed_context.test') + timer = self.statsd.timed("timed_context.test") timer.start() time.sleep(0.5) timer.stop() packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed_context.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed_context.test", name) self.assert_almost_equal(500, float(value), 100) def test_batched(self): self.statsd.open_buffer() - self.statsd.gauge('page.views', 123) - self.statsd.timing('timer', 123) + self.statsd.gauge("page.views", 123) + self.statsd.timing("timer", 123) self.statsd.close_buffer() - self.assertEqual('page.views:123|g\ntimer:123|ms', self.recv()) + self.assertEqual("page.views:123|g\ntimer:123|ms", self.recv()) def test_context_manager(self): fake_socket = FakeSocket() with Statsd() as statsd: statsd._socket = fake_socket - statsd.gauge('page.views', 123) - statsd.timing('timer', 123) + statsd.gauge("page.views", 123) + statsd.timing("timer", 123) - self.assertEqual('page.views:123|g\ntimer:123|ms', fake_socket.recv()) + self.assertEqual("page.views:123|g\ntimer:123|ms", fake_socket.recv()) def test_batched_buffer_autoflush(self): fake_socket = FakeSocket() with Statsd() as statsd: statsd._socket = fake_socket for i in range(51): - statsd.increment('mycounter') + statsd.increment("mycounter") self.assertEqual( - '\n'.join(['mycounter:1|c' for i in range(50)]), - fake_socket.recv(), + "\n".join(["mycounter:1|c" for i in range(50)]), fake_socket.recv(), ) - self.assertEqual('mycounter:1|c', fake_socket.recv()) + self.assertEqual("mycounter:1|c", fake_socket.recv()) def test_module_level_instance(self): from swh.core.statsd import statsd + self.assertTrue(isinstance(statsd, Statsd)) def test_instantiating_does_not_connect(self): local_statsd = Statsd() self.assertEqual(None, local_statsd._socket) def test_accessing_socket_opens_socket(self): local_statsd = Statsd() try: self.assertIsNotNone(local_statsd.socket) finally: local_statsd.close_socket() def test_accessing_socket_multiple_times_returns_same_socket(self): local_statsd = Statsd() fresh_socket = FakeSocket() local_statsd._socket = fresh_socket self.assertEqual(fresh_socket, local_statsd.socket) self.assertNotEqual(FakeSocket(), local_statsd.socket) def test_tags_from_environment(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'country:china,age:45' + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "country:china,age:45" statsd = Statsd() statsd._socket = FakeSocket() - statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:china', - statsd.socket.recv()) + statsd.gauge("gt", 123.4) + self.assertEqual("gt:123.4|g|#age:45,country:china", statsd.socket.recv()) def test_tags_from_environment_and_constant(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'country:china,age:45' - statsd = Statsd(constant_tags={'country': 'canada'}) + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "country:china,age:45" + statsd = Statsd(constant_tags={"country": "canada"}) statsd._socket = FakeSocket() - statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:canada', - statsd.socket.recv()) + statsd.gauge("gt", 123.4) + self.assertEqual("gt:123.4|g|#age:45,country:canada", statsd.socket.recv()) def test_tags_from_environment_warning(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'valid:tag,invalid_tag' + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "valid:tag,invalid_tag" with pytest.warns(UserWarning) as record: statsd = Statsd() assert len(record) == 1 - assert 'invalid_tag' in record[0].message.args[0] - assert 'valid:tag' not in record[0].message.args[0] - assert statsd.constant_tags == {'valid': 'tag'} + assert "invalid_tag" in record[0].message.args[0] + assert "valid:tag" not in record[0].message.args[0] + assert statsd.constant_tags == {"valid": "tag"} def test_gauge_doesnt_send_none(self): - self.statsd.gauge('metric', None) + self.statsd.gauge("metric", None) assert self.recv() is None def test_increment_doesnt_send_none(self): - self.statsd.increment('metric', None) + self.statsd.increment("metric", None) assert self.recv() is None def test_decrement_doesnt_send_none(self): - self.statsd.decrement('metric', None) + self.statsd.decrement("metric", None) assert self.recv() is None def test_timing_doesnt_send_none(self): - self.statsd.timing('metric', None) + self.statsd.timing("metric", None) assert self.recv() is None def test_histogram_doesnt_send_none(self): - self.statsd.histogram('metric', None) + self.statsd.histogram("metric", None) assert self.recv() is None def test_param_host(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = 'test-value' - os.environ['STATSD_PORT'] = '' - local_statsd = Statsd(host='actual-test-value') + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "test-value" + os.environ["STATSD_PORT"] = "" + local_statsd = Statsd(host="actual-test-value") - self.assertEqual(local_statsd.host, 'actual-test-value') + self.assertEqual(local_statsd.host, "actual-test-value") self.assertEqual(local_statsd.port, 8125) def test_param_port(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = '' - os.environ['STATSD_PORT'] = '12345' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "" + os.environ["STATSD_PORT"] = "12345" local_statsd = Statsd(port=4321) - self.assertEqual(local_statsd.host, 'localhost') + self.assertEqual(local_statsd.host, "localhost") self.assertEqual(local_statsd.port, 4321) def test_envvar_host(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = 'test-value' - os.environ['STATSD_PORT'] = '' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "test-value" + os.environ["STATSD_PORT"] = "" local_statsd = Statsd() - self.assertEqual(local_statsd.host, 'test-value') + self.assertEqual(local_statsd.host, "test-value") self.assertEqual(local_statsd.port, 8125) def test_envvar_port(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = '' - os.environ['STATSD_PORT'] = '12345' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "" + os.environ["STATSD_PORT"] = "12345" local_statsd = Statsd() - self.assertEqual(local_statsd.host, 'localhost') + self.assertEqual(local_statsd.host, "localhost") self.assertEqual(local_statsd.port, 12345) def test_namespace_added(self): - local_statsd = Statsd(namespace='test-namespace') + local_statsd = Statsd(namespace="test-namespace") local_statsd._socket = FakeSocket() - local_statsd.gauge('gauge', 123.4) - assert local_statsd.socket.recv() == 'test-namespace.gauge:123.4|g' + local_statsd.gauge("gauge", 123.4) + assert local_statsd.socket.recv() == "test-namespace.gauge:123.4|g" def test_contextmanager_empty(self): with self.statsd: - assert True, 'success' + assert True, "success" def test_contextmanager_buffering(self): with self.statsd as s: - s.gauge('gauge', 123.4) - s.gauge('gauge_other', 456.78) + s.gauge("gauge", 123.4) + s.gauge("gauge_other", 456.78) self.assertIsNone(s.socket.recv()) - self.assertEqual(self.recv(), 'gauge:123.4|g\ngauge_other:456.78|g') + self.assertEqual(self.recv(), "gauge:123.4|g\ngauge_other:456.78|g") def test_timed_elapsed(self): - with self.statsd.timed('test_timer') as t: + with self.statsd.timed("test_timer") as t: pass self.assertGreaterEqual(t.elapsed, 0) - self.assertEqual(self.recv(), 'test_timer:%s|ms' % t.elapsed) + self.assertEqual(self.recv(), "test_timer:%s|ms" % t.elapsed) diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py index 7c7f189..ab432e6 100644 --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -1,169 +1,171 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import pytest import shutil from swh.core import tarball @pytest.fixture def prepare_shutil_state(): """Reset any shutil modification in its current state """ import shutil registered_formats = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: name = format_id[0] if name in registered_formats: shutil.unregister_unpack_format(name) return shutil def test_compress_uncompress_zip(tmp_path): - tocompress = tmp_path / 'compressme' + tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): - fpath = tocompress / ('file%s.txt' % i) - fpath.write_text('content of file %s' % i) + fpath = tocompress / ("file%s.txt" % i) + fpath.write_text("content of file %s" % i) - zipfile = tmp_path / 'archive.zip' - tarball.compress(str(zipfile), 'zip', str(tocompress)) + zipfile = tmp_path / "archive.zip" + tarball.compress(str(zipfile), "zip", str(tocompress)) - destdir = tmp_path / 'destdir' + destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) - assert ['file%s.txt' % i for i in range(10)] == lsdir + assert ["file%s.txt" % i for i in range(10)] == lsdir def test_compress_uncompress_tar(tmp_path): - tocompress = tmp_path / 'compressme' + tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): - fpath = tocompress / ('file%s.txt' % i) - fpath.write_text('content of file %s' % i) + fpath = tocompress / ("file%s.txt" % i) + fpath.write_text("content of file %s" % i) - tarfile = tmp_path / 'archive.tar' - tarball.compress(str(tarfile), 'tar', str(tocompress)) + tarfile = tmp_path / "archive.tar" + tarball.compress(str(tarfile), "tar", str(tocompress)) - destdir = tmp_path / 'destdir' + destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) - assert ['file%s.txt' % i for i in range(10)] == lsdir + assert ["file%s.txt" % i for i in range(10)] == lsdir def test__unpack_tar_failure(tmp_path, datadir): """Unpack inexistent tarball should fail """ - tarpath = os.path.join(datadir, 'archives', 'inexistent-archive.tar.Z') + tarpath = os.path.join(datadir, "archives", "inexistent-archive.tar.Z") assert not os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, tmp_path) def test__unpack_tar_failure2(tmp_path, datadir): """Unpack Existent tarball into an inexistent folder should fail """ - filename = 'groff-1.02.tar.Z' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "groff-1.02.tar.Z" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) - extract_dir = os.path.join(tmp_path, 'dir', 'inexistent') + extract_dir = os.path.join(tmp_path, "dir", "inexistent") - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, extract_dir) def test__unpack_tar_failure3(tmp_path, datadir): """Unpack unsupported tarball should fail """ - filename = 'hello.zip' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "hello.zip" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, tmp_path) def test__unpack_tar(tmp_path, datadir): """Unpack supported tarball into an existent folder should be ok """ - filename = 'groff-1.02.tar.Z' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "groff-1.02.tar.Z" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) extract_dir = os.path.join(tmp_path, filename) os.makedirs(extract_dir, exist_ok=True) output_directory = tarball._unpack_tar(tarpath, extract_dir) assert extract_dir == output_directory assert len(os.listdir(extract_dir)) > 0 def test_register_new_archive_formats(prepare_shutil_state): """Registering new archive formats should be fine """ unpack_formats_v1 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] not in unpack_formats_v1 # when tarball.register_new_archive_formats() # then unpack_formats_v2 = [f[0] for f in shutil.get_unpack_formats()] for format_id in tarball.ADDITIONAL_ARCHIVE_FORMATS: assert format_id[0] in unpack_formats_v2 def test_uncompress_tarpaths(tmp_path, datadir, prepare_shutil_state): """High level call uncompression on un/supported tarballs """ - archive_dir = os.path.join(datadir, 'archives') + archive_dir = os.path.join(datadir, "archives") tarfiles = os.listdir(archive_dir) tarpaths = [os.path.join(archive_dir, tarfile) for tarfile in tarfiles] unsupported_tarpaths = [] for t in tarpaths: - if t.endswith('.Z') or t.endswith('.x') or t.endswith('.lz'): + if t.endswith(".Z") or t.endswith(".x") or t.endswith(".lz"): unsupported_tarpaths.append(t) # not supported yet for tarpath in unsupported_tarpaths: - with pytest.raises(ValueError, - match=f'Problem during unpacking {tarpath}.'): + with pytest.raises(ValueError, match=f"Problem during unpacking {tarpath}."): tarball.uncompress(tarpath, dest=tmp_path) # register those unsupported formats tarball.register_new_archive_formats() # unsupported formats are now supported for n, tarpath in enumerate(tarpaths, start=1): tarball.uncompress(tarpath, dest=tmp_path) assert n == len(tarpaths) diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py index 0b8b6e9..f84c34a 100644 --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -1,140 +1,134 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from swh.core import utils class UtilsLib(unittest.TestCase): - def test_grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[0, 1], [2, 3], [4, 5], [6, 7], [8]]) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) def test_grouper_with_stop_value(self): # given - actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2) + actual_data = utils.grouper(((i, i + 1) for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks - self.assertEqual(out, [ - [(0, 1), (1, 2)], - [(2, 3), (3, 4)], - [(4, 5), (5, 6)], - [(6, 7), (7, 8)], - [(8, 9)]]) + self.assertEqual( + out, + [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)], + ], + ) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) def test_backslashescape_errors(self): - raw_data_err = b'abcd\x80' + raw_data_err = b"abcd\x80" with self.assertRaises(UnicodeDecodeError): - raw_data_err.decode('utf-8', 'strict') + raw_data_err.decode("utf-8", "strict") self.assertEqual( - raw_data_err.decode('utf-8', 'backslashescape'), - 'abcd\\x80', + raw_data_err.decode("utf-8", "backslashescape"), "abcd\\x80", ) - raw_data_ok = b'abcd\xc3\xa9' + raw_data_ok = b"abcd\xc3\xa9" self.assertEqual( - raw_data_ok.decode('utf-8', 'backslashescape'), - raw_data_ok.decode('utf-8', 'strict'), + raw_data_ok.decode("utf-8", "backslashescape"), + raw_data_ok.decode("utf-8", "strict"), ) - unicode_data = 'abcdef\u00a3' + unicode_data = "abcdef\u00a3" self.assertEqual( - unicode_data.encode('ascii', 'backslashescape'), - b'abcdef\\xa3', + unicode_data.encode("ascii", "backslashescape"), b"abcdef\\xa3", ) def test_encode_with_unescape(self): - valid_data = '\\x01020304\\x00' - valid_data_encoded = b'\x01020304\x00' + valid_data = "\\x01020304\\x00" + valid_data_encoded = b"\x01020304\x00" - self.assertEqual( - valid_data_encoded, - utils.encode_with_unescape(valid_data) - ) + self.assertEqual(valid_data_encoded, utils.encode_with_unescape(valid_data)) def test_encode_with_unescape_invalid_escape(self): - invalid_data = 'test\\abcd' + invalid_data = "test\\abcd" with self.assertRaises(ValueError) as exc: utils.encode_with_unescape(invalid_data) - self.assertIn('invalid escape', exc.exception.args[0]) - self.assertIn('position 4', exc.exception.args[0]) + self.assertIn("invalid escape", exc.exception.args[0]) + self.assertIn("position 4", exc.exception.args[0]) def test_decode_with_escape(self): - backslashes = b'foo\\bar\\\\baz' - backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' + backslashes = b"foo\\bar\\\\baz" + backslashes_escaped = "foo\\\\bar\\\\\\\\baz" self.assertEqual( - backslashes_escaped, - utils.decode_with_escape(backslashes), + backslashes_escaped, utils.decode_with_escape(backslashes), ) - valid_utf8 = b'foo\xc3\xa2' - valid_utf8_escaped = 'foo\u00e2' + valid_utf8 = b"foo\xc3\xa2" + valid_utf8_escaped = "foo\u00e2" self.assertEqual( - valid_utf8_escaped, - utils.decode_with_escape(valid_utf8), + valid_utf8_escaped, utils.decode_with_escape(valid_utf8), ) - invalid_utf8 = b'foo\xa2' - invalid_utf8_escaped = 'foo\\xa2' + invalid_utf8 = b"foo\xa2" + invalid_utf8_escaped = "foo\\xa2" self.assertEqual( - invalid_utf8_escaped, - utils.decode_with_escape(invalid_utf8), + invalid_utf8_escaped, utils.decode_with_escape(invalid_utf8), ) - valid_utf8_nul = b'foo\xc3\xa2\x00' - valid_utf8_nul_escaped = 'foo\u00e2\\x00' + valid_utf8_nul = b"foo\xc3\xa2\x00" + valid_utf8_nul_escaped = "foo\u00e2\\x00" self.assertEqual( - valid_utf8_nul_escaped, - utils.decode_with_escape(valid_utf8_nul), + valid_utf8_nul_escaped, utils.decode_with_escape(valid_utf8_nul), ) def test_commonname(self): # when - actual_commonname = utils.commonname('/some/where/to/', - '/some/where/to/go/to') + actual_commonname = utils.commonname("/some/where/to/", "/some/where/to/go/to") # then - self.assertEqual('go/to', actual_commonname) + self.assertEqual("go/to", actual_commonname) # when - actual_commonname2 = utils.commonname(b'/some/where/to/', - b'/some/where/to/go/to') + actual_commonname2 = utils.commonname( + b"/some/where/to/", b"/some/where/to/go/to" + ) # then - self.assertEqual(b'go/to', actual_commonname2) + self.assertEqual(b"go/to", actual_commonname2) diff --git a/swh/core/utils.py b/swh/core/utils.py index 5a32d95..0b6fcbd 100644 --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -1,124 +1,123 @@ # Copyright (C) 2016-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import itertools import codecs import re from contextlib import contextmanager @contextmanager def cwd(path): """Contextually change the working directory to do thy bidding. Then gets back to the original location. """ prev_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(prev_cwd) def grouper(iterable, n): """Collect data into fixed-length size iterables. The last block might contain less elements as it will hold only the remaining number of elements. The invariant here is that the number of elements in the input iterable and the sum of the number of elements of all iterables generated from this function should be equal. Args: iterable (Iterable): an iterable n (int): size of block to slice the iterable into Yields: fixed-length blocks as iterables. As mentioned, the last iterable might be less populated. """ args = [iter(iterable)] * n stop_value = object() for _data in itertools.zip_longest(*args, fillvalue=stop_value): yield (d for d in _data if d is not stop_value) def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): - bad_data = exception.object[exception.start:exception.end] - escaped = ''.join(r'\x%02x' % x for x in bad_data) + bad_data = exception.object[exception.start : exception.end] + escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) -codecs.register_error('backslashescape', backslashescape_errors) +codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): """Encode an unicode string containing \\x backslash escapes""" slices = [] start = 0 odd_backslashes = False i = 0 while i < len(value): - if value[i] == '\\': + if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: - if value[i] != 'x': - raise ValueError('invalid escape for %r at position %d' % - (value, i-1)) + if value[i] != "x": + raise ValueError( + "invalid escape for %r at position %d" % (value, i - 1) + ) slices.append( - value[start:i-1].replace('\\\\', '\\').encode('utf-8') + value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) - slices.append(bytes.fromhex(value[i+1:i+3])) + slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 continue i += 1 - slices.append( - value[start:i].replace('\\\\', '\\').encode('utf-8') - ) + slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) - return b''.join(slices) + return b"".join(slices) def decode_with_escape(value): """Decode a bytestring as utf-8, escaping the bytes of invalid utf-8 sequences as \\x. We also escape NUL bytes as they are invalid in JSON strings. """ # escape backslashes - value = value.replace(b'\\', b'\\\\') - value = value.replace(b'\x00', b'\\x00') - return value.decode('utf-8', 'backslashescape') + value = value.replace(b"\\", b"\\\\") + value = value.replace(b"\x00", b"\\x00") + return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): """Compute the commonname between the path0 and path1. """ return path1.split(path0)[1] def numfile_sortkey(fname): """Simple function to sort filenames of the form: nnxxx.ext where nn is a number according to the numbers. Typically used to sort sql/nn-swh-xxx.sql files. """ - num, rem = re.match(r'(\d*)(.*)', fname).groups() + num, rem = re.match(r"(\d*)(.*)", fname).groups() return (num and int(num) or 99, rem) diff --git a/tox.ini b/tox.ini index 7a3ffcc..f87c7ee 100644 --- a/tox.ini +++ b/tox.ini @@ -1,45 +1,52 @@ [tox] -envlist=flake8,mypy,py3-{core,db,server} +envlist=black,flake8,mypy,py3-{core,db,server} [testenv] extras = testing-core core: logging db: db, testing-db server: http deps = db: pifpaf cover: pytest-cov commands = db: pifpaf run postgresql -- \ pytest --doctest-modules \ slow: --hypothesis-profile=slow \ cover: --cov={envsitepackagesdir}/swh/core --cov-branch \ core: {envsitepackagesdir}/swh/core/tests \ db: {envsitepackagesdir}/swh/core/db/tests \ server: {envsitepackagesdir}/swh/core/api/tests \ {posargs} [testenv:py3] skip_install = true deps = tox commands = tox -e py3-core-db-server-slow-cover -- {posargs} +[testenv:black] +skip_install = true +deps = + black +commands = + {envpython} -m black --check swh + [testenv:flake8] skip_install = true deps = flake8 commands = {envpython} -m flake8 [testenv:mypy] extras = testing-core logging db testing-db http deps = mypy commands = mypy swh