diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,6 +22,12 @@ language: system types: [python] +- repo: https://github.com/python/black + rev: 19.10b0 + hooks: + - id: black + language_version: python3.7 + # unfortunately, we are far from being able to enable this... # - repo: https://github.com/PyCQA/pydocstyle.git # rev: 4.0.0 @@ -32,13 +38,7 @@ # entry: pydocstyle --convention=google # language: python # types: [python] - -# black requires py3.6+ -#- repo: https://github.com/python/black -# rev: 19.3b0 -# hooks: -# - id: black -# language_version: python3 +# #- repo: https://github.com/asottile/blacken-docs # rev: v1.0.0-1 # hooks: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +skip-string-normalization = true diff --git a/setup.cfg b/setup.cfg new file mode 100644 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[flake8] +# E203: whitespaces before ':' +# W503: line break before binary operator +ignore = E203,W503 +max-line-length = 88 diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -10,22 +10,26 @@ import pickle import requests -from typing import ( - Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, -) +from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException -from .serializers import (decode_response, - encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, - json_dumps, json_loads, - exception_to_dict) +from .serializers import ( + decode_response, + encode_data_client as encode_data, + msgpack_dumps, + msgpack_loads, + json_dumps, + json_loads, + exception_to_dict, +) -from .negotiation import (Formatter as FormatterBase, - Negotiator as NegotiatorBase, - negotiate as _negotiate) +from .negotiation import ( + Formatter as FormatterBase, + Negotiator as NegotiatorBase, + negotiate as _negotiate, +) logger = logging.getLogger(__name__) @@ -33,10 +37,12 @@ # support for content negotiation + class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( - self.accept_mimetypes, 'application/json') + self.accept_mimetypes, 'application/json' + ) def _abort(self, status_code, err=None): return abort(status_code, err) @@ -72,6 +78,7 @@ # base API classes + class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception @@ -80,8 +87,12 @@ response: HTTP response corresponding to the failure """ - def __init__(self, payload: Optional[Any] = None, - response: Optional[requests.Response] = None): + + def __init__( + self, + payload: Optional[Any] = None, + response: Optional[requests.Response] = None, + ): if payload is not None: super().__init__(payload) else: @@ -89,11 +100,16 @@ self.response = response def __str__(self): - if self.args and isinstance(self.args[0], dict) \ - and 'type' in self.args[0] and 'args' in self.args[0]: + if ( + self.args + and isinstance(self.args[0], dict) + and 'type' in self.args[0] + and 'args' in self.args[0] + ): return ( f'') + f'{self.args[0]["type"]}: {self.args[0]["args"]}>' + ) else: return super().__str__() @@ -102,14 +118,15 @@ def dec(f): f._endpoint_path = path return f + return dec class APIError(Exception): """API Error""" + def __str__(self): - return ('An unexpected error occurred in the backend: {}' - .format(self.args)) + return 'An unexpected error occurred in the backend: {}'.format(self.args) class MetaRPCClient(type): @@ -117,6 +134,7 @@ of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" + def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new @@ -142,8 +160,7 @@ @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters - post_data = inspect.getcallargs( - wrapped_meth, *args, **kwargs) + post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') @@ -152,6 +169,7 @@ # Send the request. return self.post(meth._endpoint_path, post_data) + if meth_name not in attributes: attributes[meth_name] = meth_ @@ -186,9 +204,15 @@ """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, url, api_exception=None, - timeout=None, chunk_size=4096, - reraise_exceptions=None, **kwargs): + def __init__( + self, + url, + api_exception=None, + timeout=None, + chunk_size=4096, + reraise_exceptions=None, + **kwargs, + ): if api_exception: self.api_exception = api_exception if reraise_exceptions: @@ -199,7 +223,8 @@ adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), - pool_maxsize=kwargs.get('pool_maxsize', 100)) + pool_maxsize=kwargs.get('pool_maxsize', 100), + ) self.session.mount(self.url, adapter) self.timeout = timeout @@ -216,10 +241,7 @@ if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: - return getattr(self.session, verb)( - self._url(endpoint), - **opts - ) + return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) @@ -230,12 +252,16 @@ data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( - 'post', endpoint, data=data, - headers={'content-type': 'application/x-msgpack', - 'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + 'post', + endpoint, + data=data, + headers={ + 'content-type': 'application/x-msgpack', + 'accept': 'application/x-msgpack', + }, + **opts, + ) + if opts.get('stream') or response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: @@ -249,11 +275,9 @@ def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( - 'get', endpoint, - headers={'accept': 'application/x-msgpack'}, - **opts) - if opts.get('stream') or \ - response.headers.get('transfer-encoding') == 'chunked': + 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts + ) + if opts.get('stream') or response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: @@ -286,8 +310,9 @@ exception = exc_type(*data['exception']['args']) break else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data['exception'], response=response + ) else: exception = pickle.loads(data) @@ -296,8 +321,9 @@ if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: - exception = RemoteException(payload=data['exception'], - response=response) + exception = RemoteException( + payload=data['exception'], response=response + ) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) @@ -308,13 +334,13 @@ if status_class != 2: raise RemoteException( payload=f'API HTTP error: {status_code} {response.content}', - response=response) + response=response, + ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) - return decode_response( - response, extra_decoders=self.extra_type_decoders) + return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) @@ -322,6 +348,7 @@ class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" + encoding = 'utf-8' encoding_errors = 'surrogateescape' @@ -333,13 +360,10 @@ def encode_data_server( - data, content_type='application/x-msgpack', extra_type_encoders=None): - encoded_data = ENCODERS[content_type]( - data, extra_encoders=extra_type_encoders) - return Response( - encoded_data, - mimetype=content_type, - ) + data, content_type='application/x-msgpack', extra_type_encoders=None +): + encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) + return Response(encoded_data, mimetype=content_type) def decode_request(request, extra_decoders=None): @@ -355,8 +379,7 @@ # Should not be needed any more with py37 r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError('Wrong content type `%s` for API request' % content_type) return r @@ -385,6 +408,7 @@ `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ + request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] @@ -394,8 +418,7 @@ """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" - def __init__(self, *args, backend_class=None, backend_factory=None, - **kwargs): + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class @@ -409,13 +432,12 @@ def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request - @self.route('/'+meth._endpoint_path, methods=['POST']) + @self.route('/' + meth._endpoint_path, methods=['POST']) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request( - request, extra_decoders=self.extra_type_decoders) + kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -21,8 +21,7 @@ def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), - headers=multidict.MultiDict( - {'Content-Type': 'application/x-msgpack'}), + headers=multidict.MultiDict({'Content-Type': 'application/x-msgpack'}), **kwargs ) @@ -48,8 +47,7 @@ elif content_type == 'application/json': r = json_loads(data) else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) + raise ValueError('Wrong content type `%s` for API request' % content_type) return r @@ -67,6 +65,7 @@ else: status = 500 return encode_data_server(res, status=status) + return middleware_handler @@ -79,19 +78,20 @@ middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. - renderers = OrderedDict([ - ('application/json', render_json), - ('application/x-msgpack', render_msgpack), - ]) + renderers = OrderedDict( + [ + ('application/json', render_json), + ('application/x-msgpack', render_msgpack), + ] + ) nego_middleware = negotiation.negotiation_middleware( - renderers=renderers, - force_rendering=True) + renderers=renderers, force_rendering=True + ) middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) -@deprecated(version='0.0.64', - reason='Use the RPCServerApp instead') +@deprecated(version='0.0.64', reason='Use the RPCServerApp instead') class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/gunicorn_config.py b/swh/core/api/gunicorn_config.py --- a/swh/core/api/gunicorn_config.py +++ b/swh/core/api/gunicorn_config.py @@ -16,15 +16,24 @@ def post_fork( - server, worker, *, default_sentry_dsn=None, flask=True, - sentry_integrations=None, extra_sentry_kwargs={}): + server, + worker, + *, + default_sentry_dsn=None, + flask=True, + sentry_integrations=None, + extra_sentry_kwargs={} +): # Initializes sentry as soon as possible in gunicorn's worker processes. sentry_integrations = sentry_integrations or [] if flask: from sentry_sdk.integrations.flask import FlaskIntegration + sentry_integrations.append(FlaskIntegration()) init_sentry( - default_sentry_dsn, integrations=sentry_integrations, - extra_kwargs=extra_sentry_kwargs) + default_sentry_dsn, + integrations=sentry_integrations, + extra_kwargs=extra_sentry_kwargs, + ) diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -27,8 +27,7 @@ from decorator import decorator from inspect import getcallargs -from typing import Any, List, Optional, Callable, \ - Type, NoReturn, DefaultDict +from typing import Any, List, Optional, Callable, Type, NoReturn, DefaultDict from requests import Response @@ -47,8 +46,8 @@ self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( - "%s.mimetypes should be a non-empty list" % - self.__class__.__name__) + "%s.mimetypes should be a non-empty list" % self.__class__.__name__ + ) else: self.response_mimetype = request_mimetype @@ -57,11 +56,13 @@ def render(self, obj: Any) -> bytes: raise NotImplementedError( - "render() should be implemented by Formatter subclasses") + "render() should be implemented by Formatter subclasses" + ) def __call__(self, obj: Any) -> Response: return self._make_response( - self.render(obj), content_type=self.response_mimetype) + self.render(obj), content_type=self.response_mimetype + ) def _make_response(self, body: bytes, content_type: str) -> Response: raise NotImplementedError( @@ -71,7 +72,6 @@ class Negotiator: - def __init__(self, func: Callable[..., Any]) -> None: self.func = func self._formatters: List[Type[Formatter]] = [] @@ -90,38 +90,35 @@ return formatter(result) - def register_formatter(self, formatter: Type[Formatter], - *args, **kwargs) -> None: + def register_formatter(self, formatter: Type[Formatter], *args, **kwargs) -> None: self._formatters.append(formatter) - self._formatters_by_format[formatter.format].append( - (formatter, args, kwargs)) + self._formatters_by_format[formatter.format].append((formatter, args, kwargs)) for mimetype in formatter.mimetypes: - self._formatters_by_mimetype[mimetype].append( - (formatter, args, kwargs)) + self._formatters_by_mimetype[mimetype].append((formatter, args, kwargs)) - def get_formatter(self, format: Optional[str] = None, - mimetype: Optional[str] = None) -> Formatter: + def get_formatter( + self, format: Optional[str] = None, mimetype: Optional[str] = None + ) -> Formatter: if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " - "kwargs to be set") + "kwargs to be set" + ) if format is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_format[format][0]) + formatter_cls, args, kwargs = self._formatters_by_format[format][0] except IndexError: - raise FormatterNotFound( - "Formatter for format '%s' not found!" % format) + raise FormatterNotFound("Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific - formatter_cls, args, kwargs = ( - self._formatters_by_mimetype[mimetype][0]) + formatter_cls, args, kwargs = self._formatters_by_mimetype[mimetype][0] except IndexError: raise FormatterNotFound( - "Formatter for mimetype '%s' not found!" % mimetype) + "Formatter for mimetype '%s' not found!" % mimetype + ) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) @@ -144,8 +141,9 @@ ) -def negotiate(negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], - *args, **kwargs) -> Callable: +def negotiate( + negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], *args, **kwargs +) -> Callable: def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -21,13 +21,16 @@ ENCODERS = [ (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), (datetime.datetime, 'datetime', datetime.datetime.isoformat), - (datetime.timedelta, 'timedelta', lambda o: { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }), + ( + datetime.timedelta, + 'timedelta', + lambda o: { + 'days': o.days, + 'seconds': o.seconds, + 'microseconds': o.microseconds, + }, + ), (UUID, 'uuid', str), - # Only for JSON: (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), ] @@ -37,7 +40,6 @@ 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), 'timedelta': lambda d: datetime.timedelta(**d), 'uuid': UUID, - # Only for JSON: 'bytes': base64.b85decode, } @@ -47,24 +49,20 @@ try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: - raise ValueError('Limits were reached. Please, check your input.\n' + - str(e)) + raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content, - extra_decoders=extra_decoders) + r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json_loads(response.text, - extra_decoders=extra_decoders) + r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: - raise ValueError('Wrong content type `%s` for API response' - % content_type) + raise ValueError('Wrong content type `%s` for API response' % content_type) return r @@ -99,14 +97,10 @@ if extra_encoders: self.encoders += extra_encoders - def default(self, o: Any - ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: + def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): - return { - 'swhtype': type_name, - 'd': encoder(o), - } + return {'swhtype': type_name, 'd': encoder(o)} try: return super().default(o) except TypeError as e: @@ -164,13 +158,11 @@ def json_dumps(data: Any, extra_encoders=None) -> str: - return json.dumps(data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: - return json.loads(data, cls=SWHJSONDecoder, - extra_decoders=extra_decoders) + return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: @@ -185,10 +177,7 @@ for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): - return { - b'swhtype': type_name, - b'd': encoder(obj), - } + return {b'swhtype': type_name, b'd': encoder(obj)} return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) @@ -222,15 +211,13 @@ try: try: - return msgpack.unpackb(data, raw=False, - object_hook=decode_types, - strict_map_key=False) + return msgpack.unpackb( + data, raw=False, object_hook=decode_types, strict_map_key=False + ) except TypeError: # msgpack < 0.6.0 - return msgpack.unpackb(data, raw=False, - object_hook=decode_types) + return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 - return msgpack.unpackb(data, encoding='utf-8', - object_hook=decode_types) + return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) def exception_to_dict(exception): diff --git a/swh/core/api/tests/server_testing.py b/swh/core/api/tests/server_testing.py --- a/swh/core/api/tests/server_testing.py +++ b/swh/core/api/tests/server_testing.py @@ -29,6 +29,7 @@ class's setUp() and tearDown() methods. """ + def setUp(self): super().setUp() self.start_server() @@ -106,6 +107,7 @@ In order to correctly work, the subclass must call the parents class's setUp() and tearDown() methods. """ + def process_config(self): # WSGI app configuration for key, value in self.config.items(): @@ -136,9 +138,9 @@ class's setUp() and tearDown() methods. """ + def define_worker_function(self): def worker(app, port): - return aiohttp.web.run_app(app, port=int(port), - print=lambda *_: None) + return aiohttp.web.run_app(app, port=int(port), print=lambda *_: None) return worker diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -29,17 +29,17 @@ async def root(request): return Response('toor') -STRUCT = {'txt': 'something stupid', - # 'date': datetime.date(2019, 6, 9), # not supported - 'datetime': datetime.datetime(2019, 6, 9, 10, 12), - 'timedelta': datetime.timedelta(days=-2, hours=3), - 'int': 42, - 'float': 3.14, - 'subdata': {'int': 42, - 'datetime': datetime.datetime(2019, 6, 10, 11, 12), - }, - 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], - } + +STRUCT = { + 'txt': 'something stupid', + # 'date': datetime.date(2019, 6, 9), # not supported + 'datetime': datetime.datetime(2019, 6, 9, 10, 12), + 'timedelta': datetime.timedelta(days=-2, hours=3), + 'int': 42, + 'float': 3.14, + 'subdata': {'int': 42, 'datetime': datetime.datetime(2019, 6, 10, 11, 12)}, + 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], +} async def struct(request): @@ -137,8 +137,7 @@ """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'): - resp = await cli.get('/struct', - headers={'Accept': 'application/%s' % ctype}) + resp = await cli.get('/struct', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT @@ -151,7 +150,8 @@ resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps({'toto': 42})) + data=msgpack_dumps({'toto': 42}), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} @@ -159,7 +159,8 @@ resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, - data=msgpack_dumps(STRUCT)) + data=msgpack_dumps(STRUCT), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT @@ -172,7 +173,8 @@ resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, - data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) + data=json.dumps({'toto': 42}, cls=SWHJSONEncoder), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} @@ -180,7 +182,8 @@ resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') # assert resp.headers['Content-Type'] == 'application/x-msgpack' @@ -197,9 +200,12 @@ for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/%s' % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT @@ -215,9 +221,12 @@ for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo-no-nego', - headers={'Content-Type': 'application/json', - 'Accept': 'application/%s' % ctype}, - data=json.dumps(STRUCT, cls=SWHJSONEncoder)) + headers={ + 'Content-Type': 'application/json', + 'Accept': 'application/%s' % ctype, + }, + data=json.dumps(STRUCT, cls=SWHJSONEncoder), + ) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_gunicorn.py b/swh/core/api/tests/test_gunicorn.py --- a/swh/core/api/tests/test_gunicorn.py +++ b/swh/core/api/tests/test_gunicorn.py @@ -19,8 +19,9 @@ def test_post_fork_with_dsn_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): + with patch( + 'sentry_sdk.integrations.flask.FlaskIntegration', new=lambda: flask_integration + ): with patch('sentry_sdk.init') as sentry_sdk_init: with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): gunicorn_config.post_fork(None, None) @@ -36,12 +37,18 @@ def test_post_fork_with_package_env(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): + with patch( + 'sentry_sdk.integrations.flask.FlaskIntegration', new=lambda: flask_integration + ): with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_ENVIRONMENT': 'tests', - 'SWH_MAIN_PACKAGE': 'swh.core'}): + with patch.dict( + os.environ, + { + 'SWH_SENTRY_DSN': 'test_dsn', + 'SWH_SENTRY_ENVIRONMENT': 'tests', + 'SWH_MAIN_PACKAGE': 'swh.core', + }, + ): gunicorn_config.post_fork(None, None) version = pkg_resources.get_distribution('swh.core').version @@ -57,11 +64,13 @@ def test_post_fork_debug(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): + with patch( + 'sentry_sdk.integrations.flask.FlaskIntegration', new=lambda: flask_integration + ): with patch('sentry_sdk.init') as sentry_sdk_init: - with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1'}): + with patch.dict( + os.environ, {'SWH_SENTRY_DSN': 'test_dsn', 'SWH_SENTRY_DEBUG': '1'} + ): gunicorn_config.post_fork(None, None) sentry_sdk_init.assert_called_once_with( @@ -79,23 +88,23 @@ gunicorn_config.post_fork(None, None, flask=False) sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - integrations=[], - debug=False, - release=None, - environment=None, + dsn='test_dsn', integrations=[], debug=False, release=None, environment=None ) def test_post_fork_extras(): flask_integration = object() # unique object to check for equality - with patch('sentry_sdk.integrations.flask.FlaskIntegration', - new=lambda: flask_integration): + with patch( + 'sentry_sdk.integrations.flask.FlaskIntegration', new=lambda: flask_integration + ): with patch('sentry_sdk.init') as sentry_sdk_init: with patch.dict(os.environ, {'SWH_SENTRY_DSN': 'test_dsn'}): gunicorn_config.post_fork( - None, None, sentry_integrations=['foo'], - extra_sentry_kwargs={'bar': 'baz'}) + None, + None, + sentry_integrations=['foo'], + extra_sentry_kwargs={'bar': 'baz'}, + ) sentry_sdk_init.assert_called_once_with( dsn='test_dsn', diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -48,13 +48,13 @@ elif request.path == '/serializer_test': context.content = ( b'\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux' + ) else: assert False return context.content - requests_mock.post(re.compile('mock://example.com/'), - content=callback) + requests_mock.post(re.compile('mock://example.com/'), content=callback) return Testclient(url='mock://example.com') diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -58,9 +58,11 @@ # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py application = RPCServerApp('testapp', backend_class=RPCTest) + @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) + return application @@ -103,5 +105,7 @@ assert exc_info.value.args[0]['type'] == 'TypeError' assert exc_info.value.args[0]['args'] == ['Did I pass through?'] - assert str(exc_info.value) \ + assert ( + str(exc_info.value) == "" + ) diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -42,8 +42,7 @@ def test_api_endpoint(flask_app_client): res = flask_app_client.post( url_for('something'), - headers=[('Content-Type', 'application/json'), - ('Accept', 'application/json')], + headers=[('Content-Type', 'application/json'), ('Accept', 'application/json')], data=json.dumps({'data': 'toto'}), ) assert res.status_code == 200 @@ -64,8 +63,10 @@ def test_api_nego_accept(flask_app_client): res = flask_app_client.post( url_for('something'), - headers=[('Accept', 'application/x-msgpack'), - ('Content-Type', 'application/x-msgpack')], + headers=[ + ('Accept', 'application/x-msgpack'), + ('Content-Type', 'application/x-msgpack'), + ], data=msgpack.dumps({'data': 'toto'}), ) assert res.status_code == 200 @@ -76,9 +77,12 @@ def test_rpc_server(flask_app_client): res = flask_app_client.post( url_for('test_endpoint'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], - data=b'\x81\xa9test_data\xa4spam') + headers=[ + ('Content-Type', 'application/x-msgpack'), + ('Accept', 'application/x-msgpack'), + ], + data=b'\x81\xa9test_data\xa4spam', + ) assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' @@ -88,16 +92,19 @@ def test_rpc_server_extra_serializers(flask_app_client): res = flask_app_client.post( url_for('serializer_test'), - headers=[('Content-Type', 'application/x-msgpack'), - ('Accept', 'application/x-msgpack')], + headers=[ + ('Content-Type', 'application/x-msgpack'), + ('Accept', 'application/x-msgpack'), + ], data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' - b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + b'\xc4\x01d\x92\xa3bar\xc4\x03baz', + ) assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == ( - b'\x82\xc4\x07swhtype\xa9extratype\xc4' - b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') + b'\x82\xc4\x07swhtype\xa9extratype\xc4' b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux' + ) def test_api_negotiate_no_extra_encoders(app, flask_app_client): @@ -109,10 +116,7 @@ def endpoint(): return 'test' - res = flask_app_client.post( - url, - headers=[('Content-Type', 'application/json')], - ) + res = flask_app_client.post(url, headers=[('Content-Type', 'application/json')]) assert res.status_code == 200 assert res.mimetype == 'application/json' assert res.data == b'"test"' diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -18,7 +18,7 @@ SWHJSONEncoder, msgpack_dumps, msgpack_loads, - decode_response + decode_response, ) @@ -31,8 +31,10 @@ return f'ExtraType({self.arg1}, {self.arg2})' def __eq__(self, other): - return isinstance(other, ExtraType) \ - and (self.arg1, self.arg2) == (other.arg1, other.arg2) + return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( + other.arg1, + other.arg2, + ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ @@ -40,9 +42,7 @@ ] -extra_decoders = { - 'extratype': lambda o: ExtraType(*o), -} +extra_decoders = {'extratype': lambda o: ExtraType(*o)} class Serializers(unittest.TestCase): @@ -51,11 +51,13 @@ self.data = { 'bytes': b'123456789\x99\xaf\xff\x00\x12', - 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), - 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=self.tz), - 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, - tzinfo=datetime.timezone.utc), + 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231_455), + 'datetime_tz': datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz + ), + 'datetime_utc': datetime.datetime( + 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc + ), 'datetime_delta': datetime.timedelta(64), 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), 'swhtype': 'fake', @@ -66,22 +68,27 @@ self.encoded_data = { 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, - 'datetime_naive': {'swhtype': 'datetime', - 'd': '2015-01-01T12:04:42.231455'}, - 'datetime_tz': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+01:58'}, - 'datetime_utc': {'swhtype': 'datetime', - 'd': '2015-03-04T18:25:13.001234+00:00'}, - 'datetime_delta': {'swhtype': 'timedelta', - 'd': {'days': 64, 'seconds': 0, - 'microseconds': 0}}, - 'arrow_date': {'swhtype': 'arrow', - 'd': '2018-04-25T16:17:53.533672+00:00'}, + 'datetime_naive': { + 'swhtype': 'datetime', + 'd': '2015-01-01T12:04:42.231455', + }, + 'datetime_tz': { + 'swhtype': 'datetime', + 'd': '2015-03-04T18:25:13.001234+01:58', + }, + 'datetime_utc': { + 'swhtype': 'datetime', + 'd': '2015-03-04T18:25:13.001234+00:00', + }, + 'datetime_delta': { + 'swhtype': 'timedelta', + 'd': {'days': 64, 'seconds': 0, 'microseconds': 0}, + }, + 'arrow_date': {'swhtype': 'arrow', 'd': '2018-04-25T16:17:53.533672+00:00'}, 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, - 'uuid': {'swhtype': 'uuid', - 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, + 'uuid': {'swhtype': 'uuid', 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, } self.legacy_msgpack = { @@ -125,22 +132,20 @@ def test_round_trip_json_extra_types(self): original_data = [ExtraType('baz', self.data), 'qux'] - data = json.dumps(original_data, cls=SWHJSONEncoder, - extra_encoders=extra_encoders) + data = json.dumps( + original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders + ) self.assertEqual( original_data, - json.loads( - data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders), + ) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): - original_data = { - **self.data, - 'none_dict_key': {None: 42}, - } + original_data = {**self.data, 'none_dict_key': {None: 42}} data = msgpack_dumps(original_data) self.assertEqual(original_data, msgpack_loads(data)) @@ -149,7 +154,8 @@ data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( - original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + original_data, msgpack_loads(data, extra_decoders=extra_decoders) + ) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) @@ -161,9 +167,11 @@ @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): - mock_requests.get('https://example.org/test/data', - json=self.encoded_data, - headers={'content-type': 'application/json'}) + mock_requests.get( + 'https://example.org/test/data', + json=self.encoded_data, + headers={'content-type': 'application/json'}, + ) response = requests.get('https://example.org/test/data') assert decode_response(response) == self.data diff --git a/swh/core/cli/__init__.py b/swh/core/cli/__init__.py --- a/swh/core/cli/__init__.py +++ b/swh/core/cli/__init__.py @@ -52,26 +52,38 @@ @click.group( - context_settings=CONTEXT_SETTINGS, cls=AliasedGroup, + context_settings=CONTEXT_SETTINGS, + cls=AliasedGroup, option_notes='''\ If both options are present, --log-level will override the root logger configuration set in --log-config. The --log-config YAML must conform to the logging.config.dictConfig schema documented at https://docs.python.org/3/library/logging.config.html. -''' +''', +) +@click.option( + '--log-level', + '-l', + default=None, + type=click.Choice(LOG_LEVEL_NAMES), + help="Log level (defaults to INFO).", +) +@click.option( + '--log-config', + default=None, + type=click.File('r'), + help="Python yaml logging configuration file.", +) +@click.option( + '--sentry-dsn', default=None, help="DSN of the Sentry instance to report to" +) +@click.option( + '--sentry-debug/--no-sentry-debug', + default=False, + hidden=True, + help="Enable debugging of sentry", ) -@click.option('--log-level', '-l', default=None, - type=click.Choice(LOG_LEVEL_NAMES), - help="Log level (defaults to INFO).") -@click.option('--log-config', default=None, - type=click.File('r'), - help="Python yaml logging configuration file.") -@click.option('--sentry-dsn', default=None, - help="DSN of the Sentry instance to report to") -@click.option('--sentry-debug/--no-sentry-debug', - default=False, hidden=True, - help="Enable debugging of sentry") @click.pass_context def swh(ctx, log_level, log_config, sentry_dsn, sentry_debug): """Command line interface for Software Heritage. @@ -105,8 +117,7 @@ cmd = entry_point.load() swh.add_command(cmd, name=entry_point.name) except Exception as e: - logger.warning('Could not load subcommand %s: %s', - entry_point.name, str(e)) + logger.warning('Could not load subcommand %s: %s', entry_point.name, str(e)) return swh(auto_envvar_prefix='SWH') diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -21,9 +21,13 @@ @click.group(name="db", context_settings=CONTEXT_SETTINGS) -@click.option("--config-file", "-C", default=None, - type=click.Path(exists=True, dir_okay=False), - help="Configuration file.") +@click.option( + "--config-file", + "-C", + default=None, + type=click.Path(exists=True, dir_okay=False), + help="Configuration file.", +) @click.pass_context def db(ctx, config_file): """Software Heritage database generic tools. @@ -74,8 +78,8 @@ sqlfiles = get_sql_for_package(modname) except click.BadParameter: logger.info( - "Failed to load/find sql initialization files for %s", - modname) + "Failed to load/find sql initialization files for %s", modname + ) if sqlfiles: conninfo = cfg["args"]["db"] @@ -97,11 +101,19 @@ @click.command(context_settings=CONTEXT_SETTINGS) @click.argument('module', nargs=-1, required=True) -@click.option('--db-name', '-d', help='Database name.', - default='softwareheritage-dev', show_default=True) -@click.option('--create-db/--no-create-db', '-C', - help='Attempt to create the database.', - default=False) +@click.option( + '--db-name', + '-d', + help='Database name.', + default='softwareheritage-dev', + show_default=True, +) +@click.option( + '--create-db/--no-create-db', + '-C', + help='Attempt to create the database.', + default=False, +) def db_init(module, db_name, create_db): """Initialise a database for the Software Heritage . By default, does not attempt to create the database. @@ -122,8 +134,10 @@ # put import statements here so we can keep startup time of the main swh # command as short as possible from swh.core.db.tests.db_testing import ( - pg_createdb, pg_restore, DB_DUMP_TYPES, - swh_db_version + pg_createdb, + pg_restore, + DB_DUMP_TYPES, + swh_db_version, ) logger.debug('db_init %s dn_name=%s', module, db_name) @@ -138,8 +152,7 @@ # Try to retrieve the db version if any db_version = swh_db_version(db_name) if not db_version: # Initialize the db - dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) - for x in dump_files] + dump_files = [(x, DB_DUMP_TYPES[path.splitext(x)[1]]) for x in dump_files] for dump, dtype in dump_files: click.secho('Loading {}'.format(dump), fg='yellow') pg_restore(db_name, dump, dtype) @@ -149,8 +162,11 @@ # TODO: Ideally migrate the version from db_version to the latest # db version - click.secho('DONE database is {} version {}'.format(db_name, db_version), - fg='green', bold=True) + click.secho( + 'DONE database is {} version {}'.format(db_name, db_version), + fg='green', + bold=True, + ) def get_sql_for_package(modname): @@ -167,6 +183,6 @@ sqldir = path.join(path.dirname(m.__file__), "sql") if not path.isdir(sqldir): raise click.BadParameter( - "Module {} does not provide a db schema " - "(no sql/ dir)".format(modname)) + "Module {} does not provide a db schema " "(no sql/ dir)".format(modname) + ) return list(sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)) diff --git a/swh/core/config.py b/swh/core/config.py --- a/swh/core/config.py +++ b/swh/core/config.py @@ -16,11 +16,7 @@ logger = logging.getLogger(__name__) -SWH_CONFIG_DIRECTORIES = [ - '~/.config/swh', - '~/.swh', - '/etc/softwareheritage', -] +SWH_CONFIG_DIRECTORIES = ['~/.config/swh', '~/.swh', '/etc/softwareheritage'] SWH_GLOBAL_CONFIG = 'global.ini' @@ -29,10 +25,7 @@ 'log_db': ('str', 'dbname=softwareheritage-log'), } -SWH_CONFIG_EXTENSIONS = [ - '.yml', - '.ini', -] +SWH_CONFIG_EXTENSIONS = ['.yml', '.ini'] # conversion per type _map_convert_fn = { @@ -45,10 +38,8 @@ _map_check_fn = { 'int': lambda x: isinstance(x, int), 'bool': lambda x: isinstance(x, bool), - 'list[str]': lambda x: (isinstance(x, list) and - all(isinstance(y, str) for y in x)), - 'list[int]': lambda x: (isinstance(x, list) and - all(isinstance(y, int) for y in x)), + 'list[str]': lambda x: (isinstance(x, list) and all(isinstance(y, str) for y in x)), + 'list[int]': lambda x: (isinstance(x, list) and all(isinstance(y, int) for y in x)), } @@ -103,8 +94,7 @@ logger.info('Loading config file %s', ini_file) return config._sections['main'] else: - logger.warning('Ignoring config file %s (no [main] section)', - ini_file) + logger.warning('Ignoring config file %s (no [main] section)', ini_file) return {} @@ -112,8 +102,9 @@ def config_exists(config_path): """Check whether the given config exists""" basepath = config_basepath(config_path) - return any(exists_accessible(basepath + extension) - for extension in SWH_CONFIG_EXTENSIONS) + return any( + exists_accessible(basepath + extension) for extension in SWH_CONFIG_EXTENSIONS + ) def read(conf_file=None, default_conf=None): @@ -233,8 +224,7 @@ Note that no type checking is done for anything but dicts. """ if not isinstance(base, dict) or not isinstance(other, dict): - raise TypeError( - 'Cannot merge a %s with a %s' % (type(base), type(other))) + raise TypeError('Cannot merge a %s with a %s' % (type(base), type(other))) output = {} allkeys = set(chain(base.keys(), other.keys())) @@ -258,13 +248,13 @@ """Return the Software Heritage specific configuration paths for the given filename.""" - return [os.path.join(dirname, base_filename) - for dirname in SWH_CONFIG_DIRECTORIES] + return [os.path.join(dirname, base_filename) for dirname in SWH_CONFIG_DIRECTORIES] def prepare_folders(conf, *keys): """Prepare the folder mentioned in config under keys. """ + def makedir(folder): if not os.path.exists(folder): os.makedirs(folder) @@ -276,10 +266,7 @@ def load_global_config(): """Load the global Software Heritage config""" - return priority_read( - swh_config_paths(SWH_GLOBAL_CONFIG), - SWH_DEFAULT_GLOBAL_CONFIG, - ) + return priority_read(swh_config_paths(SWH_GLOBAL_CONFIG), SWH_DEFAULT_GLOBAL_CONFIG) def load_named_config(name, default_conf=None, global_conf=True): @@ -317,8 +304,13 @@ CONFIG_BASE_FILENAME = '' # type: Optional[str] @classmethod - def parse_config_file(cls, base_filename=None, config_filename=None, - additional_configs=None, global_config=True): + def parse_config_file( + cls, + base_filename=None, + config_filename=None, + additional_configs=None, + global_config=True, + ): """Parse the configuration file associated to the current class. By default, parse_config_file will load the configuration @@ -350,8 +342,9 @@ if not additional_configs: additional_configs = [] - full_default_config = merge_default_configs(cls.DEFAULT_CONFIG, - *additional_configs) + full_default_config = merge_default_configs( + cls.DEFAULT_CONFIG, *additional_configs + ) config = {} if global_config: diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -43,7 +43,8 @@ # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( - '%s%s,%s%s' % ( + '%s%s,%s%s' + % ( '[' if data.lower_inc else '(', '-infinity' if data.lower_inf else escape(data.lower), 'infinity' if data.upper_inf else escape(data.upper), @@ -74,12 +75,10 @@ def adapt_conn(cls, conn): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" - t_bytes = psycopg2.extensions.new_type( - (17,), "bytea", typecast_bytea) + t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) - t_bytes_array = psycopg2.extensions.new_array_type( - (1001,), "bytea[]", t_bytes) + t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod @@ -128,6 +127,7 @@ return cur_arg else: return self.conn.cursor() + _cursor = cursor # for bw compat @contextmanager @@ -147,8 +147,9 @@ self.conn.rollback() raise - def copy_to(self, items, tblname, columns, - cur=None, item_cb=None, default_values={}): + def copy_to( + self, items, tblname, columns, cur=None, item_cb=None, default_values={} + ): """Copy items' entries to table tblname with columns information. Args: @@ -170,8 +171,9 @@ cursor = self.cursor(cur) with open(read_file, 'r') as f: try: - cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( - tblname, ', '.join(columns)), f) + cursor.copy_expert( + 'COPY %s (%s) FROM STDIN CSV' % (tblname, ', '.join(columns)), f + ) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() @@ -193,7 +195,9 @@ logger.error( 'Could not escape value `%r` for column `%s`:' 'Received exception: `%s`', - value, k, e + value, + k, + e, ) raise e from None f.write(','.join(line)) diff --git a/swh/core/db/common.py b/swh/core/db/common.py --- a/swh/core/db/common.py +++ b/swh/core/db/common.py @@ -11,8 +11,7 @@ def decorator(f): sig = inspect.signature(f) params = sig.parameters - params = [param for param in params.values() - if param.name not in names] + params = [param for param in params.values() if param.name not in names] sig = sig.replace(parameters=params) f.__signature__ = sig return f @@ -41,10 +40,10 @@ Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction_generator for generator functions.') + raise ValueError('Use db_transaction_generator for generator functions.') @remove_kwargs(['cur', 'db']) @functools.wraps(meth) @@ -63,6 +62,7 @@ return meth(self, *args, db=db, cur=cur, **kwargs) finally: self.put_db(db) + return _meth return decorator @@ -76,10 +76,10 @@ Client options are passed as `set` options to the postgresql server """ + def decorator(meth, __client_options=client_options): if not inspect.isgeneratorfunction(meth): - raise ValueError( - 'Use db_transaction for non-generator functions.') + raise ValueError('Use db_transaction for non-generator functions.') @remove_kwargs(['cur', 'db']) @functools.wraps(meth) @@ -99,4 +99,5 @@ self.put_db(db) return _meth + return decorator diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py --- a/swh/core/db/db_utils.py +++ b/swh/core/db/db_utils.py @@ -26,13 +26,16 @@ not, the stored procedure will be executed first; the function body then. """ + def wrap(meth): @functools.wraps(meth) def _meth(self, *args, **kwargs): cur = kwargs.get('cur', None) self._cursor(cur).execute('SELECT %s()' % stored_proc) meth(self, *args, **kwargs) + return _meth + return wrap @@ -79,13 +82,14 @@ if curr is pre: curr = post else: - raise ValueError( - "the query contains more than one '%s' placeholder") + raise ValueError("the query contains more than one '%s' placeholder") elif token[1:] == b'%': curr.append(b'%') else: - raise ValueError("unsupported format character: '%s'" - % token[1:].decode('ascii', 'replace')) + raise ValueError( + "unsupported format character: '%s'" + % token[1:].decode('ascii', 'replace') + ) if curr is pre: raise ValueError("the query doesn't contain any '%s' placeholder") @@ -132,9 +136,7 @@ # there will be some decoding error because of stupid codec used, and Py3 # doesn't implement % on bytes. if not isinstance(sql, bytes): - sql = sql.encode( - psycopg2.extensions.encodings[cur.connection.encoding] - ) + sql = sql.encode(psycopg2.extensions.encodings[cur.connection.encoding]) pre, post = _split_sql(sql) for page in _paginate(argslist, page_size=page_size): diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py --- a/swh/core/db/tests/conftest.py +++ b/swh/core/db/tests/conftest.py @@ -1,2 +1,3 @@ import os + os.environ['LC_ALL'] = 'C.UTF-8' diff --git a/swh/core/db/tests/db_testing.py b/swh/core/db/tests/db_testing.py --- a/swh/core/db/tests/db_testing.py +++ b/swh/core/db/tests/db_testing.py @@ -30,14 +30,20 @@ """ query = 'select version from dbversion order by dbversion desc limit 1' cmd = [ - 'psql', '--tuples-only', '--no-psqlrc', '--quiet', - '-v', 'ON_ERROR_STOP=1', "--command=%s" % query, - dbname_or_service + 'psql', + '--tuples-only', + '--no-psqlrc', + '--quiet', + '-v', + 'ON_ERROR_STOP=1', + "--command=%s" % query, + dbname_or_service, ] try: - r = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, - universal_newlines=True) + r = subprocess.run( + cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True + ) result = int(r.stdout.strip()) except Exception: # db not initialized result = None @@ -53,19 +59,35 @@ """ assert dumptype in ['pg_dump', 'psql'] if dumptype == 'pg_dump': - subprocess.check_call(['pg_restore', '--no-owner', '--no-privileges', - '--dbname', dbname, dumpfile]) + subprocess.check_call( + [ + 'pg_restore', + '--no-owner', + '--no-privileges', + '--dbname', + dbname, + dumpfile, + ] + ) elif dumptype == 'psql': - subprocess.check_call(['psql', '--quiet', - '--no-psqlrc', - '-v', 'ON_ERROR_STOP=1', - '-f', dumpfile, - dbname]) + subprocess.check_call( + [ + 'psql', + '--quiet', + '--no-psqlrc', + '-v', + 'ON_ERROR_STOP=1', + '-f', + dumpfile, + dbname, + ] + ) def pg_dump(dbname, dumpfile): - subprocess.check_call(['pg_dump', '--no-owner', '--no-privileges', '-Fc', - '-f', dumpfile, dbname]) + subprocess.check_call( + ['pg_dump', '--no-owner', '--no-privileges', '-Fc', '-f', dumpfile, dbname] + ) def pg_dropdb(dbname): @@ -92,8 +114,9 @@ """ try: pg_createdb(dbname) - except subprocess.CalledProcessError: # try recovering once, in case - pg_dropdb(dbname) # the db already existed + except subprocess.CalledProcessError: + # try recovering once, in case the db already existed + pg_dropdb(dbname) pg_createdb(dbname) for dump, dtype in dumps: pg_restore(dbname, dump, dtype) @@ -116,10 +139,7 @@ """ conn = psycopg2.connect('dbname=' + dbname) - return { - 'conn': conn, - 'cursor': conn.cursor() - } + return {'conn': conn, 'cursor': conn.cursor()} def db_close(conn): @@ -153,8 +173,7 @@ self.dumps = dumps def __enter__(self): - db_create(dbname=self.dbname, - dumps=self.dumps) + db_create(dbname=self.dbname, dumps=self.dumps) return self def __exit__(self, *_): @@ -248,8 +267,11 @@ conn = db.conn cursor = db.cursor - cursor.execute("""SELECT table_name FROM information_schema.tables - WHERE table_schema = %s""", ('public',)) + cursor.execute( + """SELECT table_name FROM information_schema.tables + WHERE table_schema = %s""", + ('public',), + ) tables = set(table for (table,) in cursor.fetchall()) if excluded is not None: @@ -302,14 +324,13 @@ dump_files = [dump_files] all_dump_files = [] for files in dump_files: - all_dump_files.extend( - sorted(glob.glob(files), key=sortkey)) + all_dump_files.extend(sorted(glob.glob(files), key=sortkey)) - all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) - for x in all_dump_files] + all_dump_files = [ + (x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files + ] - cls.add_db(name=cls.TEST_DB_NAME, - dumps=all_dump_files) + cls.add_db(name=cls.TEST_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self, *args, **kwargs): diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py --- a/swh/core/db/tests/test_db.py +++ b/swh/core/db/tests/test_db.py @@ -15,9 +15,7 @@ from swh.core.db import BaseDb from swh.core.db.common import db_transaction, db_transaction_generator -from .db_testing import ( - SingleDbTestFixture, db_create, db_destroy, db_close, -) +from .db_testing import SingleDbTestFixture, db_create, db_destroy, db_close INIT_SQL = ''' @@ -29,19 +27,21 @@ ); ''' -db_rows = strategies.lists(strategies.tuples( - strategies.integers(-2147483648, +2147483647), - strategies.text( - alphabet=strategies.characters( - blacklist_categories=['Cs'], # surrogates - blacklist_characters=[ - '\x00', # pgsql does not support the null codepoint - '\r', # pgsql normalizes those - ] +db_rows = strategies.lists( + strategies.tuples( + strategies.integers(-2147483648, +2147483647), + strategies.text( + alphabet=strategies.characters( + blacklist_categories=['Cs'], # surrogates + blacklist_characters=[ + '\x00', # pgsql does not support the null codepoint + '\r', # pgsql normalizes those + ], + ) ), - ), - strategies.binary(), -)) + strategies.binary(), + ) +) @pytest.mark.db @@ -51,8 +51,7 @@ db = BaseDb.connect('dbname=%s' % db_name) with db.cursor() as cur: cur.execute(INIT_SQL) - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ('foo', b'bar')) cur.execute("select * from test_table;") assert list(cur) == [(1, 'foo', b'bar')] finally: @@ -80,15 +79,13 @@ def test_initialized(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ('foo', b'bar')) cur.execute("select * from test_table;") self.assertEqual(list(cur), [(1, 'foo', b'bar')]) def test_reset_tables(self): cur = self.db.cursor() - cur.execute("insert into test_table values (1, %s, %s);", - ('foo', b'bar')) + cur.execute("insert into test_table values (1, %s, %s);", ('foo', b'bar')) self.reset_db_tables('test-db') cur.execute("select * from test_table;") self.assertEqual(list(cur), []) @@ -106,7 +103,7 @@ self.assertCountEqual(list(cur), data) def test_copy_to_thread_exception(self): - data = [(2**65, 'foo', b'bar')] + data = [(2 ** 65, 'foo', b'bar')] items = [dict(zip(['i', 'txt', 'bytes'], item)) for item in data] with self.assertRaises(psycopg2.errors.NumericValueOutOfRange): @@ -132,11 +129,9 @@ db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, 'get_db', return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, 'put_db', create=True) storage.endpoint() @@ -146,6 +141,7 @@ def test_db_transaction__with_generator(): with pytest.raises(ValueError, match='generator'): + class Storage: @db_transaction() def endpoint(self, cur=None, db=None): @@ -154,8 +150,10 @@ def test_db_transaction_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction() @@ -187,11 +185,9 @@ db_mock = Mock() db_mock.transaction.return_value = MagicMock() db_mock.transaction.return_value.__enter__.return_value = expected_cur - mocker.patch.object( - storage, 'get_db', return_value=db_mock, create=True) + mocker.patch.object(storage, 'get_db', return_value=db_mock, create=True) - put_db_mock = mocker.patch.object( - storage, 'put_db', create=True) + put_db_mock = mocker.patch.object(storage, 'put_db', create=True) list(storage.endpoint()) @@ -201,6 +197,7 @@ def test_db_transaction_generator__with_nongenerator(): with pytest.raises(ValueError, match='generator'): + class Storage: @db_transaction_generator() def endpoint(self, cur=None, db=None): @@ -209,8 +206,10 @@ def test_db_transaction_generator_signature(): """Checks db_transaction removes the 'cur' and 'db' arguments.""" + def f(self, foo, *, bar=None): pass + expected_sig = inspect.signature(f) @db_transaction_generator() diff --git a/swh/core/logger.py b/swh/core/logger.py --- a/swh/core/logger.py +++ b/swh/core/logger.py @@ -30,9 +30,11 @@ """Get the extra data to insert to the database from the logging record""" log_data = record.__dict__ - extra_data = {k[len(EXTRA_LOGDATA_PREFIX):]: v - for k, v in log_data.items() - if k.startswith(EXTRA_LOGDATA_PREFIX)} + extra_data = { + k[len(EXTRA_LOGDATA_PREFIX) :]: v + for k, v in log_data.items() + if k.startswith(EXTRA_LOGDATA_PREFIX) + } args = log_data.get('args') if args: @@ -40,23 +42,19 @@ # Retrieve Celery task info if current_task and current_task.request: - extra_data['task'] = { - 'id': current_task.request.id, - 'name': current_task.name, - } + extra_data['task'] = {'id': current_task.request.id, 'name': current_task.name} if task_args: - extra_data['task'].update({ - 'kwargs': current_task.request.kwargs, - 'args': current_task.request.args, - }) + extra_data['task'].update( + { + 'kwargs': current_task.request.kwargs, + 'args': current_task.request.args, + } + ) return extra_data -def flatten( - data: Any, - separator: str = "_" -) -> Generator[Tuple[str, Any], None, None]: +def flatten(data: Any, separator: str = "_") -> Generator[Tuple[str, Any], None, None]: """Flatten the data dictionary into a flat structure""" def inner_flatten( @@ -102,13 +100,15 @@ } msg = self.format(record) pri = self.mapPriority(record.levelno) - send(msg, - PRIORITY=format(pri), - LOGGER=record.name, - THREAD_NAME=record.threadName, - CODE_FILE=record.pathname, - CODE_LINE=record.lineno, - CODE_FUNC=record.funcName, - **extra_data) + send( + msg, + PRIORITY=format(pri), + LOGGER=record.name, + THREAD_NAME=record.threadName, + CODE_FILE=record.pathname, + CODE_LINE=record.lineno, + CODE_FUNC=record.funcName, + **extra_data + ) except Exception: self.handleError(record) diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -27,9 +27,12 @@ def get_response_cb( - request: requests.Request, context, datadir, - ignore_urls: List[str] = [], - visits: Optional[Dict] = None): + request: requests.Request, + context, + datadir, + ignore_urls: List[str] = [], + visits: Optional[Dict] = None, +): """Mount point callback to fetch on disk the request's content. The request urls provided are url decoded first to resolve the associated file on disk. @@ -135,8 +138,9 @@ return path.join(path.dirname(str(request.fspath)), 'data') -def requests_mock_datadir_factory(ignore_urls: List[str] = [], - has_multi_visit: bool = False): +def requests_mock_datadir_factory( + ignore_urls: List[str] = [], has_multi_visit: bool = False +): """This factory generates fixture which allow to look for files on the local filesystem based on the requested URL, using the following rules: @@ -167,18 +171,22 @@ has_multi_visit: Activate or not the multiple visits behavior """ + @pytest.fixture def requests_mock_datadir(requests_mock, datadir): if not has_multi_visit: - cb = partial(get_response_cb, - ignore_urls=ignore_urls, - datadir=datadir) + cb = partial(get_response_cb, ignore_urls=ignore_urls, datadir=datadir) requests_mock.get(re.compile('https?://'), body=cb) else: visits = {} - requests_mock.get(re.compile('https?://'), body=partial( - get_response_cb, ignore_urls=ignore_urls, visits=visits, - datadir=datadir) + requests_mock.get( + re.compile('https?://'), + body=partial( + get_response_cb, + ignore_urls=ignore_urls, + visits=visits, + datadir=datadir, + ), ) return requests_mock @@ -193,8 +201,7 @@ # - first time, it checks for a file named `filename` # - second time, it checks for a file named `filename`_visit1 # etc... -requests_mock_datadir_visits = requests_mock_datadir_factory( - has_multi_visit=True) +requests_mock_datadir_visits = requests_mock_datadir_factory(has_multi_visit=True) @pytest.fixture @@ -275,10 +282,11 @@ Overrides ``requests.adapters.BaseAdapter.send`` """ resp = self._client.open( - request.url, method=request.method, + request.url, + method=request.method, headers=request.headers.items(), data=request.body, - ) + ) return self.build_response(request, resp) diff --git a/swh/core/sentry.py b/swh/core/sentry.py --- a/swh/core/sentry.py +++ b/swh/core/sentry.py @@ -16,9 +16,7 @@ return None -def init_sentry( - sentry_dsn, *, debug=None, integrations=[], - extra_kwargs={}): +def init_sentry(sentry_dsn, *, debug=None, integrations=[], extra_kwargs={}): if debug is None: debug = bool(os.environ.get('SWH_SENTRY_DEBUG')) sentry_dsn = sentry_dsn or os.environ.get('SWH_SENTRY_DSN') diff --git a/swh/core/statsd.py b/swh/core/statsd.py --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -74,8 +74,10 @@ Attributes: elapsed (float): the elapsed time at the point of completion """ - def __init__(self, statsd, metric=None, error_metric=None, - tags=None, sample_rate=1): + + def __init__( + self, statsd, metric=None, error_metric=None, tags=None, sample_rate=1 + ): self.statsd = statsd self.metric = metric self.error_metric = error_metric @@ -94,6 +96,7 @@ # Coroutines if iscoroutinefunction(func): + @wraps(func) async def wrapped_co(*args, **kwargs): start = monotonic() @@ -104,6 +107,7 @@ raise self._send(start) return result + return wrapped_co # Others @@ -117,6 +121,7 @@ raise self._send(start) return result + return wrapped def __enter__(self): @@ -134,8 +139,9 @@ def _send(self, start): elapsed = (monotonic() - start) * 1000 - self.statsd.timing(self.metric, elapsed, - tags=self.tags, sample_rate=self.sample_rate) + self.statsd.timing( + self.metric, elapsed, tags=self.tags, sample_rate=self.sample_rate + ) self.elapsed = elapsed def _send_error(self): @@ -179,8 +185,14 @@ "label:value,other_label:other_value" """ - def __init__(self, host=None, port=None, max_buffer_size=50, - namespace=None, constant_tags=None): + def __init__( + self, + host=None, + port=None, + max_buffer_size=50, + namespace=None, + constant_tags=None, + ): # Connection if host is None: host = os.environ.get('STATSD_HOST') or 'localhost' @@ -205,8 +217,7 @@ continue if ':' not in tag: warnings.warn( - 'STATSD_TAGS needs to be in key:value format, ' - '%s invalid' % tag, + 'STATSD_TAGS needs to be in key:value format, %s invalid' % tag, UserWarning, ) continue @@ -214,10 +225,9 @@ self.constant_tags[k] = v if constant_tags: - self.constant_tags.update({ - str(k): str(v) - for k, v in constant_tags.items() - }) + self.constant_tags.update( + {str(k): str(v) for k, v in constant_tags.items()} + ) # Namespace if namespace is not None: @@ -306,9 +316,12 @@ statsd.timing('user.query.time', time.monotonic() - start) """ return TimedContextManagerDecorator( - statsd=self, metric=metric, + statsd=self, + metric=metric, error_metric=error_metric, - tags=tags, sample_rate=sample_rate) + tags=tags, + sample_rate=sample_rate, + ) def set(self, metric, value, tags=None, sample_rate=1): """ @@ -387,10 +400,9 @@ value, metric_type, ("|@" + str(sample_rate)) if sample_rate != 1 else "", - ("|#" + ",".join( - "%s:%s" % (k, v) - for (k, v) in sorted(tags.items()) - )) if tags else "", + ("|#" + ",".join("%s:%s" % (k, v) for (k, v) in sorted(tags.items()))) + if tags + else "", ) # Send it self._send(payload) @@ -421,8 +433,7 @@ return { str(k): str(v) for k, v in itertools.chain( - self.constant_tags.items(), - (tags if tags else {}).items(), + self.constant_tags.items(), (tags if tags else {}).items() ) } diff --git a/swh/core/tarball.py b/swh/core/tarball.py --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -37,7 +37,8 @@ return extract_dir except Exception as e: raise shutil.ReadError( - f'Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}') + f'Unable to uncompress {tarpath} to {extract_dir}. Reason: {e}' + ) def register_new_archive_formats(): @@ -89,7 +90,7 @@ """ for dirpath, dirnames, fnames in os.walk(rootdir): - for fname in (dirnames+fnames): + for fname in dirnames + fnames: fpath = os.path.join(dirpath, fname) fname = utils.commonname(rootdir, fpath) yield fpath, fname diff --git a/swh/core/tests/fixture/test_pytest_plugin.py b/swh/core/tests/fixture/test_pytest_plugin.py --- a/swh/core/tests/fixture/test_pytest_plugin.py +++ b/swh/core/tests/fixture/test_pytest_plugin.py @@ -11,8 +11,7 @@ # "datadir" fixture to specify where to retrieve the data files from. -def test_requests_mock_datadir_with_datadir_fixture_override( - requests_mock_datadir): +def test_requests_mock_datadir_with_datadir_fixture_override(requests_mock_datadir): """Override datadir fixture should retrieve data from elsewhere """ diff --git a/swh/core/tests/test_cli.py b/swh/core/tests/test_cli.py --- a/swh/core/tests/test_cli.py +++ b/swh/core/tests/test_cli.py @@ -110,11 +110,7 @@ assert result.exit_code == 0 assert result.output.strip() == '''Hello SWH!''' sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=False, - integrations=[], - release=None, - environment=None, + dsn='test_dsn', debug=False, integrations=[], release=None, environment=None ) @@ -127,15 +123,12 @@ runner = CliRunner() with patch('sentry_sdk.init') as sentry_sdk_init: result = runner.invoke( - swhmain, ['--sentry-dsn', 'test_dsn', '--sentry-debug', 'test']) + swhmain, ['--sentry-dsn', 'test_dsn', '--sentry-debug', 'test'] + ) assert result.exit_code == 0 assert result.output.strip() == '''Hello SWH!''' sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn='test_dsn', debug=True, integrations=[], release=None, environment=None ) @@ -147,20 +140,12 @@ runner = CliRunner() with patch('sentry_sdk.init') as sentry_sdk_init: - env = { - 'SWH_SENTRY_DSN': 'test_dsn', - 'SWH_SENTRY_DEBUG': '1', - } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + env = {'SWH_SENTRY_DSN': 'test_dsn', 'SWH_SENTRY_DEBUG': '1'} + result = runner.invoke(swhmain, ['test'], env=env, auto_envvar_prefix='SWH') assert result.exit_code == 0 assert result.output.strip() == '''Hello SWH!''' sentry_sdk_init.assert_called_once_with( - dsn='test_dsn', - debug=True, - integrations=[], - release=None, - environment=None, + dsn='test_dsn', debug=True, integrations=[], release=None, environment=None ) @@ -177,8 +162,7 @@ 'SWH_MAIN_PACKAGE': 'swh.core', 'SWH_SENTRY_ENVIRONMENT': 'tests', } - result = runner.invoke( - swhmain, ['test'], env=env, auto_envvar_prefix='SWH') + result = runner.invoke(swhmain, ['test'], env=env, auto_envvar_prefix='SWH') assert result.exit_code == 0 version = pkg_resources.get_distribution('swh.core').version @@ -195,7 +179,8 @@ @pytest.fixture def log_config_path(tmp_path): - log_config = textwrap.dedent('''\ + log_config = textwrap.dedent( + '''\ --- version: 1 formatters: @@ -214,7 +199,8 @@ loggers: dontshowdebug: level: INFO - ''') + ''' + ) (tmp_path / 'log_config.yml').write_text(log_config) @@ -231,19 +217,16 @@ logging.getLogger('dontshowdebug').info('Shown') runner = CliRunner() - result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - 'test', - ], - ) + result = runner.invoke(swhmain, ['--log-config', log_config_path, 'test']) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:DEBUG:Root log debug', - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == '\n'.join( + [ + 'custom format:root:DEBUG:Root log debug', + 'custom format:root:INFO:Root log info', + 'custom format:dontshowdebug:INFO:Shown', + ] + ) def test_log_config_log_level_interaction(caplog, log_config_path, swhmain): @@ -257,18 +240,16 @@ runner = CliRunner() result = runner.invoke( - swhmain, [ - '--log-config', log_config_path, - '--log-level', 'INFO', - 'test', - ], + swhmain, ['--log-config', log_config_path, '--log-level', 'INFO', 'test'] ) assert result.exit_code == 0 - assert result.output.strip() == '\n'.join([ - 'custom format:root:INFO:Root log info', - 'custom format:dontshowdebug:INFO:Shown', - ]) + assert result.output.strip() == '\n'.join( + [ + 'custom format:root:INFO:Root log info', + 'custom format:dontshowdebug:INFO:Shown', + ] + ) def test_aliased_command(swhmain): @@ -277,6 +258,7 @@ def swhtest(ctx): 'A test command.' click.echo('Hello SWH!') + swhmain.add_alias(swhtest, 'othername') runner = CliRunner() diff --git a/swh/core/tests/test_config.py b/swh/core/tests/test_config.py --- a/swh/core/tests/test_config.py +++ b/swh/core/tests/test_config.py @@ -12,10 +12,12 @@ pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse('3.9'): + @pytest.fixture def tmp_path(request): import tempfile import pathlib + with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) @@ -34,18 +36,12 @@ 'li': ('list[int]', [42, 43]), } -other_default_conf = { - 'a': ('int', 3), -} +other_default_conf = {'a': ('int', 3)} full_default_conf = default_conf.copy() full_default_conf['a'] = other_default_conf['a'] -parsed_default_conf = { - key: value - for key, (type, value) - in default_conf.items() -} +parsed_default_conf = {key: value for key, (type, value) in default_conf.items()} parsed_conffile = { 'a': 1, @@ -172,8 +168,9 @@ empty.touch() # when - res = config.priority_read([str(p) for p in ( - swh_config, noexist, empty)], default_conf) + res = config.priority_read( + [str(p) for p in (swh_config, noexist, empty)], default_conf + ) # then assert res == parsed_conffile @@ -185,8 +182,9 @@ empty.touch() # when - res = config.priority_read([str(p) for p in ( - empty, swh_config, noexist)], default_conf) + res = config.priority_read( + [str(p) for p in (empty, swh_config, noexist)], default_conf + ) # then assert res == parsed_default_conf @@ -204,8 +202,10 @@ def test_prepare_folder(tmp_path): # given - conf = {'path1': str(tmp_path / 'path1'), - 'path2': str(tmp_path / 'path2' / 'depth1')} + conf = { + 'path1': str(tmp_path / 'path1'), + 'path2': str(tmp_path / 'path2' / 'depth1'), + } # the folders does not exists assert not os.path.exists(conf['path1']), "path1 should not exist." @@ -235,8 +235,7 @@ 'ea': 'Mr. Bungle', 'eb': None, 'ec': [11, 12, 13], - 'ed': {'eda': 'Secret Chief 3', - 'edb': 'Faith No More'}, + 'ed': {'eda': 'Secret Chief 3', 'edb': 'Faith No More'}, 'ee': 451, }, 'f': 'Janis', @@ -249,8 +248,7 @@ 'e': { 'ea': 'Igorrr', 'ec': [51, 52], - 'ed': {'edb': 'Sleepytime Gorilla Museum', - 'edc': 'Nils Peter Molvaer'}, + 'ed': {'edb': 'Sleepytime Gorilla Museum', 'edc': 'Nils Peter Molvaer'}, }, 'g': 'Hüsker Dü', } @@ -269,7 +267,8 @@ 'ed': { 'eda': 'Secret Chief 3', # only in a 'edb': 'Sleepytime Gorilla Museum', # b takes precedence - 'edc': 'Nils Peter Molvaer'}, # only defined in b + 'edc': 'Nils Peter Molvaer', + }, # only defined in b 'ee': 451, }, 'f': 'Janis', # only defined in a @@ -290,7 +289,8 @@ 'ed': { 'eda': 'Secret Chief 3', # only in a 'edb': 'Faith No More', # a takes precedence - 'edc': 'Nils Peter Molvaer'}, # only in b + 'edc': 'Nils Peter Molvaer', + }, # only in b 'ee': 451, }, 'f': 'Janis', # only in a diff --git a/swh/core/tests/test_logger.py b/swh/core/tests/test_logger.py --- a/swh/core/tests/test_logger.py +++ b/swh/core/tests/test_logger.py @@ -58,9 +58,7 @@ assert list(logger.flatten({})) == [] assert list(logger.flatten({'a': 1})) == [('a', 1)] - assert sorted(logger.flatten({'a': 1, - 'b': (2, 3,), - 'c': {'d': 4, 'e': 'f'}})) == [ + assert sorted(logger.flatten({'a': 1, 'b': (2, 3), 'c': {'d': 4, 'e': 'f'}})) == [ ('a', 1), ('b_0', 2), ('b_1', 3), @@ -74,9 +72,7 @@ str_d = str(d) assert list(logger.flatten(d)) == [("", str_d)] assert list(logger.flatten({"a": d})) == [("a", str_d)] - assert list(logger.flatten({"a": [d, d]})) == [ - ("a_0", str_d), ("a_1", str_d) - ] + assert list(logger.flatten({"a": [d, d]})) == [("a_0", str_d), ("a_1", str_d)] def test_stringify(): @@ -106,7 +102,8 @@ CODE_LINE=ln, LOGGER='test_logger', PRIORITY='6', - THREAD_NAME='MainThread') + THREAD_NAME='MainThread', + ) @patch('swh.core.logger.send') @@ -115,7 +112,10 @@ log.addHandler(logger.JournalHandler()) log.setLevel(logging.DEBUG) - _, ln = log.debug('something cool %s', ['with', {'extra': 'data'}]), lineno() # noqa + _, ln = ( + log.debug('something cool %s', ['with', {'extra': 'data'}]), + lineno() - 1, + ) # noqa send.assert_called_with( "something cool ['with', {'extra': 'data'}]", @@ -126,5 +126,5 @@ PRIORITY='7', THREAD_NAME='MainThread', SWH_LOGGING_ARGS_0_0='with', - SWH_LOGGING_ARGS_0_1_EXTRA='data' + SWH_LOGGING_ARGS_0_1_EXTRA='data', ) diff --git a/swh/core/tests/test_pytest_plugin.py b/swh/core/tests/test_pytest_plugin.py --- a/swh/core/tests/test_pytest_plugin.py +++ b/swh/core/tests/test_pytest_plugin.py @@ -14,10 +14,11 @@ def test_get_response_cb_with_encoded_url(requests_mock_datadir): # The following urls (quoted, unquoted) will be resolved as the same file for encoded_url, expected_response in [ - ('https://forge.s.o/api/diffusion?attachments%5Buris%5D=1', - "something"), - ('https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1', # noqa - "something else"), + ('https://forge.s.o/api/diffusion?attachments%5Buris%5D=1', "something"), + ( + 'https://www.reference.com/web?q=What+Is+an+Example+of+a+URL?&qo=contentPageRelatedSearch&o=600605&l=dir&sga=1', # noqa + "something else", + ), ]: for url in [encoded_url, unquote(encoded_url)]: response = requests.get(url) @@ -80,15 +81,13 @@ assert not response.ok assert response.status_code == 404 - response = requests.get( - 'https://example.com/file.json?name=doe&firstname=jane') + response = requests.get('https://example.com/file.json?name=doe&firstname=jane') assert response.ok assert response.json() == {'hello': 'jane doe'} requests_mock_datadir_ignore = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=False, + ignore_urls=['https://example.com/file.json'], has_multi_visit=False ) @@ -99,13 +98,11 @@ requests_mock_datadir_ignore_and_visit = requests_mock_datadir_factory( - ignore_urls=['https://example.com/file.json'], - has_multi_visit=True, + ignore_urls=['https://example.com/file.json'], has_multi_visit=True ) -def test_get_response_cb_ignore_url_with_visit( - requests_mock_datadir_ignore_and_visit): +def test_get_response_cb_ignore_url_with_visit(requests_mock_datadir_ignore_and_visit): response = requests.get('https://example.com/file.json') assert not response.ok assert response.status_code == 404 diff --git a/swh/core/tests/test_statsd.py b/swh/core/tests/test_statsd.py --- a/swh/core/tests/test_statsd.py +++ b/swh/core/tests/test_statsd.py @@ -98,7 +98,6 @@ class TestStatsd(unittest.TestCase): - def setUp(self): """ Set up a default Statsd instance and mock the socket. @@ -157,8 +156,7 @@ def test_tags_and_samples(self): for i in range(100): - self.statsd.gauge('gst', 23, tags={"sampled": True}, - sample_rate=0.9) + self.statsd.gauge('gst', 23, tags={"sampled": True}, sample_rate=0.9) self.assert_almost_equal(90, len(self.statsd.socket.payloads), 10) self.assertEqual('gst:23|g|@0.9|#sampled:True', self.recv()) @@ -177,22 +175,14 @@ # Test Client level constant tags def test_gauge_constant_tags(self): - self.statsd.constant_tags = { - 'bar': 'baz', - } + self.statsd.constant_tags = {'bar': 'baz'} self.statsd.gauge('gauge', 123.4) assert self.recv() == 'gauge:123.4|g|#bar:baz' def test_counter_constant_tag_with_metric_level_tags(self): - self.statsd.constant_tags = { - 'bar': 'baz', - 'foo': True, - } + self.statsd.constant_tags = {'bar': 'baz', 'foo': True} self.statsd.increment('page.views', tags={'extra': 'extra'}) - self.assertEqual( - 'page.views:1|c|#bar:baz,extra:extra,foo:True', - self.recv(), - ) + self.assertEqual('page.views:1|c|#bar:baz,extra:extra,foo:True', self.recv()) def test_gauge_constant_tags_with_metric_level_tags_twice(self): metric_level_tag = {'foo': 'bar'} @@ -207,8 +197,7 @@ def assert_almost_equal(self, a, b, delta): self.assertTrue( - 0 <= abs(a - b) <= delta, - "%s - %s not within %s" % (a, b, delta) + 0 <= abs(a - b) <= delta, "%s - %s not within %s" % (a, b, delta) ) def test_socket_error(self): @@ -225,6 +214,7 @@ """ Measure the distribution of a function's run time. """ + @self.statsd.timed('timed.test') def func(a, b, c=1, d=1): """docstring""" @@ -251,6 +241,7 @@ Exception bubble out of the decorator and is reported to statsd as a dedicated counter. """ + @self.statsd.timed('timed.test') def func(a, b, c=1, d=1): """docstring""" @@ -271,7 +262,7 @@ self.assertEqual('timed.test_error_count', name) self.assertEqual(int(value), 1) - def test_timed_no_metric(self, ): + def test_timed_no_metric(self,): """ Test using a decorator without providing a metric. """ @@ -348,6 +339,7 @@ Exception bubbles out of the `timed` context manager and is reported to statsd as a dedicated counter. """ + class ContextException(Exception): pass @@ -422,14 +414,14 @@ for i in range(51): statsd.increment('mycounter') self.assertEqual( - '\n'.join(['mycounter:1|c' for i in range(50)]), - fake_socket.recv(), + '\n'.join(['mycounter:1|c' for i in range(50)]), fake_socket.recv() ) self.assertEqual('mycounter:1|c', fake_socket.recv()) def test_module_level_instance(self): from swh.core.statsd import statsd + self.assertTrue(isinstance(statsd, Statsd)) def test_instantiating_does_not_connect(self): @@ -457,8 +449,7 @@ statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:china', - statsd.socket.recv()) + self.assertEqual('gt:123.4|g|#age:45,country:china', statsd.socket.recv()) def test_tags_from_environment_and_constant(self): with preserve_envvars('STATSD_TAGS'): @@ -466,8 +457,7 @@ statsd = Statsd(constant_tags={'country': 'canada'}) statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) - self.assertEqual('gt:123.4|g|#age:45,country:canada', - statsd.socket.recv()) + self.assertEqual('gt:123.4|g|#age:45,country:canada', statsd.socket.recv()) def test_tags_from_environment_warning(self): with preserve_envvars('STATSD_TAGS'): diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -70,8 +70,9 @@ assert not os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f'Unable to uncompress {tarpath} to {tmp_path}' + ): tarball._unpack_tar(tarpath, tmp_path) @@ -86,8 +87,9 @@ extract_dir = os.path.join(tmp_path, 'dir', 'inexistent') - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f'Unable to uncompress {tarpath} to {tmp_path}' + ): tarball._unpack_tar(tarpath, extract_dir) @@ -100,8 +102,9 @@ assert os.path.exists(tarpath) - with pytest.raises(shutil.ReadError, - match=f'Unable to uncompress {tarpath} to {tmp_path}'): + with pytest.raises( + shutil.ReadError, match=f'Unable to uncompress {tarpath} to {tmp_path}' + ): tarball._unpack_tar(tarpath, tmp_path) @@ -155,8 +158,7 @@ # not supported yet for tarpath in unsupported_tarpaths: - with pytest.raises(ValueError, - match=f'Problem during unpacking {tarpath}.'): + with pytest.raises(ValueError, match=f'Problem during unpacking {tarpath}.'): tarball.uncompress(tarpath, dest=tmp_path) # register those unsupported formats diff --git a/swh/core/tests/test_utils.py b/swh/core/tests/test_utils.py --- a/swh/core/tests/test_utils.py +++ b/swh/core/tests/test_utils.py @@ -9,7 +9,6 @@ class UtilsLib(unittest.TestCase): - def test_grouper(self): # given actual_data = utils.grouper((i for i in range(0, 9)), 2) @@ -31,18 +30,22 @@ def test_grouper_with_stop_value(self): # given - actual_data = utils.grouper(((i, i+1) for i in range(0, 9)), 2) + actual_data = utils.grouper(((i, i + 1) for i in range(0, 9)), 2) out = [] for d in actual_data: out.append(list(d)) # force generator resolution for checks - self.assertEqual(out, [ - [(0, 1), (1, 2)], - [(2, 3), (3, 4)], - [(4, 5), (5, 6)], - [(6, 7), (7, 8)], - [(8, 9)]]) + self.assertEqual( + out, + [ + [(0, 1), (1, 2)], + [(2, 3), (3, 4)], + [(4, 5), (5, 6)], + [(6, 7), (7, 8)], + [(8, 9)], + ], + ) # given actual_data = utils.grouper((i for i in range(9, 0, -1)), 4) @@ -58,10 +61,7 @@ with self.assertRaises(UnicodeDecodeError): raw_data_err.decode('utf-8', 'strict') - self.assertEqual( - raw_data_err.decode('utf-8', 'backslashescape'), - 'abcd\\x80', - ) + self.assertEqual(raw_data_err.decode('utf-8', 'backslashescape'), 'abcd\\x80') raw_data_ok = b'abcd\xc3\xa9' self.assertEqual( @@ -71,18 +71,14 @@ unicode_data = 'abcdef\u00a3' self.assertEqual( - unicode_data.encode('ascii', 'backslashescape'), - b'abcdef\\xa3', + unicode_data.encode('ascii', 'backslashescape'), b'abcdef\\xa3' ) def test_encode_with_unescape(self): valid_data = '\\x01020304\\x00' valid_data_encoded = b'\x01020304\x00' - self.assertEqual( - valid_data_encoded, - utils.encode_with_unescape(valid_data) - ) + self.assertEqual(valid_data_encoded, utils.encode_with_unescape(valid_data)) def test_encode_with_unescape_invalid_escape(self): invalid_data = 'test\\abcd' @@ -97,44 +93,34 @@ backslashes = b'foo\\bar\\\\baz' backslashes_escaped = 'foo\\\\bar\\\\\\\\baz' - self.assertEqual( - backslashes_escaped, - utils.decode_with_escape(backslashes), - ) + self.assertEqual(backslashes_escaped, utils.decode_with_escape(backslashes)) valid_utf8 = b'foo\xc3\xa2' valid_utf8_escaped = 'foo\u00e2' - self.assertEqual( - valid_utf8_escaped, - utils.decode_with_escape(valid_utf8), - ) + self.assertEqual(valid_utf8_escaped, utils.decode_with_escape(valid_utf8)) invalid_utf8 = b'foo\xa2' invalid_utf8_escaped = 'foo\\xa2' - self.assertEqual( - invalid_utf8_escaped, - utils.decode_with_escape(invalid_utf8), - ) + self.assertEqual(invalid_utf8_escaped, utils.decode_with_escape(invalid_utf8)) valid_utf8_nul = b'foo\xc3\xa2\x00' valid_utf8_nul_escaped = 'foo\u00e2\\x00' self.assertEqual( - valid_utf8_nul_escaped, - utils.decode_with_escape(valid_utf8_nul), + valid_utf8_nul_escaped, utils.decode_with_escape(valid_utf8_nul) ) def test_commonname(self): # when - actual_commonname = utils.commonname('/some/where/to/', - '/some/where/to/go/to') + actual_commonname = utils.commonname('/some/where/to/', '/some/where/to/go/to') # then self.assertEqual('go/to', actual_commonname) # when - actual_commonname2 = utils.commonname(b'/some/where/to/', - b'/some/where/to/go/to') + actual_commonname2 = utils.commonname( + b'/some/where/to/', b'/some/where/to/go/to' + ) # then self.assertEqual(b'go/to', actual_commonname2) diff --git a/swh/core/utils.py b/swh/core/utils.py --- a/swh/core/utils.py +++ b/swh/core/utils.py @@ -51,7 +51,7 @@ def backslashescape_errors(exception): if isinstance(exception, UnicodeDecodeError): - bad_data = exception.object[exception.start:exception.end] + bad_data = exception.object[exception.start : exception.end] escaped = ''.join(r'\x%02x' % x for x in bad_data) return escaped, exception.end @@ -73,12 +73,13 @@ else: if odd_backslashes: if value[i] != 'x': - raise ValueError('invalid escape for %r at position %d' % - (value, i-1)) + raise ValueError( + 'invalid escape for %r at position %d' % (value, i - 1) + ) slices.append( - value[start:i-1].replace('\\\\', '\\').encode('utf-8') + value[start : i - 1].replace('\\\\', '\\').encode('utf-8') ) - slices.append(bytes.fromhex(value[i+1:i+3])) + slices.append(bytes.fromhex(value[i + 1 : i + 3])) odd_backslashes = False start = i = i + 3 @@ -86,9 +87,7 @@ i += 1 - slices.append( - value[start:i].replace('\\\\', '\\').encode('utf-8') - ) + slices.append(value[start:i].replace('\\\\', '\\').encode('utf-8')) return b''.join(slices)