diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,11 @@ language: system types: [python] +- repo: https://github.com/python/black + rev: 19.10b0 + hooks: + - id: black + # unfortunately, we are far from being able to enable this... # - repo: https://github.com/PyCQA/pydocstyle.git # rev: 4.0.0 @@ -33,12 +38,6 @@ # language: python # types: [python] -# black requires py3.6+ -#- repo: https://github.com/python/black -# rev: 19.3b0 -# hooks: -# - id: black -# language_version: python3 #- repo: https://github.com/asottile/blacken-docs # rev: v1.0.0-1 # hooks: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +target-version = ['py37'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[flake8] +# E203: whitespaces before ':' +# W503: line break before binary operator +ignore = E203,W503 +max-line-length = 88 diff --git a/setup.py b/setup.py --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ here = path.abspath(path.dirname(__file__)) # Get the long description from the README file -with open(path.join(here, 'README.md'), encoding='utf-8') as f: +with open(path.join(here, "README.md"), encoding="utf-8") as f: long_description = f.read() @@ -21,9 +21,9 @@ requirements = [] for name in names: if name: - reqf = 'requirements-%s.txt' % name + reqf = "requirements-%s.txt" % name else: - reqf = 'requirements.txt' + reqf = "requirements.txt" if not os.path.exists(reqf): return requirements @@ -31,38 +31,37 @@ with open(reqf) as f: for line in f.readlines(): line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue requirements.append(line) return requirements setup( - name='swh.core', - description='Software Heritage core utilities', + name="swh.core", + description="Software Heritage core utilities", long_description=long_description, - long_description_content_type='text/markdown', - author='Software Heritage developers', - author_email='swh-devel@inria.fr', - url='https://forge.softwareheritage.org/diffusion/DCORE/', + long_description_content_type="text/markdown", + author="Software Heritage developers", + author_email="swh-devel@inria.fr", + url="https://forge.softwareheritage.org/diffusion/DCORE/", packages=find_packages(), - py_modules=['pytest_swh_core'], + py_modules=["pytest_swh_core"], scripts=[], - install_requires=parse_requirements(None, 'swh'), - setup_requires=['vcversioner'], + install_requires=parse_requirements(None, "swh"), + setup_requires=["vcversioner"], extras_require={ - 'testing-core': parse_requirements('test'), - 'logging': parse_requirements('logging'), - 'db': parse_requirements('db'), - 'testing-db': parse_requirements('test-db'), - 'http': parse_requirements('http'), + "testing-core": parse_requirements("test"), + "logging": parse_requirements("logging"), + "db": parse_requirements("db"), + "testing-db": parse_requirements("test-db"), + "http": parse_requirements("http"), # kitchen sink, please do not use - 'testing': parse_requirements('test', 'test-db', 'db', 'http', - 'logging'), + "testing": parse_requirements("test", "test-db", "db", "http", "logging"), }, vcversioner={}, include_package_data=True, - entry_points=''' + entry_points=""" [console_scripts] swh=swh.core.cli:main swh-db-init=swh.core.cli.db:db_init @@ -71,7 +70,7 @@ db-init=swh.core.cli.db:db_init [pytest11] pytest_swh_core = swh.core.pytest_plugin - ''', + """, classifiers=[ "Programming Language :: Python :: 3", "Intended Audience :: Developers", @@ -80,8 +79,8 @@ "Development Status :: 5 - Production/Stable", ], project_urls={ - 'Bug Reports': 'https://forge.softwareheritage.org/maniphest', - 'Funding': 'https://www.softwareheritage.org/donate', - 'Source': 'https://forge.softwareheritage.org/source/swh-core', + "Bug Reports": "https://forge.softwareheritage.org/maniphest", + "Funding": "https://www.softwareheritage.org/donate", + "Source": "https://forge.softwareheritage.org/source/swh-core", }, ) diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -10,22 +10,26 @@ import pickle import requests -from typing import ( - Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, -) +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException -from .serializers import (decode_response, - encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, - json_dumps, json_loads, - exception_to_dict) +from .serializers import ( + decode_response, + encode_data_client as encode_data, + msgpack_dumps, + msgpack_loads, + json_dumps, + json_loads, + exception_to_dict, +) -from .negotiation import (Formatter as FormatterBase, - Negotiator as NegotiatorBase, - negotiate as _negotiate) +from .negotiation import ( + Formatter as FormatterBase, + Negotiator as NegotiatorBase, + negotiate as _negotiate, +) logger = logging.getLogger(__name__) @@ -33,10 +37,12 @@ # support for content negotiation + class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( - self.accept_mimetypes, 'application/json') + self.accept_mimetypes, "application/json" + ) def _abort(self, status_code, err=None): return abort(status_code, err) @@ -55,16 +61,16 @@ class JSONFormatter(Formatter): - format = 'json' - mimetypes = ['application/json'] + format = "json" + mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): - format = 'msgpack' - mimetypes = ['application/x-msgpack'] + format = "msgpack" + mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) @@ -72,6 +78,7 @@ # base API classes + class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception @@ -80,8 +87,12 @@ response: HTTP response corresponding to the failure """ - def __init__(self, payload: Optional[Any] = None, - response: Optional[requests.Response] = None): + + def __init__( + self, + payload: Optional[Any] = None, + response: Optional[requests.Response] = None, + ): if payload is not None: super().__init__(payload) else: @@ -89,11 +100,16 @@ self.response = response def __str__(self): - if self.args and isinstance(self.args[0], dict) \ - and 'type' in self.args[0] and 'args' in self.args[0]: + if ( + self.args + and isinstance(self.args[0], dict) + and "type" in self.args[0] + and "args" in self.args[0] + ): return ( - f'') + f"' + ) else: return super().__str__() @@ -102,14 +118,15 @@ def dec(f): f._endpoint_path = path return f + return dec class APIError(Exception): """API Error""" + def __str__(self): - return ('An unexpected error occurred in the backend: {}' - .format(self.args)) + return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): @@ -117,6 +134,7 @@ of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" + def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new @@ -124,14 +142,14 @@ # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. - backend_class = attributes.get('backend_class', None) + backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break - backend_class = getattr(base, 'backend_class', None) + backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): - if hasattr(meth, '_endpoint_path'): + if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @@ -142,16 +160,16 @@ @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters - post_data = inspect.getcallargs( - wrapped_meth, *args, **kwargs) + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed - self = post_data.pop('self') - post_data.pop('cur', None) - post_data.pop('db', None) + self = post_data.pop("self") + post_data.pop("cur", None) + post_data.pop("db", None) # Send the request. return self.post(meth._endpoint_path, post_data) + if meth_name not in attributes: attributes[meth_name] = meth_ @@ -186,40 +204,44 @@ """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, url, api_exception=None, - timeout=None, chunk_size=4096, - reraise_exceptions=None, **kwargs): + def __init__( + self, + url, + api_exception=None, + timeout=None, + chunk_size=4096, + reraise_exceptions=None, + **kwargs, + ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions - base_url = url if url.endswith('/') else url + '/' + base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( - max_retries=kwargs.get('max_retries', 3), - pool_connections=kwargs.get('pool_connections', 20), - pool_maxsize=kwargs.get('pool_maxsize', 100)) + max_retries=kwargs.get("max_retries", 3), + pool_connections=kwargs.get("pool_connections", 20), + pool_maxsize=kwargs.get("pool_maxsize", 100), + ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): - return '%s%s' % (self.url, endpoint) + return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): - if 'chunk_size' in opts: + if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. - opts['stream'] = True - if self.timeout and 'timeout' not in opts: - opts['timeout'] = self.timeout + opts["stream"] = True + if self.timeout and "timeout" not in opts: + opts["timeout"] = self.timeout try: - return getattr(self.session, verb)( - self._url(endpoint), - **opts - ) + return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) @@ -228,14 +250,18 @@ data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) - chunk_size = opts.pop('chunk_size', self.chunk_size) + chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( - 'post', endpoint, data=data, - headers={'content-type': 'application/x-msgpack', - 'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + "post", + endpoint, + data=data, + headers={ + "content-type": "application/x-msgpack", + "accept": "application/x-msgpack", + }, + **opts, + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: @@ -247,13 +273,11 @@ post_stream = post def get(self, endpoint, **opts): - chunk_size = opts.pop('chunk_size', self.chunk_size) + chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( - 'get', endpoint, - headers={'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: @@ -271,7 +295,7 @@ status_class = response.status_code // 100 if status_code == 404: - raise RemoteException(payload='404 not found', response=response) + raise RemoteException(payload="404 not found", response=response) exception = None @@ -282,22 +306,24 @@ data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: - if exc_type.__name__ == data['exception']['type']: - exception = exc_type(*data['exception']['args']) + if exc_type.__name__ == data["exception"]["type"]: + exception = exc_type(*data["exception"]["args"]) break else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data["exception"], response=response + ) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) - if 'exception_pickled' in data: - exception = pickle.loads(data['exception_pickled']) + if "exception_pickled" in data: + exception = pickle.loads(data["exception_pickled"]) else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data["exception"], response=response + ) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) @@ -307,39 +333,37 @@ if status_class != 2: raise RemoteException( - payload=f'API HTTP error: {status_code} {response.content}', - response=response) + payload=f"API HTTP error: {status_code} {response.content}", + response=response, + ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) - return decode_response( - response, extra_decoders=self.extra_type_decoders) + return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): - return '<{} url={}>'.format(self.__class__.__name__, self.url) + return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" - encoding = 'utf-8' - encoding_errors = 'surrogateescape' + + encoding = "utf-8" + encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { - 'application/x-msgpack': msgpack_dumps, - 'application/json': json_dumps, + "application/x-msgpack": msgpack_dumps, + "application/json": json_dumps, } def encode_data_server( - data, content_type='application/x-msgpack', extra_type_encoders=None): - encoded_data = ENCODERS[content_type]( - data, extra_encoders=extra_type_encoders) - return Response( - encoded_data, - mimetype=content_type, - ) + data, content_type="application/x-msgpack", extra_type_encoders=None +): + encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) + return Response(encoded_data, mimetype=content_type) def decode_request(request, extra_decoders=None): @@ -348,15 +372,14 @@ if not data: return {} - if content_type == 'application/x-msgpack': + if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) - elif content_type == 'application/json': + elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) + r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError("Wrong content type `%s` for API request" % content_type) return r @@ -385,6 +408,7 @@ `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ + request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] @@ -394,8 +418,7 @@ """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, *args, backend_class=None, backend_factory=None, - **kwargs): + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class @@ -403,19 +426,18 @@ if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): - if hasattr(meth, '_endpoint_path'): + if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request - @self.route('/'+meth._endpoint_path, methods=['POST']) + @self.route("/" + meth._endpoint_path, methods=["POST"]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request( - request, extra_decoders=self.extra_type_decoders) + kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -21,9 +21,8 @@ def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), - headers=multidict.MultiDict( - {'Content-Type': 'application/x-msgpack'}), - **kwargs + headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), + **kwargs, ) @@ -39,17 +38,16 @@ async def decode_request(request): - content_type = request.headers.get('Content-Type').split(';')[0].strip() + content_type = request.headers.get("Content-Type").split(";")[0].strip() data = await request.read() if not data: return {} - if content_type == 'application/x-msgpack': + if content_type == "application/x-msgpack": r = msgpack_loads(data) - elif content_type == 'application/json': + elif content_type == "application/json": r = json_loads(data) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError("Wrong content type `%s` for API request" % content_type) return r @@ -67,6 +65,7 @@ else: status = 500 return encode_data_server(res, status=status) + return middleware_handler @@ -79,19 +78,20 @@ middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. - renderers = OrderedDict([ - ('application/json', render_json), - ('application/x-msgpack', render_msgpack), - ]) + renderers = OrderedDict( + [ + ("application/json", render_json), + ("application/x-msgpack", render_msgpack), + ] + ) nego_middleware = negotiation.negotiation_middleware( - renderers=renderers, - force_rendering=True) + renderers=renderers, force_rendering=True + ) middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) -@deprecated(version='0.0.64', - reason='Use the RPCServerApp instead') +@deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/gunicorn_config.py b/swh/core/api/gunicorn_config.py --- a/swh/core/api/gunicorn_config.py +++ b/swh/core/api/gunicorn_config.py @@ -16,15 +16,24 @@ def post_fork( - server, worker, *, default_sentry_dsn=None, flask=True, - sentry_integrations=None, extra_sentry_kwargs={}): + server, + worker, + *, + default_sentry_dsn=None, + flask=True, + sentry_integrations=None, + extra_sentry_kwargs={}, +): # Initializes sentry as soon as possible in gunicorn's worker processes. sentry_integrations = sentry_integrations or [] if flask: from sentry_sdk.integrations.flask import FlaskIntegration + sentry_integrations.append(FlaskIntegration()) init_sentry( - default_sentry_dsn, integrations=sentry_integrations, - extra_kwargs=extra_sentry_kwargs) + default_sentry_dsn, + integrations=sentry_integrations, + extra_kwargs=extra_sentry_kwargs, + ) diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -27,8 +27,7 @@ from decorator import decorator from inspect import getcallargs -from typing import Any, List, Optional, Callable, \ - Type, NoReturn, DefaultDict +from typing import Any, List, Optional, Callable, Type, NoReturn, DefaultDict from requests import Response @@ -47,8 +46,8 @@ self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( - "%s.mimetypes should be a non-empty list" % - self.__class__.__name__) + "%s.mimetypes should be a non-empty list" % self.__class__.__name__ + ) else: self.response_mimetype = request_mimetype @@ -57,11 +56,13 @@ def render(self, obj: Any) -> bytes: raise NotImplementedError( - "render() should be implemented by Formatter subclasses") + "render() should be implemented by Formatter subclasses" + ) def __call__(self, obj: Any) -> Response: return self._make_response( - self.render(obj), content_type=self.response_mimetype) + self.render(obj), content_type=self.response_mimetype + ) def _make_response(self, body: bytes, content_type: str) -> Response: raise NotImplementedError( @@ -71,7 +72,6 @@ class Negotiator: - def __init__(self, func: Callable[..., Any]) -> None: self.func = func self._formatters: List[Type[Formatter]] = [] @@ -80,7 +80,7 @@ def __call__(self, *args, **kwargs) -> Response: result = self.func(*args, **kwargs) - format = getcallargs(self.func, *args, **kwargs).get('format') + format = getcallargs(self.func, *args, **kwargs).get("format") mimetype = self.best_mimetype() try: @@ -90,38 +90,35 @@ return formatter(result) - def register_formatter(self, formatter: Type[Formatter], - *args, **kwargs) -> None: + def register_formatter(self, formatter: Type[Formatter], *args, **kwargs) -> None: self._formatters.append(formatter) - self._formatters_by_format[formatter.format].append( - (formatter, args, kwargs)) + self._formatters_by_format[formatter.format].append((formatter, args, kwargs)) for mimetype in formatter.mimetypes: - self._formatters_by_mimetype[mimetype].append( - (formatter, args, kwargs)) + self._formatters_by_mimetype[mimetype].append((formatter, args, kwargs)) - def get_formatter(self, format: Optional[str] = None, - mimetype: Optional[str] = None) -> Formatter: + def get_formatter( + self, format: Optional[str] = None, mimetype: Optional[str] = None + ) -> Formatter: if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " - "kwargs to be set") + "kwargs to be set" + ) if format is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_format[format][0]) + formatter_cls, args, kwargs = self._formatters_by_format[format][0] except IndexError: - raise FormatterNotFound( - "Formatter for format '%s' not found!" % format) + raise FormatterNotFound("Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_mimetype[mimetype][0]) + formatter_cls, args, kwargs = self._formatters_by_mimetype[mimetype][0] except IndexError: raise FormatterNotFound( - "Formatter for mimetype '%s' not found!" % mimetype) + "Formatter for mimetype '%s' not found!" % mimetype + ) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) @@ -144,13 +141,14 @@ ) -def negotiate(negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], - *args, **kwargs) -> Callable: +def negotiate( + negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], *args, **kwargs +) -> Callable: def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): - if not hasattr(f, 'negotiator'): + if not hasattr(f, "negotiator"): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -19,27 +19,29 @@ ENCODERS = [ - (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), - (datetime.datetime, 'datetime', datetime.datetime.isoformat), - (datetime.timedelta, 'timedelta', lambda o: { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }), - (UUID, 'uuid', str), - + (arrow.Arrow, "arrow", arrow.Arrow.isoformat), + (datetime.datetime, "datetime", datetime.datetime.isoformat), + ( + datetime.timedelta, + "timedelta", + lambda o: { + "days": o.days, + "seconds": o.seconds, + "microseconds": o.microseconds, + }, + ), + (UUID, "uuid", str), # Only for JSON: - (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), + (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS = { - 'arrow': arrow.get, - 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), - 'timedelta': lambda d: datetime.timedelta(**d), - 'uuid': UUID, - + "arrow": arrow.get, + "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), + "timedelta": lambda d: datetime.timedelta(**d), + "uuid": UUID, # Only for JSON: - 'bytes': base64.b85decode, + "bytes": base64.b85decode, } @@ -47,24 +49,20 @@ try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: - raise ValueError('Limits were reached. Please, check your input.\n' + - str(e)) + raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: - content_type = response.headers['content-type'] - - if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content, - extra_decoders=extra_decoders) - elif content_type.startswith('application/json'): - r = json_loads(response.text, - extra_decoders=extra_decoders) - elif content_type.startswith('text/'): + content_type = response.headers["content-type"] + + if content_type.startswith("application/x-msgpack"): + r = msgpack_loads(response.content, extra_decoders=extra_decoders) + elif content_type.startswith("application/json"): + r = json_loads(response.text, extra_decoders=extra_decoders) + elif content_type.startswith("text/"): r = response.text else: - raise ValueError('Wrong content type `%s` for API response' - % content_type) + raise ValueError("Wrong content type `%s` for API response" % content_type) return r @@ -99,14 +97,10 @@ if extra_encoders: self.encoders += extra_encoders - def default(self, o: Any - ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: + def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): - return { - 'swhtype': type_name, - 'd': encoder(o), - } + return {"swhtype": type_name, "d": encoder(o)} try: return super().default(o) except TypeError as e: @@ -146,12 +140,12 @@ def decode_data(self, o: Any) -> Any: if isinstance(o, dict): - if set(o.keys()) == {'d', 'swhtype'}: - if o['swhtype'] == 'bytes': - return base64.b85decode(o['d']) - decoder = self.decoders.get(o['swhtype']) + if set(o.keys()) == {"d", "swhtype"}: + if o["swhtype"] == "bytes": + return base64.b85decode(o["d"]) + decoder = self.decoders.get(o["swhtype"]) if decoder: - return decoder(self.decode_data(o['d'])) + return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] @@ -164,13 +158,11 @@ def json_dumps(data: Any, extra_encoders=None) -> str: - return json.dumps(data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: - return json.loads(data, cls=SWHJSONDecoder, - extra_decoders=extra_decoders) + return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: @@ -185,10 +177,7 @@ for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): - return { - b'swhtype': type_name, - b'd': encoder(obj), - } + return {b"swhtype": type_name, b"d": encoder(obj)} return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) @@ -202,44 +191,42 @@ def decode_types(obj): # Support for current encodings - if set(obj.keys()) == {b'd', b'swhtype'}: - decoder = decoders.get(obj[b'swhtype']) + if set(obj.keys()) == {b"d", b"swhtype"}: + decoder = decoders.get(obj[b"swhtype"]) if decoder: - return decoder(obj[b'd']) + return decoder(obj[b"d"]) # Support for legacy encodings - if b'__datetime__' in obj and obj[b'__datetime__']: - return iso8601.parse_date(obj[b's'], default_timezone=None) - if b'__uuid__' in obj and obj[b'__uuid__']: - return UUID(obj[b's']) - if b'__timedelta__' in obj and obj[b'__timedelta__']: - return datetime.timedelta(**obj[b's']) - if b'__arrow__' in obj and obj[b'__arrow__']: - return arrow.get(obj[b's']) + if b"__datetime__" in obj and obj[b"__datetime__"]: + return iso8601.parse_date(obj[b"s"], default_timezone=None) + if b"__uuid__" in obj and obj[b"__uuid__"]: + return UUID(obj[b"s"]) + if b"__timedelta__" in obj and obj[b"__timedelta__"]: + return datetime.timedelta(**obj[b"s"]) + if b"__arrow__" in obj and obj[b"__arrow__"]: + return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: - return msgpack.unpackb(data, raw=False, - object_hook=decode_types, - strict_map_key=False) + return msgpack.unpackb( + data, raw=False, object_hook=decode_types, strict_map_key=False + ) except TypeError: # msgpack < 0.6.0 - return msgpack.unpackb(data, raw=False, - object_hook=decode_types) + return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 - return msgpack.unpackb(data, encoding='utf-8', - object_hook=decode_types) + return msgpack.unpackb(data, encoding="utf-8", object_hook=decode_types) def exception_to_dict(exception): tb = traceback.format_exception(None, exception, exception.__traceback__) return { - 'exception': { - 'type': type(exception).__name__, - 'args': exception.args, - 'message': str(exception), - 'traceback': tb, + "exception": { + "type": type(exception).__name__, + "args": exception.args, + "message": str(exception), + "traceback": tb, } } diff --git a/swh/core/api/tests/server_testing.py b/swh/core/api/tests/server_testing.py --- a/swh/core/api/tests/server_testing.py +++ b/swh/core/api/tests/server_testing.py @@ -29,6 +29,7 @@ class's setUp() and tearDown() methods. """ + def setUp(self): super().setUp() self.start_server() @@ -38,7 +39,7 @@ super().tearDown() def url(self): - return 'http://127.0.0.1:%d/' % self.port + return "http://127.0.0.1:%d/" % self.port def process_config(self): """Process the server's configuration. Do something useful for @@ -106,6 +107,7 @@ In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ + def process_config(self): # WSGI app configuration for key, value in self.config.items(): @@ -114,7 +116,7 @@ def define_worker_function(self): def worker(app, port): # Make Flask 1.0 stop printing its server banner - os.environ['WERKZEUG_RUN_MAIN'] = 'true' + os.environ["WERKZEUG_RUN_MAIN"] = "true" return app.run(port=port, use_reloader=False) return worker @@ -136,9 +138,9 @@ class's setUp() and tearDown() methods. """ + def define_worker_function(self): def worker(app, port): - return aiohttp.web.run_app(app, port=int(port), - print=lambda *_: None) + return aiohttp.web.run_app(app, port=int(port), print=lambda *_: None) return worker diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -15,7 +15,7 @@ from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder -pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] +pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): @@ -27,19 +27,19 @@ async def root(request): - return Response('toor') + return Response("toor") -STRUCT = {'txt': 'something stupid', - # 'date': datetime.date(2019, 6, 9), # not supported - 'datetime': datetime.datetime(2019, 6, 9, 10, 12), - 'timedelta': datetime.timedelta(days=-2, hours=3), - 'int': 42, - 'float': 3.14, - 'subdata': {'int': 42, - 'datetime': datetime.datetime(2019, 6, 10, 11, 12), - }, - 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], - } + +STRUCT = { + "txt": "something stupid", + # 'date': datetime.date(2019, 6, 9), # not supported + "datetime": datetime.datetime(2019, 6, 9, 10, 12), + "timedelta": datetime.timedelta(days=-2, hours=3), + "int": 42, + "float": 3.14, + "subdata": {"int": 42, "datetime": datetime.datetime(2019, 6, 10, 11, 12)}, + "list": [42, datetime.datetime(2019, 9, 10, 11, 12), "ok"], +} async def struct(request): @@ -67,8 +67,8 @@ def check_mimetype(src, dst): - src = src.split(';')[0].strip() - dst = dst.split(';')[0].strip() + src = src.split(";")[0].strip() + dst = dst.split(";")[0].strip() assert src == dst @@ -76,12 +76,12 @@ def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) - app.router.add_route('GET', '/', root) - app.router.add_route('GET', '/struct', struct) - app.router.add_route('POST', '/echo', echo) - app.router.add_route('GET', '/server_exception', server_exception) - app.router.add_route('GET', '/client_error', client_error) - app.router.add_route('POST', '/echo-no-nego', echo_no_nego) + app.router.add_route("GET", "/", root) + app.router.add_route("GET", "/struct", struct) + app.router.add_route("POST", "/echo", echo) + app.router.add_route("GET", "/server_exception", server_exception) + app.router.add_route("GET", "/client_error", client_error) + app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app @@ -89,58 +89,57 @@ assert async_app is not None cli = await aiohttp_client(async_app) - resp = await cli.get('/') + resp = await cli.get("/") assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) - assert value == 'toor' + assert value == "toor" async def test_get_server_exception(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - resp = await cli.get('/server_exception') + resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data['exception']['type'] == 'TestServerException' + assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - resp = await cli.get('/client_error') + resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data['exception']['type'] == 'TestClientError' + assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): - resp = await cli.get('/', headers={'Accept': 'application/%s' % ctype}) + for ctype in ("x-msgpack", "json"): + resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) - assert (await decode_request(resp)) == 'toor' + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) + assert (await decode_request(resp)) == "toor" async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) - resp = await cli.get('/struct') + resp = await cli.get("/struct") assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): - resp = await cli.get('/struct', - headers={'Accept': 'application/%s' % ctype}) + for ctype in ("x-msgpack", "json"): + resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT @@ -149,19 +148,21 @@ cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps({'toto': 42})) + "/echo", + headers={"Content-Type": "application/x-msgpack"}, + data=msgpack_dumps({"toto": 42}), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') - assert (await decode_request(resp)) == {'toto': 42} + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") + assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps(STRUCT)) + "/echo", + headers={"Content-Type": "application/x-msgpack"}, + data=msgpack_dumps(STRUCT), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT @@ -170,19 +171,21 @@ cli = await aiohttp_client(async_app) resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json'}, - data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) + "/echo", + headers={"Content-Type": "application/json"}, + data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') - assert (await decode_request(resp)) == {'toto': 42} + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") + assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json'}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo", + headers={"Content-Type": "application/json"}, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT @@ -194,14 +197,17 @@ """ cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): + for ctype in ("x-msgpack", "json"): resp = await cli.post( - '/echo', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo", + headers={ + "Content-Type": "application/json", + "Accept": "application/%s" % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) + check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT @@ -212,12 +218,15 @@ """ cli = await aiohttp_client(async_app) - for ctype in ('x-msgpack', 'json'): + for ctype in ("x-msgpack", "json"): resp = await cli.post( - '/echo-no-nego', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + "/echo-no-nego", + headers={ + "Content-Type": "application/json", + "Accept": "application/%s" % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 - check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') + check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_gunicorn.py b/swh/core/api/tests/test_gunicorn.py --- a/swh/core/api/tests/test_gunicorn.py +++ b/swh/core/api/tests/test_gunicorn.py @@ -11,7 +11,7 @@ def test_post_fork_default(): - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_not_called() @@ -19,14 +19,15 @@ def test_post_fork_with_dsn_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=False, release=None, @@ -36,36 +37,44 @@ def test_post_fork_with_package_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_ENVIRONMENT': 'tests', - 'SWH_MAIN_PACKAGE': 'swh.core'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict( + os.environ, + { + "SWH_SENTRY_DSN": "test_dsn", + "SWH_SENTRY_ENVIRONMENT": "tests", + "SWH_MAIN_PACKAGE": "swh.core", + }, + ): gunicorn_config.post_fork(None, None) - version = pkg_resources.get_distribution('swh.core').version + version = pkg_resources.get_distribution("swh.core").version sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=False, - release='swh.core@' + version, - environment='tests', + release="swh.core@" + version, + environment="tests", ) def test_post_fork_debug(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict( + os.environ, {"SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1"} + ): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", integrations=[flask_integration], debug=True, release=None, @@ -74,34 +83,34 @@ def test_post_fork_no_flask(): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork(None, None, flask=False) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - integrations=[], - debug=False, - release=None, - environment=None, + dsn="test_dsn", integrations=[], debug=False, release=None, environment=None ) def test_post_fork_extras(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): - with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): + with patch( + "sentry_sdk.integrations.flask.FlaskIntegration", new=lambda: flask_integration + ): + with patch("sentry_sdk.init") as sentry_sdk_init: + with patch.dict(os.environ, {"SWH_SENTRY_DSN": "test_dsn"}): gunicorn_config.post_fork( - None, None, sentry_integrations=['foo'], - extra_sentry_kwargs={'bar': 'baz'}) + None, + None, + sentry_integrations=["foo"], + extra_sentry_kwargs={"bar": "baz"}, + ) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - integrations=['foo', flask_integration], + dsn="test_dsn", + integrations=["foo", flask_integration], debug=False, - bar='baz', + bar="baz", release=None, environment=None, ) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -14,21 +14,21 @@ @pytest.fixture def rpc_client(requests_mock): class TestStorage: - @remote_api_endpoint('test_endpoint_url') + @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): ... - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): ... - @remote_api_endpoint('serializer_test') + @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): ... - @remote_api_endpoint('overridden/endpoint') + @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): - return 'foo' + return "foo" class Testclient(RPCClient): backend_class = TestStorage @@ -36,50 +36,50 @@ extra_type_decoders = extra_decoders def overridden_method(self, data): - return 'bar' + return "bar" def callback(request, context): - assert request.headers['Content-Type'] == 'application/x-msgpack' - context.headers['Content-Type'] = 'application/x-msgpack' - if request.path == '/test_endpoint_url': - context.content = b'\xa3egg' - elif request.path == '/path/to/endpoint': - context.content = b'\xa4spam' - elif request.path == '/serializer_test': + assert request.headers["Content-Type"] == "application/x-msgpack" + context.headers["Content-Type"] = "application/x-msgpack" + if request.path == "/test_endpoint_url": + context.content = b"\xa3egg" + elif request.path == "/path/to/endpoint": + context.content = b"\xa4spam" + elif request.path == "/serializer_test": context.content = ( - b'\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b"\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) else: assert False return context.content - requests_mock.post(re.compile('mock://example.com/'), - content=callback) + requests_mock.post(re.compile("mock://example.com/"), content=callback) - return Testclient(url='mock://example.com') + return Testclient(url="mock://example.com") def test_client(rpc_client): - assert hasattr(rpc_client, 'test_endpoint') - assert hasattr(rpc_client, 'something') + assert hasattr(rpc_client, "test_endpoint") + assert hasattr(rpc_client, "something") - res = rpc_client.test_endpoint('spam') - assert res == 'egg' - res = rpc_client.test_endpoint(test_data='spam') - assert res == 'egg' + res = rpc_client.test_endpoint("spam") + assert res == "egg" + res = rpc_client.test_endpoint(test_data="spam") + assert res == "egg" - res = rpc_client.something('whatever') - assert res == 'spam' - res = rpc_client.something(data='whatever') - assert res == 'spam' + res = rpc_client.something("whatever") + assert res == "spam" + res = rpc_client.something(data="whatever") + assert res == "spam" def test_client_extra_serializers(rpc_client): - res = rpc_client.serializer_test(['foo', ExtraType('bar', b'baz')]) - assert res == ExtraType({'spam': 'egg'}, 'qux') + res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) + assert res == ExtraType({"spam": "egg"}, "qux") def test_client_overridden_method(rpc_client): - res = rpc_client.overridden_method('foo') - assert res == 'bar' + res = rpc_client.overridden_method("foo") + assert res == "bar" diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -11,18 +11,18 @@ # this class is used on the server part class RPCTest: - @remote_api_endpoint('endpoint_url') + @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('raises_typeerror') + @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): - raise TypeError('Did I pass through?') + raise TypeError("Did I pass through?") # this class is used on the client part. We cannot inherit from RPCTest @@ -31,22 +31,22 @@ # We do add an endpoint on the client side that has no implementation # server-side to test this very situation (in should generate a 404) class RPCTest2: - @remote_api_endpoint('endpoint_url') + @remote_api_endpoint("endpoint_url") def endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('not_on_server') + @remote_api_endpoint("not_on_server") def not_on_server(self, db=None, cur=None): - return 'ok' + return "ok" - @remote_api_endpoint('raises_typeerror') + @remote_api_endpoint("raises_typeerror") def raise_typeerror(self): - return 'data' + return "data" class RPCTestClient(RPCClient): @@ -57,10 +57,12 @@ def app(): # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py - application = RPCServerApp('testapp', backend_class=RPCTest) + application = RPCServerApp("testapp", backend_class=RPCTest) + @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) + return application @@ -73,35 +75,37 @@ def test_api_client_endpoint_missing(swh_rpc_client): with pytest.raises(AttributeError): - swh_rpc_client.missing(data='whatever') + swh_rpc_client.missing(data="whatever") def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL - with pytest.raises(Exception, match='404 not found'): + with pytest.raises(Exception, match="404 not found"): swh_rpc_client.not_on_server() def test_api_endpoint_kwargs(swh_rpc_client): - res = swh_rpc_client.something(data='whatever') - assert res == 'whatever' - res = swh_rpc_client.endpoint(test_data='spam') - assert res == 'egg' + res = swh_rpc_client.something(data="whatever") + assert res == "whatever" + res = swh_rpc_client.endpoint(test_data="spam") + assert res == "egg" def test_api_endpoint_args(swh_rpc_client): - res = swh_rpc_client.something('whatever') - assert res == 'whatever' - res = swh_rpc_client.endpoint('spam') - assert res == 'egg' + res = swh_rpc_client.something("whatever") + assert res == "whatever" + res = swh_rpc_client.endpoint("spam") + assert res == "egg" def test_api_typeerror(swh_rpc_client): with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() - assert exc_info.value.args[0]['type'] == 'TypeError' - assert exc_info.value.args[0]['args'] == ['Did I pass through?'] - assert str(exc_info.value) \ + assert exc_info.value.args[0]["type"] == "TypeError" + assert exc_info.value.args[0]["args"] == ["Did I pass through?"] + assert ( + str(exc_info.value) == "" + ) diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -22,97 +22,101 @@ @pytest.fixture def app(): class TestStorage: - @remote_api_endpoint('test_endpoint_url') + @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): - assert test_data == 'spam' - return 'egg' + assert test_data == "spam" + return "egg" - @remote_api_endpoint('path/to/endpoint') + @remote_api_endpoint("path/to/endpoint") def something(self, data, db=None, cur=None): return data - @remote_api_endpoint('serializer_test') + @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): - assert data == ['foo', ExtraType('bar', b'baz')] - return ExtraType({'spam': 'egg'}, 'qux') + assert data == ["foo", ExtraType("bar", b"baz")] + return ExtraType({"spam": "egg"}, "qux") - return MyRPCServerApp('testapp', backend_class=TestStorage) + return MyRPCServerApp("testapp", backend_class=TestStorage) def test_api_endpoint(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Content-Type', 'application/json'), - ('Accept', 'application/json')], - data=json.dumps({'data': 'toto'}), + url_for("something"), + headers=[("Content-Type", "application/json"), ("Accept", "application/json")], + data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" def test_api_nego_default(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Content-Type', 'application/json')], - data=json.dumps({'data': 'toto'}), + url_for("something"), + headers=[("Content-Type", "application/json")], + data=json.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( - url_for('something'), - headers=[('Accept', 'application/x-msgpack'), - ('Content-Type', 'application/x-msgpack')], - data=msgpack.dumps({'data': 'toto'}), + url_for("something"), + headers=[ + ("Accept", "application/x-msgpack"), + ("Content-Type", "application/x-msgpack"), + ], + data=msgpack.dumps({"data": "toto"}), ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' - assert res.data == b'\xa4toto' + assert res.mimetype == "application/x-msgpack" + assert res.data == b"\xa4toto" def test_rpc_server(flask_app_client): res = flask_app_client.post( - url_for('test_endpoint'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], - data=b'\x81\xa9test_data\xa4spam') + url_for("test_endpoint"), + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=b"\x81\xa9test_data\xa4spam", + ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' - assert res.data == b'\xa3egg' + assert res.mimetype == "application/x-msgpack" + assert res.data == b"\xa3egg" def test_rpc_server_extra_serializers(flask_app_client): res = flask_app_client.post( - url_for('serializer_test'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], - data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + url_for("serializer_test"), + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\xa3bar\xc4\x03baz", + ) assert res.status_code == 200 - assert res.mimetype == 'application/x-msgpack' + assert res.mimetype == "application/x-msgpack" assert res.data == ( - b'\x82\xc4\x07swhtype\xa9extratype\xc4' - b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b"\x82\xc4\x07swhtype\xa9extratype\xc4" b"\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) def test_api_negotiate_no_extra_encoders(app, flask_app_client): - url = '/test/negotiate/no/extra/encoders' + url = "/test/negotiate/no/extra/encoders" - @app.route(url, methods=['POST']) + @app.route(url, methods=["POST"]) @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) def endpoint(): - return 'test' + return "test" - res = flask_app_client.post( - url, - headers=[('Content-Type', 'application/json')], - ) + res = flask_app_client.post(url, headers=[("Content-Type", "application/json")]) assert res.status_code == 200 - assert res.mimetype == 'application/json' + assert res.mimetype == "application/json" assert res.data == b'"test"' diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -18,7 +18,7 @@ SWHJSONEncoder, msgpack_dumps, msgpack_loads, - decode_response + decode_response, ) @@ -28,21 +28,21 @@ self.arg2 = arg2 def __repr__(self): - return f'ExtraType({self.arg1}, {self.arg2})' + return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): - return isinstance(other, ExtraType) \ - and (self.arg1, self.arg2) == (other.arg1, other.arg2) + return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( + other.arg1, + other.arg2, + ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ - (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) + (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] -extra_decoders = { - 'extratype': lambda o: ExtraType(*o), -} +extra_decoders = {"extratype": lambda o: ExtraType(*o)} class Serializers(unittest.TestCase): @@ -50,68 +50,75 @@ self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { - 'bytes': b'123456789\x99\xaf\xff\x00\x12', - 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), - 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=self.tz), - 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=datetime.timezone.utc), - 'datetime_delta': datetime.timedelta(64), - 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), - 'swhtype': 'fake', - 'swh_dict': {'swhtype': 42, 'd': 'test'}, - 'random_dict': {'swhtype': 43}, - 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), + "bytes": b"123456789\x99\xaf\xff\x00\x12", + "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231_455), + "datetime_tz": datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz + ), + "datetime_utc": datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc + ), + "datetime_delta": datetime.timedelta(64), + "arrow_date": arrow.get("2018-04-25T16:17:53.533672+00:00"), + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, + "uuid": UUID("cdd8f804-9db6-40c3-93ab-5955d3836234"), } self.encoded_data = { - 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, - 'datetime_naive': {'swhtype': 'datetime', - 'd': '2015-01-01T12:04:42.231455'}, - 'datetime_tz': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+01:58'}, - 'datetime_utc': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+00:00'}, - 'datetime_delta': {'swhtype': 'timedelta', - 'd': {'days': 64, 'seconds': 0, - 'microseconds': 0}}, - 'arrow_date': {'swhtype': 'arrow', - 'd': '2018-04-25T16:17:53.533672+00:00'}, - 'swhtype': 'fake', - 'swh_dict': {'swhtype': 42, 'd': 'test'}, - 'random_dict': {'swhtype': 43}, - 'uuid': {'swhtype': 'uuid', - 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, + "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, + "datetime_naive": { + "swhtype": "datetime", + "d": "2015-01-01T12:04:42.231455", + }, + "datetime_tz": { + "swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+01:58", + }, + "datetime_utc": { + "swhtype": "datetime", + "d": "2015-03-04T18:25:13.001234+00:00", + }, + "datetime_delta": { + "swhtype": "timedelta", + "d": {"days": 64, "seconds": 0, "microseconds": 0}, + }, + "arrow_date": {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"}, + "swhtype": "fake", + "swh_dict": {"swhtype": 42, "d": "test"}, + "random_dict": {"swhtype": 43}, + "uuid": {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"}, } self.legacy_msgpack = { - 'bytes': b'\xc4\x0e123456789\x99\xaf\xff\x00\x12', - 'datetime_naive': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba' - b'2015-01-01T12:04:42.231455' + "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", + "datetime_naive": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" + b"2015-01-01T12:04:42.231455" ), - 'datetime_tz': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 ' - b'2015-03-04T18:25:13.001234+01:58' + "datetime_tz": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " + b"2015-03-04T18:25:13.001234+01:58" ), - 'datetime_utc': ( - b'\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 ' - b'2015-03-04T18:25:13.001234+00:00' + "datetime_utc": ( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " + b"2015-03-04T18:25:13.001234+00:00" ), - 'datetime_delta': ( - b'\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4' - b'days@\xa7seconds\x00\xacmicroseconds\x00' + "datetime_delta": ( + b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" + b"days@\xa7seconds\x00\xacmicroseconds\x00" ), - 'arrow_date': ( - b'\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 ' - b'2018-04-25T16:17:53.533672+00:00' + "arrow_date": ( + b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 " + b"2018-04-25T16:17:53.533672+00:00" ), - 'swhtype': b'\xa4fake', - 'swh_dict': b'\x82\xa7swhtype*\xa1d\xa4test', - 'random_dict': b'\x81\xa7swhtype+', - 'uuid': ( - b'\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$' - b'cdd8f804-9db6-40c3-93ab-5955d3836234' + "swhtype": b"\xa4fake", + "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", + "random_dict": b"\x81\xa7swhtype+", + "uuid": ( + b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" + b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } @@ -123,33 +130,32 @@ self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_round_trip_json_extra_types(self): - original_data = [ExtraType('baz', self.data), 'qux'] + original_data = [ExtraType("baz", self.data), "qux"] - data = json.dumps(original_data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + data = json.dumps( + original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders + ) self.assertEqual( original_data, - json.loads( - data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders), + ) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): - original_data = { - **self.data, - 'none_dict_key': {None: 42}, - } + original_data = {**self.data, "none_dict_key": {None: 42}} data = msgpack_dumps(original_data) self.assertEqual(original_data, msgpack_loads(data)) def test_round_trip_msgpack_extra_types(self): - original_data = [ExtraType('baz', self.data), 'qux'] + original_data = [ExtraType("baz", self.data), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( - original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + original_data, msgpack_loads(data, extra_decoders=extra_decoders) + ) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) @@ -161,10 +167,12 @@ @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): - mock_requests.get('https://example.org/test/data', - json=self.encoded_data, - headers={'content-type': 'application/json'}) - response = requests.get('https://example.org/test/data') + mock_requests.get( + "https://example.org/test/data", + json=self.encoded_data, + headers={"content-type": "application/json"}, + ) + response = requests.get("https://example.org/test/data") assert decode_response(response) == self.data def test_decode_legacy_msgpack(self): diff --git a/swh/core/cli/__init__.py b/swh/core/cli/__init__.py --- a/swh/core/cli/__init__.py +++ b/swh/core/cli/__init__.py @@ -13,19 +13,19 @@ from ..sentry import init_sentry -LOG_LEVEL_NAMES = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +LOG_LEVEL_NAMES = ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] -CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) logger = logging.getLogger(__name__) class AliasedGroup(click.Group): - '''A simple Group that supports command aliases, as well as notes related to - options''' + """A simple Group that supports command aliases, as well as notes related to + options""" def __init__(self, name=None, commands=None, **attrs): - self.option_notes = attrs.pop('option_notes', None) + self.option_notes = attrs.pop("option_notes", None) self.aliases = {} super().__init__(name, commands, **attrs) @@ -40,7 +40,7 @@ def format_options(self, ctx, formatter): click.Command.format_options(self, ctx, formatter) if self.option_notes: - with formatter.section('Notes'): + with formatter.section("Notes"): formatter.write_text(self.option_notes) self.format_commands(ctx, formatter) @@ -52,26 +52,38 @@ @click.group( - context_settings=CONTEXT_SETTINGS, cls=AliasedGroup, - option_notes='''\ + context_settings=CONTEXT_SETTINGS, + cls=AliasedGroup, + option_notes="""\ If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. -''' +""", +) +@click.option( + "--log-level", + "-l", + default=None, + type=click.Choice(LOG_LEVEL_NAMES), + help="Log level (defaults to INFO).", +) +@click.option( + "--log-config", + default=None, + type=click.File("r"), + help="Python yaml logging configuration file.", +) +@click.option( + "--sentry-dsn", default=None, help="DSN of the Sentry instance to report to" +) +@click.option( + "--sentry-debug/--no-sentry-debug", + default=False, + hidden=True, + help="Enable debugging of sentry", ) -@click.option('--log-level', '-l', default=None, - type=click.Choice(LOG_LEVEL_NAMES), - help="Log level (defaults to INFO).") -@click.option('--log-config', default=None, - type=click.File('r'), - help="Python yaml logging configuration file.") -@click.option('--sentry-dsn', default=None, - help="DSN of the Sentry instance to report to") -@click.option('--sentry-debug/--no-sentry-debug', - default=False, hidden=True, - help="Enable debugging of sentry") @click.pass_context def swh(ctx, log_level, log_config, sentry_dsn, sentry_debug): """Command line interface for Software Heritage. @@ -82,7 +94,7 @@ init_sentry(sentry_dsn, debug=sentry_debug) if log_level is None and log_config is None: - log_level = 'INFO' + log_level = "INFO" if log_config: logging.config.dictConfig(yaml.safe_load(log_config.read())) @@ -92,7 +104,7 @@ logging.root.setLevel(log_level) ctx.ensure_object(dict) - ctx.obj['log_level'] = log_level + ctx.obj["log_level"] = log_level def main(): @@ -100,16 +112,15 @@ # for the next few logging statements logging.basicConfig() # load plugins that define cli sub commands - for entry_point in pkg_resources.iter_entry_points('swh.cli.subcommands'): + for entry_point in pkg_resources.iter_entry_points("swh.cli.subcommands"): try: cmd = entry_point.load() swh.add_command(cmd, name=entry_point.name) except Exception as e: - logger.warning('Could not load subcommand %s: %s', - entry_point.name, str(e)) + logger.warning("Could not load subcommand %s: %s", entry_point.name, str(e)) - return swh(auto_envvar_prefix='SWH') + return swh(auto_envvar_prefix="SWH") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -21,16 +21,20 @@ @click.group(name="db", context_settings=CONTEXT_SETTINGS) -@click.option("--config-file", "-C", default=None, - type=click.Path(exists=True, dir_okay=False), - help="Configuration file.") +@click.option( + "--config-file", + "-C", + default=None, + type=click.Path(exists=True, dir_okay=False), + help="Configuration file.", +) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. """ ctx.ensure_object(dict) if config_file is None: - config_file = environ.get('SWH_CONFIG_FILENAME') + config_file = environ.get("SWH_CONFIG_FILENAME") cfg = config_read(config_file) ctx.obj["config"] = cfg @@ -74,8 +78,8 @@ sqlfiles = get_sql_for_package(modname) except click.BadParameter: logger.info( - "Failed to load/find sql initialization files for %s", - modname) + "Failed to load/find sql initialization files for %s", modname + ) if sqlfiles: conninfo = cfg["args"]["db"] @@ -96,12 +100,20 @@ @click.command(context_settings=CONTEXT_SETTINGS) -@click.argument('module', nargs=-1, required=True) -@click.option('--db-name', '-d', help='Database name.', - default='softwareheritage-dev', show_default=True) -@click.option('--create-db/--no-create-db', '-C', - help='Attempt to create the database.', - default=False) +@click.argument("module", nargs=-1, required=True) +@click.option( + "--db-name", + "-d", + help="Database name.", + default="softwareheritage-dev", + show_default=True, +) +@click.option( + "--create-db/--no-create-db", + "-C", + help="Attempt to create the database.", + default=False, +) def db_init(module, db_name, create_db): """Initialise a database for the Software Heritage . By default, does not attempt to create the database. @@ -122,11 +134,13 @@ # put import statements here so we can keep startup time of the main swh # command as short as possible from swh.core.db.tests.db_testing import ( - pg_createdb, pg_restore, DB_DUMP_TYPES, - swh_db_version + pg_createdb, + pg_restore, + DB_DUMP_TYPES, + swh_db_version, ) - logger.debug('db_init %s dn_name=%s', module, db_name) + logger.debug("db_init %s dn_name=%s", module, db_name) dump_files = [] for modname in module: @@ -138,10 +152,9 @@ # Try to retrieve the db version if any db_version = swh_db_version(db_name) if not db_version: # Initialize the db - dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) - for x in dump_files] + dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) for x in dump_files] for dump, dtype in dump_files: - click.secho('Loading {}'.format(dump), fg='yellow') + click.secho("Loading {}".format(dump), fg="yellow") pg_restore(db_name, dump, dtype) db_version = swh_db_version(db_name) @@ -149,8 +162,11 @@ # TODO: Ideally migrate the version from db_version to the latest # db version - click.secho('DONE database is {} version {}'.format(db_name, db_version), - fg='green', bold=True) + click.secho( + "DONE database is {} version {}".format(db_name, db_version), + fg="green", + bold=True, + ) def get_sql_for_package(modname): @@ -167,6 +183,6 @@ sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( - "Module {} does not provide a db schema " - "(no sql/ dir)".format(modname)) + "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) + ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) diff --git a/swh/core/config.py b/swh/core/config.py --- a/swh/core/config.py +++ b/swh/core/config.py @@ -16,39 +16,30 @@ logger = logging.getLogger(__name__) -SWH_CONFIG_DIRECTORIES = [ - '~/.config/swh', - '~/.swh', - '/etc/softwareheritage', -] +SWH_CONFIG_DIRECTORIES = ["~/.config/swh", "~/.swh", "/etc/softwareheritage"] -SWH_GLOBAL_CONFIG = 'global.ini' +SWH_GLOBAL_CONFIG = "global.ini" SWH_DEFAULT_GLOBAL_CONFIG = { - 'max_content_size': ('int', 100 * 1024 * 1024), - 'log_db': ('str', 'dbname=softwareheritage-log'), + "max_content_size": ("int", 100 * 1024 * 1024), + "log_db": ("str", "dbname=softwareheritage-log"), } -SWH_CONFIG_EXTENSIONS = [ - '.yml', - '.ini', -] +SWH_CONFIG_EXTENSIONS = [".yml", ".ini"] # conversion per type _map_convert_fn = { - 'int': int, - 'bool': lambda x: x.lower() == 'true', - 'list[str]': lambda x: [value.strip() for value in x.split(',')], - 'list[int]': lambda x: [int(value.strip()) for value in x.split(',')], + "int": int, + "bool": lambda x: x.lower() == "true", + "list[str]": lambda x: [value.strip() for value in x.split(",")], + "list[int]": lambda x: [int(value.strip()) for value in x.split(",")], } _map_check_fn = { - 'int': lambda x: isinstance(x, int), - 'bool': lambda x: isinstance(x, bool), - 'list[str]': lambda x: (isinstance(x, list) and - all(isinstance(y, str) for y in x)), - 'list[int]': lambda x: (isinstance(x, list) and - all(isinstance(y, int) for y in x)), + "int": lambda x: isinstance(x, int), + "bool": lambda x: isinstance(x, bool), + "list[str]": lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), + "list[int]": lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } @@ -78,7 +69,7 @@ def config_basepath(config_path): """Return the base path of a configuration file""" - if config_path.endswith(('.ini', '.yml')): + if config_path.endswith((".ini", ".yml")): return config_path[:-4] return config_path @@ -89,22 +80,21 @@ Can read yml or ini files. """ - yml_file = base_config_path + '.yml' + yml_file = base_config_path + ".yml" if exists_accessible(yml_file): - logger.info('Loading config file %s', yml_file) + logger.info("Loading config file %s", yml_file) with open(yml_file) as f: return yaml.safe_load(f) - ini_file = base_config_path + '.ini' + ini_file = base_config_path + ".ini" if exists_accessible(ini_file): config = configparser.ConfigParser() config.read(ini_file) - if 'main' in config._sections: - logger.info('Loading config file %s', ini_file) - return config._sections['main'] + if "main" in config._sections: + logger.info("Loading config file %s", ini_file) + return config._sections["main"] else: - logger.warning('Ignoring config file %s (no [main] section)', - ini_file) + logger.warning("Ignoring config file %s (no [main] section)", ini_file) return {} @@ -112,8 +102,9 @@ def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) - return any(exists_accessible(basepath + extension) - for extension in SWH_CONFIG_EXTENSIONS) + return any( + exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS + ) def read(conf_file=None, default_conf=None): @@ -233,8 +224,7 @@ Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): - raise TypeError( - 'Cannot merge a %s with a %s' % (type(base), type(other))) + raise TypeError("Cannot merge a %s with a %s" % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) @@ -258,13 +248,13 @@ """Return the Software Heritage specific configuration paths for the given filename.""" - return [os.path.join(dirname, base_filename) - for dirname in SWH_CONFIG_DIRECTORIES] + return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ + def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) @@ -276,10 +266,7 @@ def load_global_config(): """Load the global Software Heritage config""" - return priority_read( - swh_config_paths(SWH_GLOBAL_CONFIG), - SWH_DEFAULT_GLOBAL_CONFIG, - ) + return priority_read(swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG) def load_named_config(name, default_conf=None, global_conf=True): @@ -314,11 +301,16 @@ """ DEFAULT_CONFIG = {} # type: Dict[str, Tuple[str, Any]] - CONFIG_BASE_FILENAME = '' # type: Optional[str] + CONFIG_BASE_FILENAME = "" # type: Optional[str] @classmethod - def parse_config_file(cls, base_filename=None, config_filename=None, - additional_configs=None, global_config=True): + def parse_config_file( + cls, + base_filename=None, + config_filename=None, + additional_configs=None, + global_config=True, + ): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration @@ -341,8 +333,8 @@ if config_filename: config_filenames = [config_filename] - elif 'SWH_CONFIG_FILENAME' in os.environ: - config_filenames = [os.environ['SWH_CONFIG_FILENAME']] + elif "SWH_CONFIG_FILENAME" in os.environ: + config_filenames = [os.environ["SWH_CONFIG_FILENAME"]] else: if not base_filename: base_filename = cls.CONFIG_BASE_FILENAME @@ -350,8 +342,9 @@ if not additional_configs: additional_configs = [] - full_default_config = merge_default_configs(cls.DEFAULT_CONFIG, - *additional_configs) + full_default_config = merge_default_configs( + cls.DEFAULT_CONFIG, *additional_configs + ) config = {} if global_config: diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -26,9 +26,9 @@ def escape(data): if data is None: - return '' + return "" if isinstance(data, bytes): - return '\\x%s' % binascii.hexlify(data).decode('ascii') + return "\\x%s" % binascii.hexlify(data).decode("ascii") elif isinstance(data, str): return '"%s"' % data.replace('"', '""') elif isinstance(data, datetime.datetime): @@ -38,16 +38,17 @@ elif isinstance(data, dict): return escape(json.dumps(data)) elif isinstance(data, list): - return escape("{%s}" % ','.join(escape(d) for d in data)) + return escape("{%s}" % ",".join(escape(d) for d in data)) elif isinstance(data, psycopg2.extras.Range): # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( - '%s%s,%s%s' % ( - '[' if data.lower_inc else '(', - '-infinity' if data.lower_inf else escape(data.lower), - 'infinity' if data.upper_inf else escape(data.upper), - ']' if data.upper_inc else ')', + "%s%s,%s%s" + % ( + "[" if data.lower_inc else "(", + "-infinity" if data.lower_inf else escape(data.lower), + "infinity" if data.upper_inf else escape(data.upper), + "]" if data.upper_inc else ")", ) ) elif isinstance(data, enum.IntEnum): @@ -74,12 +75,10 @@ def adapt_conn(cls, conn): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" - t_bytes = psycopg2.extensions.new_type( - (17,), "bytea", typecast_bytea) + t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) - t_bytes_array = psycopg2.extensions.new_array_type( - (1001,), "bytea[]", t_bytes) + t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod @@ -128,6 +127,7 @@ return cur_arg else: return self.conn.cursor() + _cursor = cursor # for bw compat @contextmanager @@ -147,8 +147,9 @@ self.conn.rollback() raise - def copy_to(self, items, tblname, columns, - cur=None, item_cb=None, default_values={}): + def copy_to( + self, items, tblname, columns, cur=None, item_cb=None, default_values={} + ): """Copy items' entries to table tblname with columns information. Args: @@ -168,10 +169,11 @@ def writer(): nonlocal exc_info cursor = self.cursor(cur) - with open(read_file, 'r') as f: + with open(read_file, "r") as f: try: - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) + cursor.copy_expert( + "COPY %s (%s) FROM STDIN CSV" % (tblname, ", ".join(columns)), f + ) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() @@ -180,7 +182,7 @@ write_thread.start() try: - with open(write_file, 'w') as f: + with open(write_file, "w") as f: for d in items: if item_cb is not None: item_cb(d) @@ -191,13 +193,15 @@ line.append(escape(value)) except Exception as e: logger.error( - 'Could not escape value `%r` for column `%s`:' - 'Received exception: `%s`', - value, k, e + "Could not escape value `%r` for column `%s`:" + "Received exception: `%s`", + value, + k, + e, ) raise e from None - f.write(','.join(line)) - f.write('\n') + f.write(",".join(line)) + f.write("\n") finally: # No problem bubbling up exceptions, but we still need to make sure @@ -209,4 +213,4 @@ raise exc_info[1].with_traceback(exc_info[2]) def mktemp(self, tblname, cur=None): - self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) + self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,)) diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -11,8 +11,7 @@ def decorator(f): sig = inspect.signature(f) params = sig.parameters - params = [param for param in params.values() - if param.name not in names] + params = [param for param in params.values() if param.name not in names] sig = sig.replace(parameters=params) f.__signature__ = sig return f @@ -26,10 +25,10 @@ Returns a dictionary with the old values if they changed.""" old_options = {} for option, value in options.items(): - cursor.execute('SHOW %s' % option) + cursor.execute("SHOW %s" % option) old_value = cursor.fetchall()[0][0] if old_value != value: - cursor.execute('SET LOCAL %s TO %%s' % option, (value,)) + cursor.execute("SET LOCAL %s TO %%s" % option, (value,)) old_options[option] = old_value return old_options @@ -41,16 +40,16 @@ Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction_generator for generator functions.') + raise ValueError("Use db_transaction_generator for generator functions.") - @remove_kwargs(['cur', 'db']) + @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] + if "cur" in kwargs and kwargs["cur"]: + cur = kwargs["cur"] old_options = apply_options(cur, __client_options) ret = meth(self, *args, **kwargs) apply_options(cur, old_options) @@ -63,6 +62,7 @@ return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) + return _meth return decorator @@ -76,16 +76,16 @@ Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if not inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction for non-generator functions.') + raise ValueError("Use db_transaction for non-generator functions.") - @remove_kwargs(['cur', 'db']) + @remove_kwargs(["cur", "db"]) @functools.wraps(meth) def _meth(self, *args, **kwargs): - if 'cur' in kwargs and kwargs['cur']: - cur = kwargs['cur'] + if "cur" in kwargs and kwargs["cur"]: + cur = kwargs["cur"] old_options = apply_options(cur, __client_options) yield from meth(self, *args, **kwargs) apply_options(cur, old_options) @@ -99,4 +99,5 @@ self.put_db(db) return _meth + return decorator diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -26,13 +26,16 @@ not, the stored procedure will be executed first; the function body then. """ + def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): - cur = kwargs.get('cur', None) - self._cursor(cur).execute('SELECT %s()' % stored_proc) + cur = kwargs.get("cur", None) + self._cursor(cur).execute("SELECT %s()" % stored_proc) meth(self, *args, **kwargs) + return _meth + return wrap @@ -69,23 +72,24 @@ """ curr = pre = [] post = [] - tokens = re.split(br'(%.)', sql) + tokens = re.split(br"(%.)", sql) for token in tokens: - if len(token) != 2 or token[:1] != b'%': + if len(token) != 2 or token[:1] != b"%": curr.append(token) continue - if token[1:] == b's': + if token[1:] == b"s": if curr is pre: curr = post else: - raise ValueError( - "the query contains more than one '%s' placeholder") - elif token[1:] == b'%': - curr.append(b'%') + raise ValueError("the query contains more than one '%s' placeholder") + elif token[1:] == b"%": + curr.append(b"%") else: - raise ValueError("unsupported format character: '%s'" - % token[1:].decode('ascii', 'replace')) + raise ValueError( + "unsupported format character: '%s'" + % token[1:].decode("ascii", "replace") + ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") @@ -94,7 +98,7 @@ def execute_values_generator(cur, sql, argslist, template=None, page_size=100): - '''Execute a statement using SQL ``VALUES`` with a sequence of parameters. + """Execute a statement using SQL ``VALUES`` with a sequence of parameters. Rows returned by the query are returned through a generator. You need to consume the generator for the queries to be executed! @@ -127,23 +131,21 @@ After the execution of the function the `cursor.rowcount` property will **not** contain a total result. - ''' + """ # we can't just use sql % vals because vals is bytes: if sql is bytes # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): - sql = sql.encode( - psycopg2.extensions.encodings[cur.connection.encoding] - ) + sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): if template is None: - template = b'(' + b','.join([b'%s'] * len(page[0])) + b')' + template = b"(" + b",".join([b"%s"] * len(page[0])) + b")" parts = pre[:] for args in page: parts.append(cur.mogrify(template, args)) - parts.append(b',') + parts.append(b",") parts[-1:] = post - cur.execute(b''.join(parts)) + cur.execute(b"".join(parts)) yield from cur diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,2 +1,3 @@ import os -os.environ['LC_ALL'] = 'C.UTF-8' + +os.environ["LC_ALL"] = "C.UTF-8" diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -14,7 +14,7 @@ from swh.core.utils import numfile_sortkey as sortkey -DB_DUMP_TYPES = {'.sql': 'psql', '.dump': 'pg_dump'} # type: Dict[str, str] +DB_DUMP_TYPES = {".sql": "psql", ".dump": "pg_dump"} # type: Dict[str, str] def swh_db_version(dbname_or_service): @@ -28,48 +28,70 @@ Optional[Int]: Either the db's version or None """ - query = 'select version from dbversion order by dbversion desc limit 1' + query = "select version from dbversion order by dbversion desc limit 1" cmd = [ - 'psql', '--tuples-only', '--no-psqlrc', '--quiet', - '-v', 'ON_ERROR_STOP=1', "--command=%s" % query, - dbname_or_service + "psql", + "--tuples-only", + "--no-psqlrc", + "--quiet", + "-v", + "ON_ERROR_STOP=1", + "--command=%s" % query, + dbname_or_service, ] try: - r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, - universal_newlines=True) + r = subprocess.run( + cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True + ) result = int(r.stdout.strip()) except Exception: # db not initialized result = None return result -def pg_restore(dbname, dumpfile, dumptype='pg_dump'): +def pg_restore(dbname, dumpfile, dumptype="pg_dump"): """ Args: dbname: name of the DB to restore into dumpfile: path of the dump file dumptype: one of 'pg_dump' (for binary dumps), 'psql' (for SQL dumps) """ - assert dumptype in ['pg_dump', 'psql'] - if dumptype == 'pg_dump': - subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', - '--dbname', dbname, dumpfile]) - elif dumptype == 'psql': - subprocess.check_call(['psql', '--quiet', - '--no-psqlrc', - '-v', 'ON_ERROR_STOP=1', - '-f', dumpfile, - dbname]) + assert dumptype in ["pg_dump", "psql"] + if dumptype == "pg_dump": + subprocess.check_call( + [ + "pg_restore", + "--no-owner", + "--no-privileges", + "--dbname", + dbname, + dumpfile, + ] + ) + elif dumptype == "psql": + subprocess.check_call( + [ + "psql", + "--quiet", + "--no-psqlrc", + "-v", + "ON_ERROR_STOP=1", + "-f", + dumpfile, + dbname, + ] + ) def pg_dump(dbname, dumpfile): - subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', - '-f', dumpfile, dbname]) + subprocess.check_call( + ["pg_dump", "--no-owner", "--no-privileges", "-Fc", "-f", dumpfile, dbname] + ) def pg_dropdb(dbname): - subprocess.check_call(['dropdb', dbname]) + subprocess.check_call(["dropdb", dbname]) def pg_createdb(dbname, check=True): @@ -79,7 +101,7 @@ not exist, the db will be created. """ - subprocess.run(['createdb', dbname], check=check) + subprocess.run(["createdb", dbname], check=check) def db_create(dbname, dumps=None): @@ -93,7 +115,7 @@ try: pg_createdb(dbname) except subprocess.CalledProcessError: # try recovering once, in case - pg_dropdb(dbname) # the db already existed + pg_dropdb(dbname) # the db already existed pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) @@ -115,11 +137,8 @@ context: setUp """ - conn = psycopg2.connect('dbname=' + dbname) - return { - 'conn': conn, - 'cursor': conn.cursor() - } + conn = psycopg2.connect("dbname=" + dbname) + return {"conn": conn, "cursor": conn.cursor()} def db_close(conn): @@ -139,8 +158,8 @@ def __enter__(self): self.db_setup = db_connect(self.dbname) - self.conn = self.db_setup['conn'] - self.cursor = self.db_setup['cursor'] + self.conn = self.db_setup["conn"] + self.cursor = self.db_setup["cursor"] return self def __exit__(self, *_): @@ -148,13 +167,12 @@ class DbTestContext: - def __init__(self, name='softwareheritage-test', dumps=None): + def __init__(self, name="softwareheritage-test", dumps=None): self.dbname = name self.dumps = dumps def __enter__(self): - db_create(dbname=self.dbname, - dumps=self.dumps) + db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): @@ -215,7 +233,7 @@ DB_TEST_FIXTURE_IMPORTED = True @classmethod - def add_db(cls, name='softwareheritage-test', dumps=None): + def add_db(cls, name="softwareheritage-test", dumps=None): cls._DB_DUMP_LIST[name] = dumps @classmethod @@ -248,15 +266,18 @@ conn = db.conn cursor = db.cursor - cursor.execute("""SELECT table_name FROM information_schema.tables - WHERE table_schema = %s""", ('public',)) + cursor.execute( + """SELECT table_name FROM information_schema.tables + WHERE table_schema = %s""", + ("public",), + ) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: tables -= set(excluded) for table in tables: - cursor.execute('truncate table %s cascade' % table) + cursor.execute("truncate table %s cascade" % table) conn.commit() @@ -288,7 +309,7 @@ cursor: open psycopg2 cursor to the DB """ - TEST_DB_NAME = 'softwareheritage-test' + TEST_DB_NAME = "softwareheritage-test" TEST_DB_DUMP = None # type: Optional[Union[str, Iterable[str]]] @classmethod @@ -302,14 +323,13 @@ dump_files = [dump_files] all_dump_files = [] for files in dump_files: - all_dump_files.extend( - sorted(glob.glob(files), key=sortkey)) + all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) - all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) - for x in all_dump_files] + all_dump_files = [ + (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files + ] - cls.add_db(name=cls.TEST_DB_NAME, - dumps=all_dump_files) + cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py --- a/swh/core/db/tests/test_cli.py +++ b/swh/core/db/tests/test_cli.py @@ -5,7 +5,7 @@ from swh.core.cli.db import db as swhdb -help_msg = '''Usage: swh [OPTIONS] COMMAND [ARGS]... +help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. @@ -25,18 +25,18 @@ Commands: db Software Heritage database generic tools. -''' +""" def test_swh_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output == help_msg -help_db_msg = '''Usage: swh db [OPTIONS] COMMAND [ARGS]... +help_db_msg = """Usage: swh db [OPTIONS] COMMAND [ARGS]... Software Heritage database generic tools. @@ -46,12 +46,12 @@ Commands: init Initialize the database for every Software Heritage module found in... -''' +""" def test_swh_db_help(swhmain): swhmain.add_command(swhdb) runner = CliRunner() - result = runner.invoke(swhmain, ['db', '-h']) + result = runner.invoke(swhmain, ["db", "-h"]) assert result.exit_code == 0 assert result.output == help_db_msg diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -15,46 +15,45 @@ from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator -from .db_testing import ( - SingleDbTestFixture, db_create, db_destroy, db_close, -) +from .db_testing import SingleDbTestFixture, db_create, db_destroy, db_close -INIT_SQL = ''' +INIT_SQL = """ create table test_table ( i int, txt text, bytes bytea ); -''' - -db_rows = strategies.lists(strategies.tuples( - strategies.integers(-2147483648, +2147483647), - strategies.text( - alphabet=strategies.characters( - blacklist_categories=['Cs'], # surrogates - blacklist_characters=[ - '\x00', # pgsql does not support the null codepoint - '\r', # pgsql normalizes those - ] +""" + +db_rows = strategies.lists( + strategies.tuples( + strategies.integers(-2147483648, +2147483647), + strategies.text( + alphabet=strategies.characters( + blacklist_categories=["Cs"], # surrogates + blacklist_characters=[ + "\x00", # pgsql does not support the null codepoint + "\r", # pgsql normalizes those + ], + ) ), - ), - strategies.binary(), -)) + strategies.binary(), + ) +) @pytest.mark.db def test_connect(): - db_name = db_create('test-db2', dumps=[]) + db_name = db_create("test-db2", dumps=[]) try: - db = BaseDb.connect('dbname=%s' % db_name) + db = BaseDb.connect("dbname=%s" % db_name) with db.cursor() as cur: cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) cur.execute("select * from test_table;") - assert list(cur) == [(1, 'foo', b'bar')] + assert list(cur) == [(1, "foo", b"bar")] finally: db_close(db.conn) db_destroy(db_name) @@ -62,15 +61,15 @@ @pytest.mark.db class TestDb(SingleDbTestFixture, unittest.TestCase): - TEST_DB_NAME = 'test-db' + TEST_DB_NAME = "test-db" @classmethod def setUpClass(cls): with tempfile.TemporaryDirectory() as td: - with open(os.path.join(td, 'init.sql'), 'a') as fd: + with open(os.path.join(td, "init.sql"), "a") as fd: fd.write(INIT_SQL) - cls.TEST_DB_DUMP = os.path.join(td, '*.sql') + cls.TEST_DB_DUMP = os.path.join(td, "*.sql") super().setUpClass() @@ -80,37 +79,35 @@ def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) cur.execute("select * from test_table;") - self.assertEqual(list(cur), [(1, 'foo', b'bar')]) + self.assertEqual(list(cur), [(1, "foo", b"bar")]) def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) - self.reset_db_tables('test-db') + cur.execute("insert into test_table values (1, %s, %s);", ("foo", b"bar")) + self.reset_db_tables("test-db") cur.execute("select * from test_table;") self.assertEqual(list(cur), []) @given(db_rows) def test_copy_to(self, data): # the table is not reset between runs by hypothesis - self.reset_db_tables('test-db') + self.reset_db_tables("test-db") - items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] - self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] + self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) cur = self.db.cursor() - cur.execute('select * from test_table;') + cur.execute("select * from test_table;") self.assertCountEqual(list(cur), data) def test_copy_to_thread_exception(self): - data = [(2**65, 'foo', b'bar')] + data = [(2 ** 65, "foo", b"bar")] - items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] + items = [dict(zip(["i", "txt", "bytes"], item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): - self.db.copy_to(items, 'test_table', ['i', 'txt', 'bytes']) + self.db.copy_to(items, "test_table", ["i", "txt", "bytes"]) def test_db_transaction(mocker): @@ -132,11 +129,9 @@ db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, "put_db", create=True) storage.endpoint() @@ -145,7 +140,8 @@ def test_db_transaction__with_generator(): - with pytest.raises(ValueError, match='generator'): + with pytest.raises(ValueError, match="generator"): + class Storage: @db_transaction() def endpoint(self, cur=None, db=None): @@ -154,8 +150,10 @@ def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction() @@ -187,11 +185,9 @@ db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, "get_db", return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, "put_db", create=True) list(storage.endpoint()) @@ -200,7 +196,8 @@ def test_db_transaction_generator__with_nongenerator(): - with pytest.raises(ValueError, match='generator'): + with pytest.raises(ValueError, match="generator"): + class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): @@ -209,8 +206,10 @@ def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction_generator() diff --git a/swh/core/logger.py b/swh/core/logger.py --- a/swh/core/logger.py +++ b/swh/core/logger.py @@ -15,7 +15,7 @@ current_task = None -EXTRA_LOGDATA_PREFIX = 'swh_' +EXTRA_LOGDATA_PREFIX = "swh_" def db_level_of_py_level(lvl): @@ -30,33 +30,31 @@ """Get the extra data to insert to the database from the logging record""" log_data = record.__dict__ - extra_data = {k[len(EXTRA_LOGDATA_PREFIX):]: v - for k, v in log_data.items() - if k.startswith(EXTRA_LOGDATA_PREFIX)} + extra_data = { + k[len(EXTRA_LOGDATA_PREFIX) :]: v + for k, v in log_data.items() + if k.startswith(EXTRA_LOGDATA_PREFIX) + } - args = log_data.get('args') + args = log_data.get("args") if args: - extra_data['logging_args'] = args + extra_data["logging_args"] = args # Retrieve Celery task info if current_task and current_task.request: - extra_data['task'] = { - 'id': current_task.request.id, - 'name': current_task.name, - } + extra_data["task"] = {"id": current_task.request.id, "name": current_task.name} if task_args: - extra_data['task'].update({ - 'kwargs': current_task.request.kwargs, - 'args': current_task.request.args, - }) + extra_data["task"].update( + { + "kwargs": current_task.request.kwargs, + "args": current_task.request.args, + } + ) return extra_data -def flatten( - data: Any, - separator: str = "_" -) -> Generator[Tuple[str, Any], None, None]: +def flatten(data: Any, separator: str = "_") -> Generator[Tuple[str, Any], None, None]: """Flatten the data dictionary into a flat structure""" def inner_flatten( @@ -102,13 +100,15 @@ } msg = self.format(record) pri = self.mapPriority(record.levelno) - send(msg, - PRIORITY=format(pri), - LOGGER=record.name, - THREAD_NAME=record.threadName, - CODE_FILE=record.pathname, - CODE_LINE=record.lineno, - CODE_FUNC=record.funcName, - **extra_data) + send( + msg, + PRIORITY=format(pri), + LOGGER=record.name, + THREAD_NAME=record.threadName, + CODE_FILE=record.pathname, + CODE_LINE=record.lineno, + CODE_FUNC=record.funcName, + **extra_data, + ) except Exception: self.handleError(record) diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -27,9 +27,12 @@ def get_response_cb( - request: requests.Request, context, datadir, - ignore_urls: List[str] = [], - visits: Optional[Dict] = None): + request: requests.Request, + context, + datadir, + ignore_urls: List[str] = [], + visits: Optional[Dict] = None, +): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. @@ -80,9 +83,9 @@ Optional[FileDescriptor] on disk file to read from the test context """ - logger.debug('get_response_cb(%s, %s)', request, context) - logger.debug('url: %s', request.url) - logger.debug('ignore_urls: %s', ignore_urls) + logger.debug("get_response_cb(%s, %s)", request, context) + logger.debug("url: %s", request.url) + logger.debug("ignore_urls: %s", ignore_urls) unquoted_url = unquote(request.url) if unquoted_url in ignore_urls: context.status_code = 404 @@ -90,28 +93,28 @@ url = urlparse(unquoted_url) # http://pypi.org ~> http_pypi.org # https://files.pythonhosted.org ~> https_files.pythonhosted.org - dirname = '%s_%s' % (url.scheme, url.hostname) + dirname = "%s_%s" % (url.scheme, url.hostname) # url.path: pypi//json -> local file: pypi__json filename = url.path[1:] - if filename.endswith('/'): + if filename.endswith("/"): filename = filename[:-1] - filename = filename.replace('/', '_') + filename = filename.replace("/", "_") if url.query: - filename += ',' + url.query.replace('&', ',') + filename += "," + url.query.replace("&", ",") filepath = path.join(datadir, dirname, filename) if visits is not None: visit = visits.get(url, 0) visits[url] = visit + 1 if visit: - filepath = filepath + '_visit%s' % visit + filepath = filepath + "_visit%s" % visit if not path.isfile(filepath): - logger.debug('not found filepath: %s', filepath) + logger.debug("not found filepath: %s", filepath) context.status_code = 404 return None - fd = open(filepath, 'rb') - context.headers['content-length'] = str(path.getsize(filepath)) + fd = open(filepath, "rb") + context.headers["content-length"] = str(path.getsize(filepath)) return fd @@ -132,11 +135,12 @@ """ - return path.join(path.dirname(str(request.fspath)), 'data') + return path.join(path.dirname(str(request.fspath)), "data") -def requests_mock_datadir_factory(ignore_urls: List[str] = [], - has_multi_visit: bool = False): +def requests_mock_datadir_factory( + ignore_urls: List[str] = [], has_multi_visit: bool = False +): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: @@ -167,18 +171,22 @@ has_multi_visit: Activate or not the multiple visits behavior """ + @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: - cb = partial(get_response_cb, - ignore_urls=ignore_urls, - datadir=datadir) - requests_mock.get(re.compile('https?://'), body=cb) + cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) + requests_mock.get(re.compile("https?://"), body=cb) else: visits = {} - requests_mock.get(re.compile('https?://'), body=partial( - get_response_cb, ignore_urls=ignore_urls, visits=visits, - datadir=datadir) + requests_mock.get( + re.compile("https?://"), + body=partial( + get_response_cb, + ignore_urls=ignore_urls, + visits=visits, + datadir=datadir, + ), ) return requests_mock @@ -193,8 +201,7 @@ # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... -requests_mock_datadir_visits = requests_mock_datadir_factory( - has_multi_visit=True) +requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture @@ -219,12 +226,12 @@ See swh/core/api/tests/test_rpc_client_server.py for an example of usage. """ - url = 'mock://example.com' + url = "mock://example.com" cli = swh_rpc_client_class(url=url) # we need to clear the list of existing adapters here so we ensure we # have one and only one adapter which is then used for all the requests. cli.session.adapters.clear() - cli.session.mount('mock://', swh_rpc_adapter) + cli.session.mount("mock://", swh_rpc_adapter) return cli @@ -251,7 +258,7 @@ response.status_code = resp.status_code # Make headers case-insensitive. - response.headers = CaseInsensitiveDict(getattr(resp, 'headers', {})) + response.headers = CaseInsensitiveDict(getattr(resp, "headers", {})) # Set encoding. response.encoding = get_encoding_from_headers(response.headers) @@ -259,7 +266,7 @@ response.reason = response.raw.status if isinstance(req.url, bytes): - response.url = req.url.decode('utf-8') + response.url = req.url.decode("utf-8") else: response.url = req.url @@ -275,10 +282,11 @@ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( - request.url, method=request.method, + request.url, + method=request.method, headers=request.headers.items(), data=request.body, - ) + ) return self.build_response(request, resp) @@ -299,9 +307,9 @@ assert client.get(url_for('myview')).status_code == 200 """ - if 'app' not in request.fixturenames: + if "app" not in request.fixturenames: return - app = request.getfixturevalue('app') + app = request.getfixturevalue("app") ctx = app.test_request_context() ctx.push() diff --git a/swh/core/sentry.py b/swh/core/sentry.py --- a/swh/core/sentry.py +++ b/swh/core/sentry.py @@ -8,21 +8,19 @@ def get_sentry_release(): - main_package = os.environ.get('SWH_MAIN_PACKAGE') + main_package = os.environ.get("SWH_MAIN_PACKAGE") if main_package: version = pkg_resources.get_distribution(main_package).version - return f'{main_package}@{version}' + return f"{main_package}@{version}" else: return None -def init_sentry( - sentry_dsn, *, debug=None, integrations=[], - extra_kwargs={}): +def init_sentry(sentry_dsn, *, debug=None, integrations=[], extra_kwargs={}): if debug is None: - debug = bool(os.environ.get('SWH_SENTRY_DEBUG')) - sentry_dsn = sentry_dsn or os.environ.get('SWH_SENTRY_DSN') - environment = os.environ.get('SWH_SENTRY_ENVIRONMENT') + debug = bool(os.environ.get("SWH_SENTRY_DEBUG")) + sentry_dsn = sentry_dsn or os.environ.get("SWH_SENTRY_DSN") + environment = os.environ.get("SWH_SENTRY_ENVIRONMENT") if sentry_dsn: import sentry_sdk diff --git a/swh/core/statsd.py b/swh/core/statsd.py --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -63,7 +63,7 @@ import warnings -log = logging.getLogger('swh.core.statsd') +log = logging.getLogger("swh.core.statsd") class TimedContextManagerDecorator(object): @@ -74,8 +74,10 @@ Attributes: elapsed (float): the elapsed time at the point of completion """ - def __init__(self, statsd, metric=None, error_metric=None, - tags=None, sample_rate=1): + + def __init__( + self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1 + ): self.statsd = statsd self.metric = metric self.error_metric = error_metric @@ -90,10 +92,11 @@ Default to the function name if metric was not provided. """ if not self.metric: - self.metric = '%s.%s' % (func.__module__, func.__name__) + self.metric = "%s.%s" % (func.__module__, func.__name__) # Coroutines if iscoroutinefunction(func): + @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() @@ -104,6 +107,7 @@ raise self._send(start) return result + return wrapped_co # Others @@ -117,6 +121,7 @@ raise self._send(start) return result + return wrapped def __enter__(self): @@ -134,13 +139,14 @@ def _send(self, start): elapsed = (monotonic() - start) * 1000 - self.statsd.timing(self.metric, elapsed, - tags=self.tags, sample_rate=self.sample_rate) + self.statsd.timing( + self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate + ) self.elapsed = elapsed def _send_error(self): if self.error_metric is None: - self.error_metric = self.metric + '_error_count' + self.error_metric = self.metric + "_error_count" self.statsd.increment(self.error_metric, tags=self.tags) def start(self): @@ -179,15 +185,21 @@ "label:value,other_label:other_value" """ - def __init__(self, host=None, port=None, max_buffer_size=50, - namespace=None, constant_tags=None): + def __init__( + self, + host=None, + port=None, + max_buffer_size=50, + namespace=None, + constant_tags=None, + ): # Connection if host is None: - host = os.environ.get('STATSD_HOST') or 'localhost' + host = os.environ.get("STATSD_HOST") or "localhost" self.host = host if port is None: - port = os.environ.get('STATSD_PORT') or 8125 + port = os.environ.get("STATSD_PORT") or 8125 self.port = int(port) # Socket @@ -195,29 +207,27 @@ self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server - self.encoding = 'utf-8' + self.encoding = "utf-8" # Tags self.constant_tags = {} - tags_envvar = os.environ.get('STATSD_TAGS', '') - for tag in tags_envvar.split(','): + tags_envvar = os.environ.get("STATSD_TAGS", "") + for tag in tags_envvar.split(","): if not tag: continue - if ':' not in tag: + if ":" not in tag: warnings.warn( - 'STATSD_TAGS needs to be in key:value format, ' - '%s invalid' % tag, + "STATSD_TAGS needs to be in key:value format, " "%s invalid" % tag, UserWarning, ) continue - k, v = tag.split(':', 1) + k, v = tag.split(":", 1) self.constant_tags[k] = v if constant_tags: - self.constant_tags.update({ - str(k): str(v) - for k, v in constant_tags.items() - }) + self.constant_tags.update( + {str(k): str(v) for k, v in constant_tags.items()} + ) # Namespace if namespace is not None: @@ -239,7 +249,7 @@ >>> statsd.gauge('users.online', 123) >>> statsd.gauge('active.connections', 1001, tags={"protocol": "http"}) """ - return self._report(metric, 'g', value, tags, sample_rate) + return self._report(metric, "g", value, tags, sample_rate) def increment(self, metric, value=1, tags=None, sample_rate=1): """ @@ -249,7 +259,7 @@ >>> statsd.increment('page.views') >>> statsd.increment('files.transferred', 124) """ - self._report(metric, 'c', value, tags, sample_rate) + self._report(metric, "c", value, tags, sample_rate) def decrement(self, metric, value=1, tags=None, sample_rate=1): """ @@ -260,7 +270,7 @@ >>> statsd.decrement('active.connections', 2) """ metric_value = -value if value else value - self._report(metric, 'c', metric_value, tags, sample_rate) + self._report(metric, "c", metric_value, tags, sample_rate) def histogram(self, metric, value, tags=None, sample_rate=1): """ @@ -269,7 +279,7 @@ >>> statsd.histogram('uploaded.file.size', 1445) >>> statsd.histogram('file.count', 26, tags={"filetype": "python"}) """ - self._report(metric, 'h', value, tags, sample_rate) + self._report(metric, "h", value, tags, sample_rate) def timing(self, metric, value, tags=None, sample_rate=1): """ @@ -277,7 +287,7 @@ >>> statsd.timing("query.response.time", 1234) """ - self._report(metric, 'ms', value, tags, sample_rate) + self._report(metric, "ms", value, tags, sample_rate) def timed(self, metric=None, error_metric=None, tags=None, sample_rate=1): """ @@ -306,9 +316,12 @@ statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( - statsd=self, metric=metric, + statsd=self, + metric=metric, error_metric=error_metric, - tags=tags, sample_rate=sample_rate) + tags=tags, + sample_rate=sample_rate, + ) def set(self, metric, value, tags=None, sample_rate=1): """ @@ -316,7 +329,7 @@ >>> statsd.set('visitors.uniques', 999) """ - self._report(metric, 's', value, tags, sample_rate) + self._report(metric, "s", value, tags, sample_rate) @property def socket(self): @@ -387,10 +400,9 @@ value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", - ("|#" + ",".join( - "%s:%s" % (k, v) - for (k, v) in sorted(tags.items()) - )) if tags else "", + ("|#" + ",".join("%s:%s" % (k, v) for (k, v) in sorted(tags.items()))) + if tags + else "", ) # Send it self._send(payload) @@ -398,7 +410,7 @@ def _send_to_server(self, packet): try: # If set, use socket directly - self.socket.send(packet.encode('utf-8')) + self.socket.send(packet.encode("utf-8")) except socket.timeout: return except socket.error: @@ -421,8 +433,7 @@ return { str(k): str(v) for k, v in itertools.chain( - self.constant_tags.items(), - (tags if tags else {}).items(), + self.constant_tags.items(), (tags if tags else {}).items() ) } diff --git a/swh/core/tarball.py b/swh/core/tarball.py --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -33,11 +33,12 @@ """ try: - run(['tar', 'xf', tarpath, '-C', extract_dir], check=True) + run(["tar", "xf", tarpath, "-C", extract_dir], check=True) return extract_dir except Exception as e: raise shutil.ReadError( - f'Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}') + f"Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}" + ) def register_new_archive_formats(): @@ -71,7 +72,7 @@ try: shutil.unpack_archive(tarpath, extract_dir=dest) except shutil.ReadError as e: - raise ValueError(f'Problem during unpacking {tarpath}. Reason: {e}') + raise ValueError(f"Problem during unpacking {tarpath}. Reason: {e}") # Fix permissions for dirpath, _, fnames in os.walk(dest): @@ -89,7 +90,7 @@ """ for dirpath, dirnames, fnames in os.walk(rootdir): - for fname in (dirnames+fnames): + for fname in dirnames + fnames: fpath = os.path.join(dirpath, fname) fname = utils.commonname(rootdir, fpath) yield fpath, fname @@ -99,7 +100,7 @@ """Compress dirpath's content as tarpath. """ - with zipfile.ZipFile(tarpath, 'w') as z: + with zipfile.ZipFile(tarpath, "w") as z: for fpath, fname in files: z.write(fpath, arcname=fname) @@ -108,7 +109,7 @@ """Compress dirpath's content as tarpath. """ - with tarfile.open(tarpath, 'w:bz2') as t: + with tarfile.open(tarpath, "w:bz2") as t: for fpath, fname in files: t.add(fpath, arcname=fname, recursive=False) @@ -128,7 +129,7 @@ else: # iterable of 'filepath, filename' files = dirpath_or_files - if nature == 'zip': + if nature == "zip": _compress_zip(tarpath, files) else: _compress_tar(tarpath, files) @@ -139,9 +140,9 @@ # Additional uncompression archive format support ADDITIONAL_ARCHIVE_FORMATS = [ # name , extensions, function - ('tar.Z|x', ['.tar.Z', '.tar.x'], _unpack_tar), + ("tar.Z|x", [".tar.Z", ".tar.x"], _unpack_tar), # FIXME: make this optional depending on the runtime lzip package install - ('tar.lz', ['.tar.lz'], _unpack_tar), + ("tar.lz", [".tar.lz"], _unpack_tar), ] register_new_archive_formats() diff --git a/swh/core/tests/__init__.py b/swh/core/tests/__init__.py --- a/swh/core/tests/__init__.py +++ b/swh/core/tests/__init__.py @@ -2,4 +2,4 @@ import swh.core -SQL_DIR = path.join(path.dirname(swh.core.__file__), 'sql') +SQL_DIR = path.join(path.dirname(swh.core.__file__), "sql") diff --git a/swh/core/tests/fixture/conftest.py b/swh/core/tests/fixture/conftest.py --- a/swh/core/tests/fixture/conftest.py +++ b/swh/core/tests/fixture/conftest.py @@ -8,7 +8,7 @@ from os import path -DATADIR = path.join(path.abspath(path.dirname(__file__)), 'data') +DATADIR = path.join(path.abspath(path.dirname(__file__)), "data") @pytest.fixture diff --git a/swh/core/tests/fixture/test_pytest_plugin.py b/swh/core/tests/fixture/test_pytest_plugin.py --- a/swh/core/tests/fixture/test_pytest_plugin.py +++ b/swh/core/tests/fixture/test_pytest_plugin.py @@ -11,14 +11,13 @@ # "datadir" fixture to specify where to retrieve the data files from. -def test_requests_mock_datadir_with_datadir_fixture_override( - requests_mock_datadir): +def test_requests_mock_datadir_with_datadir_fixture_override(requests_mock_datadir): """Override datadir fixture should retrieve data from elsewhere """ - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'welcome': 'you'} + assert response.json() == {"welcome": "you"} def test_data_dir_override(datadir): diff --git a/swh/core/tests/test_cli.py b/swh/core/tests/test_cli.py --- a/swh/core/tests/test_cli.py +++ b/swh/core/tests/test_cli.py @@ -13,7 +13,7 @@ import pytest -help_msg = '''Usage: swh [OPTIONS] COMMAND [ARGS]... +help_msg = """Usage: swh [OPTIONS] COMMAND [ARGS]... Command line interface for Software Heritage. @@ -30,172 +30,157 @@ The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. -''' +""" def test_swh_help(swhmain): runner = CliRunner() - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) - result = runner.invoke(swhmain, ['--help']) + result = runner.invoke(swhmain, ["--help"]) assert result.exit_code == 0 assert result.output.startswith(help_msg) def test_command(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: - result = runner.invoke(swhmain, ['test']) + with patch("sentry_sdk.init") as sentry_sdk_init: + result = runner.invoke(swhmain, ["test"]) sentry_sdk_init.assert_not_called() assert result.exit_code == 0 - assert result.output.strip() == 'Hello SWH!' + assert result.output.strip() == "Hello SWH!" def test_loglevel_default(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 20 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['test']) + result = runner.invoke(swhmain, ["test"]) assert result.exit_code == 0 print(result.output) - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_loglevel_error(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 40 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['-l', 'ERROR', 'test']) + result = runner.invoke(swhmain, ["-l", "ERROR", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_loglevel_debug(caplog, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): assert logging.root.level == 10 - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - result = runner.invoke(swhmain, ['-l', 'DEBUG', 'test']) + result = runner.invoke(swhmain, ["-l", "DEBUG", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" def test_sentry(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: - result = runner.invoke(swhmain, ['--sentry-dsn', 'test_dsn', 'test']) + with patch("sentry_sdk.init") as sentry_sdk_init: + result = runner.invoke(swhmain, ["--sentry-dsn", "test_dsn", "test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=False, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=False, integrations=[], release=None, environment=None ) def test_sentry_debug(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: result = runner.invoke( - swhmain, ['--sentry-dsn', 'test_dsn', '--sentry-debug', 'test']) + swhmain, ["--sentry-dsn", "test_dsn", "--sentry-debug", "test"] + ) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=True, integrations=[], release=None, environment=None ) def test_sentry_env(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: - env = { - 'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1', - } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + with patch("sentry_sdk.init") as sentry_sdk_init: + env = {"SWH_SENTRY_DSN": "test_dsn", "SWH_SENTRY_DEBUG": "1"} + result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn="test_dsn", debug=True, integrations=[], release=None, environment=None ) def test_sentry_env_main_package(swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - click.echo('Hello SWH!') + click.echo("Hello SWH!") runner = CliRunner() - with patch('sentry_sdk.init') as sentry_sdk_init: + with patch("sentry_sdk.init") as sentry_sdk_init: env = { - 'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_MAIN_PACKAGE': 'swh.core', - 'SWH_SENTRY_ENVIRONMENT': 'tests', + "SWH_SENTRY_DSN": "test_dsn", + "SWH_MAIN_PACKAGE": "swh.core", + "SWH_SENTRY_ENVIRONMENT": "tests", } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + result = runner.invoke(swhmain, ["test"], env=env, auto_envvar_prefix="SWH") assert result.exit_code == 0 - version = pkg_resources.get_distribution('swh.core').version + version = pkg_resources.get_distribution("swh.core").version - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', + dsn="test_dsn", debug=False, integrations=[], - release='swh.core@' + version, - environment='tests', + release="swh.core@" + version, + environment="tests", ) @pytest.fixture def log_config_path(tmp_path): - log_config = textwrap.dedent('''\ + log_config = textwrap.dedent( + """\ --- version: 1 formatters: @@ -214,85 +199,82 @@ loggers: dontshowdebug: level: INFO - ''') + """ + ) - (tmp_path / 'log_config.yml').write_text(log_config) + (tmp_path / "log_config.yml").write_text(log_config) - yield str(tmp_path / 'log_config.yml') + yield str(tmp_path / "log_config.yml") def test_log_config(caplog, log_config_path, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - logging.debug('Root log debug') - logging.info('Root log info') - logging.getLogger('dontshowdebug').debug('Not shown') - logging.getLogger('dontshowdebug').info('Shown') + logging.debug("Root log debug") + logging.info("Root log info") + logging.getLogger("dontshowdebug").debug("Not shown") + logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() - result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - 'test', - ], - ) + result = runner.invoke(swhmain, ["--log-config", log_config_path, "test"]) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:DEBUG:Root log debug', - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == "\n".join( + [ + "custom format:root:DEBUG:Root log debug", + "custom format:root:INFO:Root log info", + "custom format:dontshowdebug:INFO:Shown", + ] + ) def test_log_config_log_level_interaction(caplog, log_config_path, swhmain): - @swhmain.command(name='test') + @swhmain.command(name="test") @click.pass_context def swhtest(ctx): - logging.debug('Root log debug') - logging.info('Root log info') - logging.getLogger('dontshowdebug').debug('Not shown') - logging.getLogger('dontshowdebug').info('Shown') + logging.debug("Root log debug") + logging.info("Root log info") + logging.getLogger("dontshowdebug").debug("Not shown") + logging.getLogger("dontshowdebug").info("Shown") runner = CliRunner() result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - '--log-level', 'INFO', - 'test', - ], + swhmain, ["--log-config", log_config_path, "--log-level", "INFO", "test"] ) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == "\n".join( + [ + "custom format:root:INFO:Root log info", + "custom format:dontshowdebug:INFO:Shown", + ] + ) def test_aliased_command(swhmain): - @swhmain.command(name='canonical-test') + @swhmain.command(name="canonical-test") @click.pass_context def swhtest(ctx): - 'A test command.' - click.echo('Hello SWH!') - swhmain.add_alias(swhtest, 'othername') + "A test command." + click.echo("Hello SWH!") + + swhmain.add_alias(swhtest, "othername") runner = CliRunner() # check we have only 'canonical-test' listed in the usage help msg - result = runner.invoke(swhmain, ['-h']) + result = runner.invoke(swhmain, ["-h"]) assert result.exit_code == 0 - assert 'canonical-test A test command.' in result.output - assert 'othername' not in result.output + assert "canonical-test A test command." in result.output + assert "othername" not in result.output # check we can execute the cmd with 'canonical-test' - result = runner.invoke(swhmain, ['canonical-test']) + result = runner.invoke(swhmain, ["canonical-test"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" # check we can also execute the cmd with the alias 'othername' - result = runner.invoke(swhmain, ['othername']) + result = runner.invoke(swhmain, ["othername"]) assert result.exit_code == 0 - assert result.output.strip() == '''Hello SWH!''' + assert result.output.strip() == """Hello SWH!""" diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -11,61 +11,57 @@ from swh.core import config pytest_v = pkg_resources.get_distribution("pytest").parsed_version -if pytest_v < pkg_resources.extern.packaging.version.parse('3.9'): +if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): + @pytest.fixture def tmp_path(request): import tempfile import pathlib + with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) default_conf = { - 'a': ('int', 2), - 'b': ('string', 'default-string'), - 'c': ('bool', True), - 'd': ('int', 10), - 'e': ('int', None), - 'f': ('bool', None), - 'g': ('string', None), - 'h': ('bool', True), - 'i': ('bool', True), - 'ls': ('list[str]', ['a', 'b', 'c']), - 'li': ('list[int]', [42, 43]), + "a": ("int", 2), + "b": ("string", "default-string"), + "c": ("bool", True), + "d": ("int", 10), + "e": ("int", None), + "f": ("bool", None), + "g": ("string", None), + "h": ("bool", True), + "i": ("bool", True), + "ls": ("list[str]", ["a", "b", "c"]), + "li": ("list[int]", [42, 43]), } -other_default_conf = { - 'a': ('int', 3), -} +other_default_conf = {"a": ("int", 3)} full_default_conf = default_conf.copy() -full_default_conf['a'] = other_default_conf['a'] +full_default_conf["a"] = other_default_conf["a"] -parsed_default_conf = { - key: value - for key, (type, value) - in default_conf.items() -} +parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { - 'a': 1, - 'b': 'this is a string', - 'c': True, - 'd': 10, - 'e': None, - 'f': None, - 'g': None, - 'h': False, - 'i': True, - 'ls': ['list', 'of', 'strings'], - 'li': [1, 2, 3, 4], + "a": 1, + "b": "this is a string", + "c": True, + "d": 10, + "e": None, + "f": None, + "g": None, + "h": False, + "i": True, + "ls": ["list", "of", "strings"], + "li": [1, 2, 3, 4], } @pytest.fixture def swh_config(tmp_path): # create a temporary folder - conffile = tmp_path / 'config.ini' + conffile = tmp_path / "config.ini" conf_contents = """[main] a = 1 b = this is a string @@ -74,7 +70,7 @@ ls = list, of, strings li = 1, 2, 3, 4 """ - conffile.open('w').write(conf_contents) + conffile.open("w").write(conf_contents) return conffile @@ -90,7 +86,7 @@ @pytest.fixture def swh_config_unreadable_dir(swh_config): # Create a proper configuration file in an unreadable directory - perms_broken_dir = swh_config.parent / 'unreadabledir' + perms_broken_dir = swh_config.parent / "unreadabledir" perms_broken_dir.mkdir() shutil.move(str(swh_config), str(perms_broken_dir)) os.chmod(str(perms_broken_dir), 0o000) @@ -102,7 +98,7 @@ @pytest.fixture def swh_config_empty(tmp_path): # create a temporary folder - conffile = tmp_path / 'config.ini' + conffile = tmp_path / "config.ini" conffile.touch() return conffile @@ -125,7 +121,7 @@ def test_support_non_existing_conffile(tmp_path): # when - res = config.read(str(tmp_path / 'void.ini'), default_conf) + res = config.read(str(tmp_path / "void.ini"), default_conf) # then assert res == parsed_default_conf @@ -158,7 +154,7 @@ def test_priority_read_nonexist_conf(swh_config): - noexist = str(swh_config.parent / 'void.ini') + noexist = str(swh_config.parent / "void.ini") # when res = config.priority_read([noexist, str(swh_config)], default_conf) @@ -167,146 +163,150 @@ def test_priority_read_conf_nonexist_empty(swh_config): - noexist = swh_config.parent / 'void.ini' - empty = swh_config.parent / 'empty.ini' + noexist = swh_config.parent / "void.ini" + empty = swh_config.parent / "empty.ini" empty.touch() # when - res = config.priority_read([str(p) for p in ( - swh_config, noexist, empty)], default_conf) + res = config.priority_read( + [str(p) for p in (swh_config, noexist, empty)], default_conf + ) # then assert res == parsed_conffile def test_priority_read_empty_conf_nonexist(swh_config): - noexist = swh_config.parent / 'void.ini' - empty = swh_config.parent / 'empty.ini' + noexist = swh_config.parent / "void.ini" + empty = swh_config.parent / "empty.ini" empty.touch() # when - res = config.priority_read([str(p) for p in ( - empty, swh_config, noexist)], default_conf) + res = config.priority_read( + [str(p) for p in (empty, swh_config, noexist)], default_conf + ) # then assert res == parsed_default_conf def test_swh_config_paths(): - res = config.swh_config_paths('foo/bar.ini') + res = config.swh_config_paths("foo/bar.ini") assert res == [ - '~/.config/swh/foo/bar.ini', - '~/.swh/foo/bar.ini', - '/etc/softwareheritage/foo/bar.ini', + "~/.config/swh/foo/bar.ini", + "~/.swh/foo/bar.ini", + "/etc/softwareheritage/foo/bar.ini", ] def test_prepare_folder(tmp_path): # given - conf = {'path1': str(tmp_path / 'path1'), - 'path2': str(tmp_path / 'path2' / 'depth1')} + conf = { + "path1": str(tmp_path / "path1"), + "path2": str(tmp_path / "path2" / "depth1"), + } # the folders does not exists - assert not os.path.exists(conf['path1']), "path1 should not exist." - assert not os.path.exists(conf['path2']), "path2 should not exist." + assert not os.path.exists(conf["path1"]), "path1 should not exist." + assert not os.path.exists(conf["path2"]), "path2 should not exist." # when - config.prepare_folders(conf, 'path1') + config.prepare_folders(conf, "path1") # path1 exists but not path2 - assert os.path.exists(conf['path1']), "path1 should now exist!" - assert not os.path.exists(conf['path2']), "path2 should not exist." + assert os.path.exists(conf["path1"]), "path1 should now exist!" + assert not os.path.exists(conf["path2"]), "path2 should not exist." # path1 already exists, skips it but creates path2 - config.prepare_folders(conf, 'path1', 'path2') + config.prepare_folders(conf, "path1", "path2") - assert os.path.exists(conf['path1']), "path1 should still exist!" - assert os.path.exists(conf['path2']), "path2 should now exist." + assert os.path.exists(conf["path1"]), "path1 should still exist!" + assert os.path.exists(conf["path2"]), "path2 should now exist." def test_merge_config(): cfg_a = { - 'a': 42, - 'b': [1, 2, 3], - 'c': None, - 'd': {'gheez': 27}, - 'e': { - 'ea': 'Mr. Bungle', - 'eb': None, - 'ec': [11, 12, 13], - 'ed': {'eda': 'Secret Chief 3', - 'edb': 'Faith No More'}, - 'ee': 451, + "a": 42, + "b": [1, 2, 3], + "c": None, + "d": {"gheez": 27}, + "e": { + "ea": "Mr. Bungle", + "eb": None, + "ec": [11, 12, 13], + "ed": {"eda": "Secret Chief 3", "edb": "Faith No More"}, + "ee": 451, }, - 'f': 'Janis', + "f": "Janis", } cfg_b = { - 'a': 43, - 'b': [41, 42, 43], - 'c': 'Tom Waits', - 'd': None, - 'e': { - 'ea': 'Igorrr', - 'ec': [51, 52], - 'ed': {'edb': 'Sleepytime Gorilla Museum', - 'edc': 'Nils Peter Molvaer'}, + "a": 43, + "b": [41, 42, 43], + "c": "Tom Waits", + "d": None, + "e": { + "ea": "Igorrr", + "ec": [51, 52], + "ed": {"edb": "Sleepytime Gorilla Museum", "edc": "Nils Peter Molvaer"}, }, - 'g': 'Hüsker Dü', + "g": "Hüsker Dü", } # merge A, B cfg_m = config.merge_configs(cfg_a, cfg_b) assert cfg_m == { - 'a': 43, # b takes precedence - 'b': [41, 42, 43], # b takes precedence - 'c': 'Tom Waits', # b takes precedence - 'd': None, # b['d'] takes precedence (explicit None) - 'e': { - 'ea': 'Igorrr', # a takes precedence - 'eb': None, # only in a - 'ec': [51, 52], # b takes precedence - 'ed': { - 'eda': 'Secret Chief 3', # only in a - 'edb': 'Sleepytime Gorilla Museum', # b takes precedence - 'edc': 'Nils Peter Molvaer'}, # only defined in b - 'ee': 451, + "a": 43, # b takes precedence + "b": [41, 42, 43], # b takes precedence + "c": "Tom Waits", # b takes precedence + "d": None, # b['d'] takes precedence (explicit None) + "e": { + "ea": "Igorrr", # a takes precedence + "eb": None, # only in a + "ec": [51, 52], # b takes precedence + "ed": { + "eda": "Secret Chief 3", # only in a + "edb": "Sleepytime Gorilla Museum", # b takes precedence + "edc": "Nils Peter Molvaer", + }, # only defined in b + "ee": 451, }, - 'f': 'Janis', # only defined in a - 'g': 'Hüsker Dü', # only defined in b + "f": "Janis", # only defined in a + "g": "Hüsker Dü", # only defined in b } # merge B, A cfg_m = config.merge_configs(cfg_b, cfg_a) assert cfg_m == { - 'a': 42, # a takes precedence - 'b': [1, 2, 3], # a takes precedence - 'c': None, # a takes precedence - 'd': {'gheez': 27}, # a takes precedence - 'e': { - 'ea': 'Mr. Bungle', # a takes precedence - 'eb': None, # only defined in a - 'ec': [11, 12, 13], # a takes precedence - 'ed': { - 'eda': 'Secret Chief 3', # only in a - 'edb': 'Faith No More', # a takes precedence - 'edc': 'Nils Peter Molvaer'}, # only in b - 'ee': 451, + "a": 42, # a takes precedence + "b": [1, 2, 3], # a takes precedence + "c": None, # a takes precedence + "d": {"gheez": 27}, # a takes precedence + "e": { + "ea": "Mr. Bungle", # a takes precedence + "eb": None, # only defined in a + "ec": [11, 12, 13], # a takes precedence + "ed": { + "eda": "Secret Chief 3", # only in a + "edb": "Faith No More", # a takes precedence + "edc": "Nils Peter Molvaer", + }, # only in b + "ee": 451, }, - 'f': 'Janis', # only in a - 'g': 'Hüsker Dü', # only in b + "f": "Janis", # only in a + "g": "Hüsker Dü", # only in b } def test_merge_config_type_error(): - for v in (1, 'str', None): + for v in (1, "str", None): with pytest.raises(TypeError): config.merge_configs(v, {}) with pytest.raises(TypeError): config.merge_configs({}, v) - for v in (1, 'str'): + for v in (1, "str"): with pytest.raises(TypeError): - config.merge_configs({'a': v}, {'a': {}}) + config.merge_configs({"a": v}, {"a": {}}) with pytest.raises(TypeError): - config.merge_configs({'a': {}}, {'a': v}) + config.merge_configs({"a": {}}, {"a": v}) diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py --- a/swh/core/tests/test_logger.py +++ b/swh/core/tests/test_logger.py @@ -19,53 +19,51 @@ def test_db_level(): - assert logger.db_level_of_py_level(10) == 'debug' - assert logger.db_level_of_py_level(20) == 'info' - assert logger.db_level_of_py_level(30) == 'warning' - assert logger.db_level_of_py_level(40) == 'error' - assert logger.db_level_of_py_level(50) == 'critical' + assert logger.db_level_of_py_level(10) == "debug" + assert logger.db_level_of_py_level(20) == "info" + assert logger.db_level_of_py_level(30) == "warning" + assert logger.db_level_of_py_level(40) == "error" + assert logger.db_level_of_py_level(50) == "critical" def test_flatten_scalar(): - assert list(logger.flatten('')) == [('', '')] - assert list(logger.flatten('toto')) == [('', 'toto')] + assert list(logger.flatten("")) == [("", "")] + assert list(logger.flatten("toto")) == [("", "toto")] - assert list(logger.flatten(10)) == [('', 10)] - assert list(logger.flatten(10.5)) == [('', 10.5)] + assert list(logger.flatten(10)) == [("", 10)] + assert list(logger.flatten(10.5)) == [("", 10.5)] def test_flatten_list(): assert list(logger.flatten([])) == [] - assert list(logger.flatten([1])) == [('0', 1)] + assert list(logger.flatten([1])) == [("0", 1)] - assert list(logger.flatten([1, 2, ['a', 'b']])) == [ - ('0', 1), - ('1', 2), - ('2_0', 'a'), - ('2_1', 'b'), + assert list(logger.flatten([1, 2, ["a", "b"]])) == [ + ("0", 1), + ("1", 2), + ("2_0", "a"), + ("2_1", "b"), ] - assert list(logger.flatten([1, 2, ['a', ('x', 1)]])) == [ - ('0', 1), - ('1', 2), - ('2_0', 'a'), - ('2_1_0', 'x'), - ('2_1_1', 1), + assert list(logger.flatten([1, 2, ["a", ("x", 1)]])) == [ + ("0", 1), + ("1", 2), + ("2_0", "a"), + ("2_1_0", "x"), + ("2_1_1", 1), ] def test_flatten_dict(): assert list(logger.flatten({})) == [] - assert list(logger.flatten({'a': 1})) == [('a', 1)] - - assert sorted(logger.flatten({'a': 1, - 'b': (2, 3,), - 'c': {'d': 4, 'e': 'f'}})) == [ - ('a', 1), - ('b_0', 2), - ('b_1', 3), - ('c_d', 4), - ('c_e', 'f'), + assert list(logger.flatten({"a": 1})) == [("a", 1)] + + assert sorted(logger.flatten({"a": 1, "b": (2, 3), "c": {"d": 4, "e": "f"}})) == [ + ("a", 1), + ("b_0", 2), + ("b_1", 3), + ("c_d", 4), + ("c_e", "f"), ] @@ -74,57 +72,59 @@ str_d = str(d) assert list(logger.flatten(d)) == [("", str_d)] assert list(logger.flatten({"a": d})) == [("a", str_d)] - assert list(logger.flatten({"a": [d, d]})) == [ - ("a_0", str_d), ("a_1", str_d) - ] + assert list(logger.flatten({"a": [d, d]})) == [("a_0", str_d), ("a_1", str_d)] def test_stringify(): - assert logger.stringify(None) == 'None' - assert logger.stringify(123) == '123' - assert logger.stringify('abc') == 'abc' + assert logger.stringify(None) == "None" + assert logger.stringify(123) == "123" + assert logger.stringify("abc") == "abc" date = datetime(2019, 9, 1, 16, 32) - assert logger.stringify(date) == '2019-09-01T16:32:00' + assert logger.stringify(date) == "2019-09-01T16:32:00" tzdate = datetime(2019, 9, 1, 16, 32, tzinfo=pytz.utc) - assert logger.stringify(tzdate) == '2019-09-01T16:32:00+00:00' + assert logger.stringify(tzdate) == "2019-09-01T16:32:00+00:00" -@patch('swh.core.logger.send') +@patch("swh.core.logger.send") def test_journal_handler(send): - log = logging.getLogger('test_logger') + log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) - _, ln = log.info('hello world'), lineno() + _, ln = log.info("hello world"), lineno() send.assert_called_with( - 'hello world', + "hello world", CODE_FILE=__file__, - CODE_FUNC='test_journal_handler', + CODE_FUNC="test_journal_handler", CODE_LINE=ln, - LOGGER='test_logger', - PRIORITY='6', - THREAD_NAME='MainThread') + LOGGER="test_logger", + PRIORITY="6", + THREAD_NAME="MainThread", + ) -@patch('swh.core.logger.send') +@patch("swh.core.logger.send") def test_journal_handler_w_data(send): - log = logging.getLogger('test_logger') + log = logging.getLogger("test_logger") log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) - _, ln = log.debug('something cool %s', ['with', {'extra': 'data'}]), lineno() # noqa + _, ln = ( + log.debug("something cool %s", ["with", {"extra": "data"}]), + lineno() - 1, + ) # noqa send.assert_called_with( "something cool ['with', {'extra': 'data'}]", CODE_FILE=__file__, - CODE_FUNC='test_journal_handler_w_data', + CODE_FUNC="test_journal_handler_w_data", CODE_LINE=ln, - LOGGER='test_logger', - PRIORITY='7', - THREAD_NAME='MainThread', - SWH_LOGGING_ARGS_0_0='with', - SWH_LOGGING_ARGS_0_1_EXTRA='data' + LOGGER="test_logger", + PRIORITY="7", + THREAD_NAME="MainThread", + SWH_LOGGING_ARGS_0_0="with", + SWH_LOGGING_ARGS_0_1_EXTRA="data", ) diff --git a/swh/core/tests/test_pytest_plugin.py b/swh/core/tests/test_pytest_plugin.py --- a/swh/core/tests/test_pytest_plugin.py +++ b/swh/core/tests/test_pytest_plugin.py @@ -14,10 +14,11 @@ def test_get_response_cb_with_encoded_url(requests_mock_datadir): # The following urls (quoted, unquoted) will be resolved as the same file for encoded_url, expected_response in [ - ('https://forge.s.o/api/diffusion?attachments%5Buris%5D=1', - "something"), - ('https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1', # noqa - "something else"), + ("https://forge.s.o/api/diffusion?attachments%5Buris%5D=1", "something"), + ( + "https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1", # noqa + "something else", + ), ]: for url in [encoded_url, unquote(encoded_url)]: response = requests.get(url) @@ -26,95 +27,91 @@ def test_get_response_cb_with_visits_nominal(requests_mock_datadir_visits): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('http://example.com/something.json') + response = requests.get("http://example.com/something.json") assert response.ok assert response.json() == "something" - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'world'} + assert response.json() == {"hello": "world"} - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_with_visits(requests_mock_datadir_visits): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('https://example.com/other.json') + response = requests.get("https://example.com/other.json") assert response.ok assert response.json() == "foobar" - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'world'} + assert response.json() == {"hello": "world"} - response = requests.get('https://example.com/other.json') + response = requests.get("https://example.com/other.json") assert not response.ok assert response.status_code == 404 - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_get_response_cb_no_visit(requests_mock_datadir): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert response.ok - assert response.json() == {'hello': 'you'} + assert response.json() == {"hello": "you"} def test_get_response_cb_query_params(requests_mock_datadir): - response = requests.get('https://example.com/file.json?toto=42') + response = requests.get("https://example.com/file.json?toto=42") assert not response.ok assert response.status_code == 404 - response = requests.get( - 'https://example.com/file.json?name=doe&firstname=jane') + response = requests.get("https://example.com/file.json?name=doe&firstname=jane") assert response.ok - assert response.json() == {'hello': 'jane doe'} + assert response.json() == {"hello": "jane doe"} requests_mock_datadir_ignore = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=False, + ignore_urls=["https://example.com/file.json"], has_multi_visit=False ) def test_get_response_cb_ignore_url(requests_mock_datadir_ignore): - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 requests_mock_datadir_ignore_and_visit = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=True, + ignore_urls=["https://example.com/file.json"], has_multi_visit=True ) -def test_get_response_cb_ignore_url_with_visit( - requests_mock_datadir_ignore_and_visit): - response = requests.get('https://example.com/file.json') +def test_get_response_cb_ignore_url_with_visit(requests_mock_datadir_ignore_and_visit): + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 - response = requests.get('https://example.com/file.json') + response = requests.get("https://example.com/file.json") assert not response.ok assert response.status_code == 404 def test_data_dir(datadir): - expected_datadir = path.join(path.abspath(path.dirname(__file__)), 'data') + expected_datadir = path.join(path.abspath(path.dirname(__file__)), "data") assert datadir == expected_datadir diff --git a/swh/core/tests/test_statsd.py b/swh/core/tests/test_statsd.py --- a/swh/core/tests/test_statsd.py +++ b/swh/core/tests/test_statsd.py @@ -76,7 +76,7 @@ def recv(self): try: - return self.payloads.popleft().decode('utf-8') + return self.payloads.popleft().decode("utf-8") except IndexError: return None @@ -98,7 +98,6 @@ class TestStatsd(unittest.TestCase): - def setUp(self): """ Set up a default Statsd instance and mock the socket. @@ -111,139 +110,130 @@ return self.statsd.socket.recv() def test_set(self): - self.statsd.set('set', 123) - assert self.recv() == 'set:123|s' + self.statsd.set("set", 123) + assert self.recv() == "set:123|s" def test_gauge(self): - self.statsd.gauge('gauge', 123.4) - assert self.recv() == 'gauge:123.4|g' + self.statsd.gauge("gauge", 123.4) + assert self.recv() == "gauge:123.4|g" def test_counter(self): - self.statsd.increment('page.views') - self.assertEqual('page.views:1|c', self.recv()) + self.statsd.increment("page.views") + self.assertEqual("page.views:1|c", self.recv()) - self.statsd.increment('page.views', 11) - self.assertEqual('page.views:11|c', self.recv()) + self.statsd.increment("page.views", 11) + self.assertEqual("page.views:11|c", self.recv()) - self.statsd.decrement('page.views') - self.assertEqual('page.views:-1|c', self.recv()) + self.statsd.decrement("page.views") + self.assertEqual("page.views:-1|c", self.recv()) - self.statsd.decrement('page.views', 12) - self.assertEqual('page.views:-12|c', self.recv()) + self.statsd.decrement("page.views", 12) + self.assertEqual("page.views:-12|c", self.recv()) def test_histogram(self): - self.statsd.histogram('histo', 123.4) - self.assertEqual('histo:123.4|h', self.recv()) + self.statsd.histogram("histo", 123.4) + self.assertEqual("histo:123.4|h", self.recv()) def test_tagged_gauge(self): - self.statsd.gauge('gt', 123.4, tags={'country': 'china', 'age': 45}) - self.assertEqual('gt:123.4|g|#age:45,country:china', self.recv()) + self.statsd.gauge("gt", 123.4, tags={"country": "china", "age": 45}) + self.assertEqual("gt:123.4|g|#age:45,country:china", self.recv()) def test_tagged_counter(self): - self.statsd.increment('ct', tags={'country': 'españa'}) - self.assertEqual('ct:1|c|#country:españa', self.recv()) + self.statsd.increment("ct", tags={"country": "españa"}) + self.assertEqual("ct:1|c|#country:españa", self.recv()) def test_tagged_histogram(self): - self.statsd.histogram('h', 1, tags={'test_tag': 'tag_value'}) - self.assertEqual('h:1|h|#test_tag:tag_value', self.recv()) + self.statsd.histogram("h", 1, tags={"test_tag": "tag_value"}) + self.assertEqual("h:1|h|#test_tag:tag_value", self.recv()) def test_sample_rate(self): - self.statsd.increment('c', sample_rate=0) + self.statsd.increment("c", sample_rate=0) assert not self.recv() for i in range(10000): - self.statsd.increment('sampled_counter', sample_rate=0.3) + self.statsd.increment("sampled_counter", sample_rate=0.3) self.assert_almost_equal(3000, len(self.statsd.socket.payloads), 150) - self.assertEqual('sampled_counter:1|c|@0.3', self.recv()) + self.assertEqual("sampled_counter:1|c|@0.3", self.recv()) def test_tags_and_samples(self): for i in range(100): - self.statsd.gauge('gst', 23, tags={"sampled": True}, - sample_rate=0.9) + self.statsd.gauge("gst", 23, tags={"sampled": True}, sample_rate=0.9) self.assert_almost_equal(90, len(self.statsd.socket.payloads), 10) - self.assertEqual('gst:23|g|@0.9|#sampled:True', self.recv()) + self.assertEqual("gst:23|g|@0.9|#sampled:True", self.recv()) def test_timing(self): - self.statsd.timing('t', 123) - self.assertEqual('t:123|ms', self.recv()) + self.statsd.timing("t", 123) + self.assertEqual("t:123|ms", self.recv()) def test_metric_namespace(self): """ Namespace prefixes all metric names. """ self.statsd.namespace = "foo" - self.statsd.gauge('gauge', 123.4) - self.assertEqual('foo.gauge:123.4|g', self.recv()) + self.statsd.gauge("gauge", 123.4) + self.assertEqual("foo.gauge:123.4|g", self.recv()) # Test Client level constant tags def test_gauge_constant_tags(self): - self.statsd.constant_tags = { - 'bar': 'baz', - } - self.statsd.gauge('gauge', 123.4) - assert self.recv() == 'gauge:123.4|g|#bar:baz' + self.statsd.constant_tags = {"bar": "baz"} + self.statsd.gauge("gauge", 123.4) + assert self.recv() == "gauge:123.4|g|#bar:baz" def test_counter_constant_tag_with_metric_level_tags(self): - self.statsd.constant_tags = { - 'bar': 'baz', - 'foo': True, - } - self.statsd.increment('page.views', tags={'extra': 'extra'}) - self.assertEqual( - 'page.views:1|c|#bar:baz,extra:extra,foo:True', - self.recv(), - ) + self.statsd.constant_tags = {"bar": "baz", "foo": True} + self.statsd.increment("page.views", tags={"extra": "extra"}) + self.assertEqual("page.views:1|c|#bar:baz,extra:extra,foo:True", self.recv()) def test_gauge_constant_tags_with_metric_level_tags_twice(self): - metric_level_tag = {'foo': 'bar'} - self.statsd.constant_tags = {'bar': 'baz'} - self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) - assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' + metric_level_tag = {"foo": "bar"} + self.statsd.constant_tags = {"bar": "baz"} + self.statsd.gauge("gauge", 123.4, tags=metric_level_tag) + assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar" # sending metrics multiple times with same metric-level tags # should not duplicate the tags being sent - self.statsd.gauge('gauge', 123.4, tags=metric_level_tag) - assert self.recv() == 'gauge:123.4|g|#bar:baz,foo:bar' + self.statsd.gauge("gauge", 123.4, tags=metric_level_tag) + assert self.recv() == "gauge:123.4|g|#bar:baz,foo:bar" def assert_almost_equal(self, a, b, delta): self.assertTrue( - 0 <= abs(a - b) <= delta, - "%s - %s not within %s" % (a, b, delta) + 0 <= abs(a - b) <= delta, "%s - %s not within %s" % (a, b, delta) ) def test_socket_error(self): self.statsd._socket = BrokenSocket() - self.statsd.gauge('no error', 1) - assert True, 'success' + self.statsd.gauge("no error", 1) + assert True, "success" def test_socket_timeout(self): self.statsd._socket = SlowSocket() - self.statsd.gauge('no error', 1) - assert True, 'success' + self.statsd.gauge("no error", 1) + assert True, "success" def test_timed(self): """ Measure the distribution of a function's run time. """ - @self.statsd.timed('timed.test') + + @self.statsd.timed("timed.test") def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a, b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed.test", name) self.assert_almost_equal(500, float(value), 100) def test_timed_exception(self): @@ -251,27 +241,28 @@ Exception bubble out of the decorator and is reported to statsd as a dedicated counter. """ - @self.statsd.timed('timed.test') + + @self.statsd.timed("timed.test") def func(a, b, c=1, d=1): """docstring""" time.sleep(0.5) return (a / b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) with self.assertRaises(ZeroDivisionError): func(1, 0) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('c', type_) - self.assertEqual('timed.test_error_count', name) + self.assertEqual("c", type_) + self.assertEqual("timed.test_error_count", name) self.assertEqual(int(value), 1) - def test_timed_no_metric(self, ): + def test_timed_no_metric(self,): """ Test using a decorator without providing a metric. """ @@ -282,19 +273,19 @@ time.sleep(0.5) return (a, b, c, d) - self.assertEqual('func', func.__name__) - self.assertEqual('docstring', func.__doc__) + self.assertEqual("func", func.__name__) + self.assertEqual("docstring", func.__doc__) result = func(1, 2, d=3) # Assert it handles args and kwargs correctly. self.assertEqual(result, (1, 2, 1, 3)) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('swh.core.tests.test_statsd.func', name) + self.assertEqual("ms", type_) + self.assertEqual("swh.core.tests.test_statsd.func", name) self.assert_almost_equal(500, float(value), 100) def test_timed_coroutine(self): @@ -305,7 +296,7 @@ """ import asyncio - @self.statsd.timed('timed.test') + @self.statsd.timed("timed.test") @asyncio.coroutine def print_foo(): """docstring""" @@ -318,11 +309,11 @@ # Assert packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed.test", name) self.assert_almost_equal(500, float(value), 100) def test_timed_context(self): @@ -330,16 +321,16 @@ Measure the distribution of a context's run time. """ # In milliseconds - with self.statsd.timed('timed_context.test') as timer: + with self.statsd.timed("timed_context.test") as timer: self.assertIsInstance(timer, TimedContextManagerDecorator) time.sleep(0.5) packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed_context.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed_context.test", name) self.assert_almost_equal(500, float(value), 100) self.assert_almost_equal(500, timer.elapsed, 100) @@ -348,11 +339,12 @@ Exception bubbles out of the `timed` context manager and is reported to statsd as a dedicated counter. """ + class ContextException(Exception): pass def func(self): - with self.statsd.timed('timed_context.test'): + with self.statsd.timed("timed_context.test"): time.sleep(0.5) raise ContextException() @@ -361,11 +353,11 @@ # Ensure the timing was recorded. packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('c', type_) - self.assertEqual('timed_context.test_error_count', name) + self.assertEqual("c", type_) + self.assertEqual("timed_context.test_error_count", name) self.assertEqual(int(value), 1) def test_timed_context_no_metric_name_exception(self): @@ -385,51 +377,51 @@ self.assertEqual(packet, None) def test_timed_start_stop_calls(self): - timer = self.statsd.timed('timed_context.test') + timer = self.statsd.timed("timed_context.test") timer.start() time.sleep(0.5) timer.stop() packet = self.recv() - name_value, type_ = packet.split('|') - name, value = name_value.split(':') + name_value, type_ = packet.split("|") + name, value = name_value.split(":") - self.assertEqual('ms', type_) - self.assertEqual('timed_context.test', name) + self.assertEqual("ms", type_) + self.assertEqual("timed_context.test", name) self.assert_almost_equal(500, float(value), 100) def test_batched(self): self.statsd.open_buffer() - self.statsd.gauge('page.views', 123) - self.statsd.timing('timer', 123) + self.statsd.gauge("page.views", 123) + self.statsd.timing("timer", 123) self.statsd.close_buffer() - self.assertEqual('page.views:123|g\ntimer:123|ms', self.recv()) + self.assertEqual("page.views:123|g\ntimer:123|ms", self.recv()) def test_context_manager(self): fake_socket = FakeSocket() with Statsd() as statsd: statsd._socket = fake_socket - statsd.gauge('page.views', 123) - statsd.timing('timer', 123) + statsd.gauge("page.views", 123) + statsd.timing("timer", 123) - self.assertEqual('page.views:123|g\ntimer:123|ms', fake_socket.recv()) + self.assertEqual("page.views:123|g\ntimer:123|ms", fake_socket.recv()) def test_batched_buffer_autoflush(self): fake_socket = FakeSocket() with Statsd() as statsd: statsd._socket = fake_socket for i in range(51): - statsd.increment('mycounter') + statsd.increment("mycounter") self.assertEqual( - '\n'.join(['mycounter:1|c' for i in range(50)]), - fake_socket.recv(), + "\n".join(["mycounter:1|c" for i in range(50)]), fake_socket.recv() ) - self.assertEqual('mycounter:1|c', fake_socket.recv()) + self.assertEqual("mycounter:1|c", fake_socket.recv()) def test_module_level_instance(self): from swh.core.statsd import statsd + self.assertTrue(isinstance(statsd, Statsd)) def test_instantiating_does_not_connect(self): @@ -451,113 +443,111 @@ self.assertNotEqual(FakeSocket(), local_statsd.socket) def test_tags_from_environment(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'country:china,age:45' + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "country:china,age:45" statsd = Statsd() statsd._socket = FakeSocket() - statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:china', - statsd.socket.recv()) + statsd.gauge("gt", 123.4) + self.assertEqual("gt:123.4|g|#age:45,country:china", statsd.socket.recv()) def test_tags_from_environment_and_constant(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'country:china,age:45' - statsd = Statsd(constant_tags={'country': 'canada'}) + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "country:china,age:45" + statsd = Statsd(constant_tags={"country": "canada"}) statsd._socket = FakeSocket() - statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:canada', - statsd.socket.recv()) + statsd.gauge("gt", 123.4) + self.assertEqual("gt:123.4|g|#age:45,country:canada", statsd.socket.recv()) def test_tags_from_environment_warning(self): - with preserve_envvars('STATSD_TAGS'): - os.environ['STATSD_TAGS'] = 'valid:tag,invalid_tag' + with preserve_envvars("STATSD_TAGS"): + os.environ["STATSD_TAGS"] = "valid:tag,invalid_tag" with pytest.warns(UserWarning) as record: statsd = Statsd() assert len(record) == 1 - assert 'invalid_tag' in record[0].message.args[0] - assert 'valid:tag' not in record[0].message.args[0] - assert statsd.constant_tags == {'valid': 'tag'} + assert "invalid_tag" in record[0].message.args[0] + assert "valid:tag" not in record[0].message.args[0] + assert statsd.constant_tags == {"valid": "tag"} def test_gauge_doesnt_send_none(self): - self.statsd.gauge('metric', None) + self.statsd.gauge("metric", None) assert self.recv() is None def test_increment_doesnt_send_none(self): - self.statsd.increment('metric', None) + self.statsd.increment("metric", None) assert self.recv() is None def test_decrement_doesnt_send_none(self): - self.statsd.decrement('metric', None) + self.statsd.decrement("metric", None) assert self.recv() is None def test_timing_doesnt_send_none(self): - self.statsd.timing('metric', None) + self.statsd.timing("metric", None) assert self.recv() is None def test_histogram_doesnt_send_none(self): - self.statsd.histogram('metric', None) + self.statsd.histogram("metric", None) assert self.recv() is None def test_param_host(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = 'test-value' - os.environ['STATSD_PORT'] = '' - local_statsd = Statsd(host='actual-test-value') + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "test-value" + os.environ["STATSD_PORT"] = "" + local_statsd = Statsd(host="actual-test-value") - self.assertEqual(local_statsd.host, 'actual-test-value') + self.assertEqual(local_statsd.host, "actual-test-value") self.assertEqual(local_statsd.port, 8125) def test_param_port(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = '' - os.environ['STATSD_PORT'] = '12345' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "" + os.environ["STATSD_PORT"] = "12345" local_statsd = Statsd(port=4321) - self.assertEqual(local_statsd.host, 'localhost') + self.assertEqual(local_statsd.host, "localhost") self.assertEqual(local_statsd.port, 4321) def test_envvar_host(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = 'test-value' - os.environ['STATSD_PORT'] = '' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "test-value" + os.environ["STATSD_PORT"] = "" local_statsd = Statsd() - self.assertEqual(local_statsd.host, 'test-value') + self.assertEqual(local_statsd.host, "test-value") self.assertEqual(local_statsd.port, 8125) def test_envvar_port(self): - with preserve_envvars('STATSD_HOST', 'STATSD_PORT'): - os.environ['STATSD_HOST'] = '' - os.environ['STATSD_PORT'] = '12345' + with preserve_envvars("STATSD_HOST", "STATSD_PORT"): + os.environ["STATSD_HOST"] = "" + os.environ["STATSD_PORT"] = "12345" local_statsd = Statsd() - self.assertEqual(local_statsd.host, 'localhost') + self.assertEqual(local_statsd.host, "localhost") self.assertEqual(local_statsd.port, 12345) def test_namespace_added(self): - local_statsd = Statsd(namespace='test-namespace') + local_statsd = Statsd(namespace="test-namespace") local_statsd._socket = FakeSocket() - local_statsd.gauge('gauge', 123.4) - assert local_statsd.socket.recv() == 'test-namespace.gauge:123.4|g' + local_statsd.gauge("gauge", 123.4) + assert local_statsd.socket.recv() == "test-namespace.gauge:123.4|g" def test_contextmanager_empty(self): with self.statsd: - assert True, 'success' + assert True, "success" def test_contextmanager_buffering(self): with self.statsd as s: - s.gauge('gauge', 123.4) - s.gauge('gauge_other', 456.78) + s.gauge("gauge", 123.4) + s.gauge("gauge_other", 456.78) self.assertIsNone(s.socket.recv()) - self.assertEqual(self.recv(), 'gauge:123.4|g\ngauge_other:456.78|g') + self.assertEqual(self.recv(), "gauge:123.4|g\ngauge_other:456.78|g") def test_timed_elapsed(self): - with self.statsd.timed('test_timer') as t: + with self.statsd.timed("test_timer") as t: pass self.assertGreaterEqual(t.elapsed, 0) - self.assertEqual(self.recv(), 'test_timer:%s|ms' % t.elapsed) + self.assertEqual(self.recv(), "test_timer:%s|ms" % t.elapsed) diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -27,51 +27,52 @@ def test_compress_uncompress_zip(tmp_path): - tocompress = tmp_path / 'compressme' + tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): - fpath = tocompress / ('file%s.txt' % i) - fpath.write_text('content of file %s' % i) + fpath = tocompress / ("file%s.txt" % i) + fpath.write_text("content of file %s" % i) - zipfile = tmp_path / 'archive.zip' - tarball.compress(str(zipfile), 'zip', str(tocompress)) + zipfile = tmp_path / "archive.zip" + tarball.compress(str(zipfile), "zip", str(tocompress)) - destdir = tmp_path / 'destdir' + destdir = tmp_path / "destdir" tarball.uncompress(str(zipfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) - assert ['file%s.txt' % i for i in range(10)] == lsdir + assert ["file%s.txt" % i for i in range(10)] == lsdir def test_compress_uncompress_tar(tmp_path): - tocompress = tmp_path / 'compressme' + tocompress = tmp_path / "compressme" tocompress.mkdir() for i in range(10): - fpath = tocompress / ('file%s.txt' % i) - fpath.write_text('content of file %s' % i) + fpath = tocompress / ("file%s.txt" % i) + fpath.write_text("content of file %s" % i) - tarfile = tmp_path / 'archive.tar' - tarball.compress(str(tarfile), 'tar', str(tocompress)) + tarfile = tmp_path / "archive.tar" + tarball.compress(str(tarfile), "tar", str(tocompress)) - destdir = tmp_path / 'destdir' + destdir = tmp_path / "destdir" tarball.uncompress(str(tarfile), str(destdir)) lsdir = sorted(x.name for x in destdir.iterdir()) - assert ['file%s.txt' % i for i in range(10)] == lsdir + assert ["file%s.txt" % i for i in range(10)] == lsdir def test__unpack_tar_failure(tmp_path, datadir): """Unpack inexistent tarball should fail """ - tarpath = os.path.join(datadir, 'archives', 'inexistent-archive.tar.Z') + tarpath = os.path.join(datadir, "archives", "inexistent-archive.tar.Z") assert not os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, tmp_path) @@ -79,15 +80,16 @@ """Unpack Existent tarball into an inexistent folder should fail """ - filename = 'groff-1.02.tar.Z' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "groff-1.02.tar.Z" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) - extract_dir = os.path.join(tmp_path, 'dir', 'inexistent') + extract_dir = os.path.join(tmp_path, "dir", "inexistent") - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, extract_dir) @@ -95,13 +97,14 @@ """Unpack unsupported tarball should fail """ - filename = 'hello.zip' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "hello.zip" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f"Unable to uncompress {tarpath} to {tmp_path}" + ): tarball._unpack_tar(tarpath, tmp_path) @@ -109,8 +112,8 @@ """Unpack supported tarball into an existent folder should be ok """ - filename = 'groff-1.02.tar.Z' - tarpath = os.path.join(datadir, 'archives', filename) + filename = "groff-1.02.tar.Z" + tarpath = os.path.join(datadir, "archives", filename) assert os.path.exists(tarpath) @@ -144,19 +147,18 @@ """High level call uncompression on un/supported tarballs """ - archive_dir = os.path.join(datadir, 'archives') + archive_dir = os.path.join(datadir, "archives") tarfiles = os.listdir(archive_dir) tarpaths = [os.path.join(archive_dir, tarfile) for tarfile in tarfiles] unsupported_tarpaths = [] for t in tarpaths: - if t.endswith('.Z') or t.endswith('.x') or t.endswith('.lz'): + if t.endswith(".Z") or t.endswith(".x") or t.endswith(".lz"): unsupported_tarpaths.append(t) # not supported yet for tarpath in unsupported_tarpaths: - with pytest.raises(ValueError, - match=f'Problem during unpacking {tarpath}.'): + with pytest.raises(ValueError, match=f"Problem during unpacking {tarpath}."): tarball.uncompress(tarpath, dest=tmp_path) # register those unsupported formats diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -9,7 +9,6 @@ class UtilsLib(unittest.TestCase): - def test_grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) @@ -31,18 +30,22 @@ def test_grouper_with_stop_value(self): # given - actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2) + actual_data = utils.grouper(((i, i + 1) for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks - self.assertEqual(out, [ - [(0, 1), (1, 2)], - [(2, 3), (3, 4)], - [(4, 5), (5, 6)], - [(6, 7), (7, 8)], - [(8, 9)]]) + self.assertEqual( + out, + [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)], + ], + ) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) @@ -54,87 +57,70 @@ self.assertEqual(out, [[9, 8, 7, 6], [5, 4, 3, 2], [1]]) def test_backslashescape_errors(self): - raw_data_err = b'abcd\x80' + raw_data_err = b"abcd\x80" with self.assertRaises(UnicodeDecodeError): - raw_data_err.decode('utf-8', 'strict') + raw_data_err.decode("utf-8", "strict") - self.assertEqual( - raw_data_err.decode('utf-8', 'backslashescape'), - 'abcd\\x80', - ) + self.assertEqual(raw_data_err.decode("utf-8", "backslashescape"), "abcd\\x80") - raw_data_ok = b'abcd\xc3\xa9' + raw_data_ok = b"abcd\xc3\xa9" self.assertEqual( - raw_data_ok.decode('utf-8', 'backslashescape'), - raw_data_ok.decode('utf-8', 'strict'), + raw_data_ok.decode("utf-8", "backslashescape"), + raw_data_ok.decode("utf-8", "strict"), ) - unicode_data = 'abcdef\u00a3' + unicode_data = "abcdef\u00a3" self.assertEqual( - unicode_data.encode('ascii', 'backslashescape'), - b'abcdef\\xa3', + unicode_data.encode("ascii", "backslashescape"), b"abcdef\\xa3" ) def test_encode_with_unescape(self): - valid_data = '\\x01020304\\x00' - valid_data_encoded = b'\x01020304\x00' + valid_data = "\\x01020304\\x00" + valid_data_encoded = b"\x01020304\x00" - self.assertEqual( - valid_data_encoded, - utils.encode_with_unescape(valid_data) - ) + self.assertEqual(valid_data_encoded, utils.encode_with_unescape(valid_data)) def test_encode_with_unescape_invalid_escape(self): - invalid_data = 'test\\abcd' + invalid_data = "test\\abcd" with self.assertRaises(ValueError) as exc: utils.encode_with_unescape(invalid_data) - self.assertIn('invalid escape', exc.exception.args[0]) - self.assertIn('position 4', exc.exception.args[0]) + self.assertIn("invalid escape", exc.exception.args[0]) + self.assertIn("position 4", exc.exception.args[0]) def test_decode_with_escape(self): - backslashes = b'foo\\bar\\\\baz' - backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' + backslashes = b"foo\\bar\\\\baz" + backslashes_escaped = "foo\\\\bar\\\\\\\\baz" - self.assertEqual( - backslashes_escaped, - utils.decode_with_escape(backslashes), - ) + self.assertEqual(backslashes_escaped, utils.decode_with_escape(backslashes)) - valid_utf8 = b'foo\xc3\xa2' - valid_utf8_escaped = 'foo\u00e2' + valid_utf8 = b"foo\xc3\xa2" + valid_utf8_escaped = "foo\u00e2" - self.assertEqual( - valid_utf8_escaped, - utils.decode_with_escape(valid_utf8), - ) + self.assertEqual(valid_utf8_escaped, utils.decode_with_escape(valid_utf8)) - invalid_utf8 = b'foo\xa2' - invalid_utf8_escaped = 'foo\\xa2' + invalid_utf8 = b"foo\xa2" + invalid_utf8_escaped = "foo\\xa2" - self.assertEqual( - invalid_utf8_escaped, - utils.decode_with_escape(invalid_utf8), - ) + self.assertEqual(invalid_utf8_escaped, utils.decode_with_escape(invalid_utf8)) - valid_utf8_nul = b'foo\xc3\xa2\x00' - valid_utf8_nul_escaped = 'foo\u00e2\\x00' + valid_utf8_nul = b"foo\xc3\xa2\x00" + valid_utf8_nul_escaped = "foo\u00e2\\x00" self.assertEqual( - valid_utf8_nul_escaped, - utils.decode_with_escape(valid_utf8_nul), + valid_utf8_nul_escaped, utils.decode_with_escape(valid_utf8_nul) ) def test_commonname(self): # when - actual_commonname = utils.commonname('/some/where/to/', - '/some/where/to/go/to') + actual_commonname = utils.commonname("/some/where/to/", "/some/where/to/go/to") # then - self.assertEqual('go/to', actual_commonname) + self.assertEqual("go/to", actual_commonname) # when - actual_commonname2 = utils.commonname(b'/some/where/to/', - b'/some/where/to/go/to') + actual_commonname2 = utils.commonname( + b"/some/where/to/", b"/some/where/to/go/to" + ) # then - self.assertEqual(b'go/to', actual_commonname2) + self.assertEqual(b"go/to", actual_commonname2) diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -51,14 +51,14 @@ def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): - bad_data = exception.object[exception.start:exception.end] - escaped = ''.join(r'\x%02x' % x for x in bad_data) + bad_data = exception.object[exception.start : exception.end] + escaped = "".join(r"\x%02x" % x for x in bad_data) return escaped, exception.end return codecs.backslashreplace_errors(exception) -codecs.register_error('backslashescape', backslashescape_errors) +codecs.register_error("backslashescape", backslashescape_errors) def encode_with_unescape(value): @@ -68,17 +68,18 @@ odd_backslashes = False i = 0 while i < len(value): - if value[i] == '\\': + if value[i] == "\\": odd_backslashes = not odd_backslashes else: if odd_backslashes: - if value[i] != 'x': - raise ValueError('invalid escape for %r at position %d' % - (value, i-1)) + if value[i] != "x": + raise ValueError( + "invalid escape for %r at position %d" % (value, i - 1) + ) slices.append( - value[start:i-1].replace('\\\\', '\\').encode('utf-8') + value[start : i - 1].replace("\\\\", "\\").encode("utf-8") ) - slices.append(bytes.fromhex(value[i+1:i+3])) + slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 @@ -86,11 +87,9 @@ i += 1 - slices.append( - value[start:i].replace('\\\\', '\\').encode('utf-8') - ) + slices.append(value[start:i].replace("\\\\", "\\").encode("utf-8")) - return b''.join(slices) + return b"".join(slices) def decode_with_escape(value): @@ -99,9 +98,9 @@ strings. """ # escape backslashes - value = value.replace(b'\\', b'\\\\') - value = value.replace(b'\x00', b'\\x00') - return value.decode('utf-8', 'backslashescape') + value = value.replace(b"\\", b"\\\\") + value = value.replace(b"\x00", b"\\x00") + return value.decode("utf-8", "backslashescape") def commonname(path0, path1, as_str=False): @@ -120,5 +119,5 @@ Typically used to sort sql/nn-swh-xxx.sql files. """ - num, rem = re.match(r'(\d*)(.*)', fname).groups() + num, rem = re.match(r"(\d*)(.*)", fname).groups() return (num and int(num) or 99, rem)