diff --git a/PKG-INFO b/PKG-INFO index 44c40d5..dfe1bee 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,91 +1,91 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.87 +Version: 0.0.88 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index 44c40d5..dfe1bee 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,91 +1,91 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.87 +Version: 0.0.88 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 63ea3e0..ba08caa 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,376 +1,413 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc +import datetime import functools import inspect import json import logging import pickle import requests -import datetime +import traceback -from typing import Any, ClassVar, Optional, Type +from typing import Any, ClassVar, List, Optional, Type from flask import Flask, Request, Response, request, abort +from werkzeug.exceptions import HTTPException + from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) class SWHJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, datetime.timedelta): return str(obj) # Let the base class default method raise the TypeError return super().default(obj) class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): return json.dumps(obj, cls=SWHJSONEncoder) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): return msgpack_dumps(obj) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__(self, payload: Optional[Any] = None, response: Optional[requests.Response] = None): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response + def __str__(self): + if self.args and isinstance(self.args[0], dict) \ + and 'type' in self.args[0] and 'args' in self.args[0]: + return ( + f'') + else: + return super().__str__() + def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" + reraise_exceptions: ClassVar[List[Type[Exception]]] = [] + """On server errors, if any of the exception classes in this list + has the same name as the error name, then the exception will + be instantiated and raised instead of a generic RemoteException.""" + def __init__(self, url, api_exception=None, - timeout=None, chunk_size=4096, **kwargs): + timeout=None, chunk_size=4096, + reraise_exceptions=None, **kwargs): if api_exception: self.api_exception = api_exception + if reraise_exceptions: + self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (encode_data(x) for x in data) else: data = encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ - # XXX: unpickling below breaks language-independence and should be - # replaced by proper language-independent [de]serialization status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload='404 not found', response=response) exception = None + # TODO: only old servers send pickled error; stop trying to unpickle + # after they are all upgraded try: if status_class == 4: data = decode_response(response) - exception = pickle.loads(data) + if isinstance(data, dict): + for exc_type in self.reraise_exceptions: + if exc_type.__name__ == data['exception']['type']: + exception = exc_type(*data['exception']['args']) + break + else: + exception = RemoteException(payload=data['exception'], + response=response) + else: + exception = pickle.loads(data) elif status_class == 5: data = decode_response(response) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: exception = RemoteException(payload=data['exception'], response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f'API HTTP error: {status_code} {response.content}', response=response) def _decode_response(self, response): - if response.status_code == 404: - return None - else: - self.raise_for_status(response) - return decode_response(response) + self.raise_for_status(response) + return decode_response(response) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' ENCODERS = { 'application/x-msgpack': msgpack_dumps, 'application/json': json.dumps, } def encode_data_server(data, content_type='application/x-msgpack'): encoded_data = ENCODERS[content_type](data) return Response( encoded_data, mimetype=content_type, ) def decode_request(request): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r -def error_handler(exception, encoder): - # XXX: this breaks language-independence and should be - # replaced by proper serialization of errors +def error_handler(exception, encoder, status_code=500): logging.exception(exception) - response = encoder(pickle.dumps(exception)) - response.status_code = 400 + tb = traceback.format_exception(None, exception, exception.__traceback__) + error = { + 'exception': { + 'type': type(exception).__name__, + 'args': exception.args, + 'message': str(exception), + 'traceback': tb, + } + } + response = encoder(error) + if isinstance(exception, HTTPException): + response.status_code = exception.code + else: + # TODO: differentiate between server errors and client errors + response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) @negotiate(MsgpackFormatter) @negotiate(JSONFormatter) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) kw = decode_request(request) return obj_meth(**kw) diff --git a/swh/core/api/negotiation.py b/swh/core/api/negotiation.py index de57742..1322862 100644 --- a/swh/core/api/negotiation.py +++ b/swh/core/api/negotiation.py @@ -1,153 +1,159 @@ # This code is a partial and adapted copy of # https://github.com/nickstenning/negotiate # # Copyright 2012-2013 Nick Stenning # 2019 The Software Heritage developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # from collections import defaultdict from decorator import decorator from inspect import getcallargs -from typing import Any, List, Optional +from typing import Any, List, Optional, Callable, \ + Type, NoReturn, DefaultDict + +from requests import Response class FormatterNotFound(Exception): pass class Formatter: - format = None # type: Optional[str] - mimetypes = [] # type: List[Any] + format: Optional[str] = None + mimetypes: List[str] = [] - def __init__(self, request_mimetype=None): + def __init__(self, request_mimetype: Optional[str] = None) -> None: if request_mimetype is None or request_mimetype not in self.mimetypes: try: self.response_mimetype = self.mimetypes[0] except IndexError: raise NotImplementedError( "%s.mimetypes should be a non-empty list" % self.__class__.__name__) else: self.response_mimetype = request_mimetype - def configure(self): + def configure(self) -> None: pass - def render(self, obj): + def render(self, obj: Any) -> bytes: raise NotImplementedError( "render() should be implemented by Formatter subclasses") - def __call__(self, obj): + def __call__(self, obj: Any) -> Response: return self._make_response( self.render(obj), content_type=self.response_mimetype) - def _make_response(self, body, content_type): + def _make_response(self, body: bytes, content_type: str) -> Response: raise NotImplementedError( "_make_response() should be implemented by " "framework-specific subclasses of Formatter" ) class Negotiator: - def __init__(self, func): + def __init__(self, func: Callable[..., Any]) -> None: self.func = func - self._formatters = [] - self._formatters_by_format = defaultdict(list) - self._formatters_by_mimetype = defaultdict(list) + self._formatters: List[Type[Formatter]] = [] + self._formatters_by_format: DefaultDict = defaultdict(list) + self._formatters_by_mimetype: DefaultDict = defaultdict(list) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Response: result = self.func(*args, **kwargs) format = getcallargs(self.func, *args, **kwargs).get('format') mimetype = self.best_mimetype() try: formatter = self.get_formatter(format, mimetype) except FormatterNotFound as e: return self._abort(404, str(e)) return formatter(result) - def register_formatter(self, formatter, *args, **kwargs): + def register_formatter(self, formatter: Type[Formatter], + *args, **kwargs) -> None: self._formatters.append(formatter) self._formatters_by_format[formatter.format].append( (formatter, args, kwargs)) for mimetype in formatter.mimetypes: self._formatters_by_mimetype[mimetype].append( (formatter, args, kwargs)) - def get_formatter(self, format=None, mimetype=None): + def get_formatter(self, format: Optional[str] = None, + mimetype: Optional[str] = None) -> Formatter: if format is None and mimetype is None: raise TypeError( "get_formatter expects one of the 'format' or 'mimetype' " "kwargs to be set") if format is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_format[format][0]) except IndexError: raise FormatterNotFound( "Formatter for format '%s' not found!" % format) elif mimetype is not None: try: # the first added will be the most specific formatter_cls, args, kwargs = ( self._formatters_by_mimetype[mimetype][0]) except IndexError: raise FormatterNotFound( "Formatter for mimetype '%s' not found!" % mimetype) formatter = formatter_cls(request_mimetype=mimetype) formatter.configure(*args, **kwargs) return formatter @property - def accept_mimetypes(self): + def accept_mimetypes(self) -> List[str]: return [m for f in self._formatters for m in f.mimetypes] - def best_mimetype(self): + def best_mimetype(self) -> str: raise NotImplementedError( "best_mimetype() should be implemented in " "framework-specific subclasses of Negotiator" ) - def _abort(self, status_code, err=None): + def _abort(self, status_code: int, err: Optional[str] = None) -> NoReturn: raise NotImplementedError( "_abort() should be implemented in framework-specific " "subclasses of Negotiator" ) -def negotiate(negotiator_cls, formatter_cls, *args, **kwargs): +def negotiate(negotiator_cls: Type[Negotiator], formatter_cls: Type[Formatter], + *args, **kwargs) -> Callable: def _negotiate(f, *args, **kwargs): return f.negotiator(*args, **kwargs) def decorate(f): if not hasattr(f, 'negotiator'): f.negotiator = negotiator_cls(f) f.negotiator.register_formatter(formatter_cls, *args, **kwargs) return decorator(_negotiate, f) return decorate diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index ac43bfb..f9aca96 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,193 +1,198 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json import types from uuid import UUID import arrow import dateutil.parser import msgpack +from typing import Any, Dict, Union, Tuple +from requests import Response -def encode_data_client(data): + +def encode_data_client(data: Any) -> bytes: try: return msgpack_dumps(data) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) -def decode_response(response): +def decode_response(response: Response) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): r = msgpack_loads(response.content) elif content_type.startswith('application/json'): r = json.loads(response.text, cls=SWHJSONDecoder) elif content_type.startswith('text/'): r = response.text else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ - def default(self, o): + def default(self, o: Any + ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: if isinstance(o, bytes): return { 'swhtype': 'bytes', 'd': base64.b85encode(o).decode('ascii'), } elif isinstance(o, datetime.datetime): return { 'swhtype': 'datetime', 'd': o.isoformat(), } elif isinstance(o, UUID): return { 'swhtype': 'uuid', 'd': str(o), } elif isinstance(o, datetime.timedelta): return { 'swhtype': 'timedelta', 'd': { 'days': o.days, 'seconds': o.seconds, 'microseconds': o.microseconds, }, } elif isinstance(o, arrow.Arrow): return { 'swhtype': 'arrow', 'd': o.isoformat(), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ - def decode_data(self, o): + + def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: datatype = o['swhtype'] if datatype == 'bytes': return base64.b85decode(o['d']) elif datatype == 'datetime': return dateutil.parser.parse(o['d']) elif datatype == 'uuid': return UUID(o['d']) elif datatype == 'timedelta': return datetime.timedelta(**o['d']) elif datatype == 'arrow': return arrow.get(o['d']) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o - def raw_decode(self, s, idx=0): + def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index -def msgpack_dumps(data): +def msgpack_dumps(data: Any) -> bytes: """Write data as a msgpack stream""" def encode_types(obj): if isinstance(obj, datetime.datetime): return {b'__datetime__': True, b's': obj.isoformat()} if isinstance(obj, types.GeneratorType): return list(obj) if isinstance(obj, UUID): return {b'__uuid__': True, b's': str(obj)} if isinstance(obj, datetime.timedelta): return { b'__timedelta__': True, b's': { 'days': obj.days, 'seconds': obj.seconds, 'microseconds': obj.microseconds, }, } if isinstance(obj, arrow.Arrow): return {b'__arrow__': True, b's': obj.isoformat()} return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) -def msgpack_loads(data): +def msgpack_loads(data: bytes) -> Any: """Read data as a msgpack stream""" def decode_types(obj): if b'__datetime__' in obj and obj[b'__datetime__']: return dateutil.parser.parse(obj[b's']) if b'__uuid__' in obj and obj[b'__uuid__']: return UUID(obj[b's']) if b'__timedelta__' in obj and obj[b'__timedelta__']: return datetime.timedelta(**obj[b's']) if b'__arrow__' in obj and obj[b'__arrow__']: return arrow.get(obj[b's']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py index 38b6370..16a84b2 100644 --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -1,102 +1,107 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core.api import remote_api_endpoint, RPCServerApp, RPCClient -from swh.core.api import error_handler, encode_data_server +from swh.core.api import error_handler, encode_data_server, RemoteException # this class is used on the server part class RPCTest: @remote_api_endpoint('endpoint_url') def endpoint(self, test_data, db=None, cur=None): assert test_data == 'spam' return 'egg' @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): return data @remote_api_endpoint('raises_typeerror') def raise_typeerror(self): raise TypeError('Did I pass through?') # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient # proxy class from this does not handle inheritance properly. # We do add an endpoint on the client side that has no implementation # server-side to test this very situation (in should generate a 404) class RPCTest2: @remote_api_endpoint('endpoint_url') def endpoint(self, test_data, db=None, cur=None): assert test_data == 'spam' return 'egg' @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): return data @remote_api_endpoint('not_on_server') def not_on_server(self, db=None, cur=None): return 'ok' @remote_api_endpoint('raises_typeerror') def raise_typeerror(self): return 'data' class RPCTestClient(RPCClient): backend_class = RPCTest2 @pytest.fixture def app(): # This fixture is used by the 'swh_rpc_adapter' fixture # which is defined in swh/core/pytest_plugin.py application = RPCServerApp('testapp', backend_class=RPCTest) @application.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data_server) return application @pytest.fixture def swh_rpc_client_class(): # This fixture is used by the 'swh_rpc_client' fixture # which is defined in swh/core/pytest_plugin.py return RPCTestClient def test_api_client_endpoint_missing(swh_rpc_client): with pytest.raises(AttributeError): swh_rpc_client.missing(data='whatever') def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL - with pytest.raises(Exception, match='404 Not Found'): + with pytest.raises(Exception, match='404 not found'): swh_rpc_client.not_on_server() def test_api_endpoint_kwargs(swh_rpc_client): res = swh_rpc_client.something(data='whatever') assert res == 'whatever' res = swh_rpc_client.endpoint(test_data='spam') assert res == 'egg' def test_api_endpoint_args(swh_rpc_client): res = swh_rpc_client.something('whatever') assert res == 'whatever' res = swh_rpc_client.endpoint('spam') assert res == 'egg' def test_api_typeerror(swh_rpc_client): - with pytest.raises(TypeError, match='Did I pass through?'): + with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() + + assert exc_info.value.args[0]['type'] == 'TypeError' + assert exc_info.value.args[0]['args'] == ['Did I pass through?'] + assert str(exc_info.value) \ + == "" diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py index 0090457..19a5867 100644 --- a/swh/core/db/__init__.py +++ b/swh/core/db/__init__.py @@ -1,212 +1,212 @@ # Copyright (C) 2015-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import binascii import datetime import enum import json import logging import os import sys import threading from contextlib import contextmanager import psycopg2 import psycopg2.extras logger = logging.getLogger(__name__) psycopg2.extras.register_uuid() def escape(data): if data is None: return '' if isinstance(data, bytes): return '\\x%s' % binascii.hexlify(data).decode('ascii') elif isinstance(data, str): return '"%s"' % data.replace('"', '""') elif isinstance(data, datetime.datetime): # We escape twice to make sure the string generated by # isoformat gets escaped return escape(data.isoformat()) elif isinstance(data, dict): return escape(json.dumps(data)) elif isinstance(data, list): return escape("{%s}" % ','.join(escape(d) for d in data)) elif isinstance(data, psycopg2.extras.Range): # We escape twice here too, so that we make sure # everything gets passed to copy properly return escape( '%s%s,%s%s' % ( '[' if data.lower_inc else '(', '-infinity' if data.lower_inf else escape(data.lower), 'infinity' if data.upper_inf else escape(data.upper), ']' if data.upper_inc else ')', ) ) elif isinstance(data, enum.IntEnum): return escape(int(data)) else: # We don't escape here to make sure we pass literals properly return str(data) def typecast_bytea(value, cur): if value is not None: data = psycopg2.BINARY(value, cur) return data.tobytes() class BaseDb: """Base class for swh.*.*Db. cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb """ @classmethod def adapt_conn(cls, conn): """Makes psycopg2 use 'bytes' to decode bytea instead of 'memoryview', for this connection.""" t_bytes = psycopg2.extensions.new_type( (17,), "bytea", typecast_bytea) psycopg2.extensions.register_type(t_bytes, conn) t_bytes_array = psycopg2.extensions.new_array_type( (1001,), "bytea[]", t_bytes) psycopg2.extensions.register_type(t_bytes_array, conn) @classmethod def connect(cls, *args, **kwargs): """factory method to create a DB proxy Accepts all arguments of psycopg2.connect; only some specific possibilities are reported below. Args: connstring: libpq2 connection string """ conn = psycopg2.connect(*args, **kwargs) return cls(conn) @classmethod def from_pool(cls, pool): conn = pool.getconn() return cls(conn, pool=pool) def __init__(self, conn, pool=None): """create a DB proxy Args: conn: psycopg2 connection to the SWH DB pool: psycopg2 pool of connections """ self.adapt_conn(conn) self.conn = conn self.pool = pool def put_conn(self): if self.pool: self.pool.putconn(self.conn) def cursor(self, cur_arg=None): """get a cursor: from cur_arg if given, or a fresh one otherwise meant to avoid boilerplate if/then/else in methods that proxy stored procedures """ if cur_arg is not None: return cur_arg else: return self.conn.cursor() _cursor = cursor # for bw compat @contextmanager def transaction(self): """context manager to execute within a DB transaction Yields: a psycopg2 cursor """ with self.conn.cursor() as cur: try: yield cur self.conn.commit() except Exception: if not self.conn.closed: self.conn.rollback() raise def copy_to(self, items, tblname, columns, cur=None, item_cb=None, default_values={}): """Copy items' entries to table tblname with columns information. Args: items (List[dict]): dictionaries of data to copy over tblname. tblname (str): destination table's name. columns ([str]): keys to access data in items and also the column names in the destination table. default_values (dict): dictionary of default values to use when inserting entried int the tblname table. cur: a db cursor; if not given, a new cursor will be created. item_cb (fn): optional function to apply to items's entry. """ read_file, write_file = os.pipe() exc_info = None def writer(): nonlocal exc_info cursor = self.cursor(cur) with open(read_file, 'r') as f: try: cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % ( tblname, ', '.join(columns)), f) except Exception: # Tell the main thread about the exception exc_info = sys.exc_info() write_thread = threading.Thread(target=writer) write_thread.start() try: with open(write_file, 'w') as f: for d in items: if item_cb is not None: item_cb(d) line = [] for k in columns: + value = d.get(k, default_values.get(k)) try: - value = d.get(k, default_values.get(k)) line.append(escape(value)) except Exception as e: logger.error( 'Could not escape value `%r` for column `%s`:' 'Received exception: `%s`', value, k, e ) raise e from None f.write(','.join(line)) f.write('\n') finally: # No problem bubbling up exceptions, but we still need to make sure # we finish copying, even though we're probably going to cancel the # transaction. write_thread.join() if exc_info: # postgresql returned an error, let's raise it. raise exc_info[1].with_traceback(exc_info[2]) def mktemp(self, tblname, cur=None): self.cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,)) diff --git a/version.txt b/version.txt index 081a206..ced1abc 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.0.87-0-g32423fb \ No newline at end of file +v0.0.88-0-g08289ce \ No newline at end of file