diff --git a/PKG-INFO b/PKG-INFO index dfe1bee..0ab367d 100644 --- a/PKG-INFO +++ b/PKG-INFO @@ -1,91 +1,91 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.88 +Version: 0.0.89 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/requirements-http.txt b/requirements-http.txt index c66192b..f307d47 100644 --- a/requirements-http.txt +++ b/requirements-http.txt @@ -1,10 +1,10 @@ # requirements for swh.core.api aiohttp aiohttp_utils >= 3.1.1 arrow decorator Flask +iso8601 msgpack > 0.5 -python-dateutil requests blinker # dependency of sentry-sdk[flask] diff --git a/swh.core.egg-info/PKG-INFO b/swh.core.egg-info/PKG-INFO index dfe1bee..0ab367d 100644 --- a/swh.core.egg-info/PKG-INFO +++ b/swh.core.egg-info/PKG-INFO @@ -1,91 +1,91 @@ Metadata-Version: 2.1 Name: swh.core -Version: 0.0.88 +Version: 0.0.89 Summary: Software Heritage core utilities Home-page: https://forge.softwareheritage.org/diffusion/DCORE/ Author: Software Heritage developers Author-email: swh-devel@inria.fr License: UNKNOWN Project-URL: Bug Reports, https://forge.softwareheritage.org/maniphest Project-URL: Funding, https://www.softwareheritage.org/donate Project-URL: Source, https://forge.softwareheritage.org/source/swh-core Description: swh-core ======== core library for swh's modules: - config parser - hash computations - serialization - logging mechanism - database connection - http-based RPC client/server Development ----------- We strongly recommend you to use a [virtualenv][1] if you want to run tests or hack the code. To set up your development environment: ``` (swh) user@host:~/swh-environment/swh-core$ pip install -e .[testing] ``` This will install every Python package needed to run this package's tests. Unit tests can be executed using [pytest][2] or [tox][3]. ``` (swh) user@host:~/swh-environment/swh-core$ pytest ============================== test session starts ============================== platform linux -- Python 3.7.3, pytest-3.10.1, py-1.8.0, pluggy-0.12.0 hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/home/ddouard/src/swh-environment/swh-core/.hypothesis/examples') rootdir: /home/ddouard/src/swh-environment/swh-core, inifile: pytest.ini plugins: requests-mock-1.6.0, hypothesis-4.26.4, celery-4.3.0, postgresql-1.4.1 collected 89 items swh/core/api/tests/test_api.py .. [ 2%] swh/core/api/tests/test_async.py .... [ 6%] swh/core/api/tests/test_serializers.py ..... [ 12%] swh/core/db/tests/test_db.py .... [ 16%] swh/core/tests/test_cli.py ...... [ 23%] swh/core/tests/test_config.py .............. [ 39%] swh/core/tests/test_statsd.py ........................................... [ 87%] .... [ 92%] swh/core/tests/test_utils.py ....... [100%] ===================== 89 passed, 9 warnings in 6.94 seconds ===================== ``` Note: this git repository uses [pre-commit][4] hooks to ensure better and more consistent code. It should already be installed in your virtualenv (if not, just type `pip install pre-commit`). Make sure to activate it in your local copy of the git repository: ``` (swh) user@host:~/swh-environment/swh-core$ pre-commit install pre-commit installed at .git/hooks/pre-commit ``` Please read the [developer setup manual][5] for more information on how to hack on Software Heritage. [1]: https://virtualenv.pypa.io [2]: https://docs.pytest.org [3]: https://tox.readthedocs.io [4]: https://pre-commit.com [5]: https://docs.softwareheritage.org/devel/developer-setup.html Platform: UNKNOWN Classifier: Programming Language :: Python :: 3 Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3) Classifier: Operating System :: OS Independent Classifier: Development Status :: 5 - Production/Stable Description-Content-Type: text/markdown Provides-Extra: testing-core Provides-Extra: logging Provides-Extra: db Provides-Extra: testing-db Provides-Extra: http Provides-Extra: testing diff --git a/swh.core.egg-info/requires.txt b/swh.core.egg-info/requires.txt index f41c6f7..e25eb0c 100644 --- a/swh.core.egg-info/requires.txt +++ b/swh.core.egg-info/requires.txt @@ -1,52 +1,52 @@ Click Deprecated PyYAML sentry-sdk [db] psycopg2 [http] aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask +iso8601 msgpack>0.5 -python-dateutil requests blinker [logging] systemd-python [testing] pytest pytest-mock requests-mock hypothesis>=3.11.0 pre-commit pytz pytest-postgresql psycopg2 aiohttp aiohttp_utils>=3.1.1 arrow decorator Flask +iso8601 msgpack>0.5 -python-dateutil requests blinker systemd-python [testing-core] pytest pytest-mock requests-mock hypothesis>=3.11.0 pre-commit pytz [testing-db] pytest-postgresql diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index ba08caa..dc3bdea 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,413 +1,429 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc -import datetime import functools import inspect -import json import logging import pickle import requests import traceback -from typing import Any, ClassVar, List, Optional, Type +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, +) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, SWHJSONDecoder) + msgpack_dumps, msgpack_loads, + json_dumps, json_loads) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) - -class SWHJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return str(obj) - # Let the base class default method raise the TypeError - return super().default(obj) + def configure(self, extra_encoders): + self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): - return json.dumps(obj, cls=SWHJSONEncoder) + return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): - return msgpack_dumps(obj) + return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__(self, payload: Optional[Any] = None, response: Optional[requests.Response] = None): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if self.args and isinstance(self.args[0], dict) \ and 'type' in self.args[0] and 'args' in self.args[0]: return ( f'') else: return super().__str__() def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): - data = (encode_data(x) for x in data) + data = (self._encode_data(x) for x in data) else: - data = encode_data(data) + data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) + def _encode_data(self, data): + return encode_data(data, extra_encoders=self.extra_type_encoders) + post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload='404 not found', response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: exception = exc_type(*data['exception']['args']) break else: exception = RemoteException(payload=data['exception'], response=response) else: exception = pickle.loads(data) elif status_class == 5: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: exception = RemoteException(payload=data['exception'], response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f'API HTTP error: {status_code} {response.content}', response=response) - def _decode_response(self, response): - self.raise_for_status(response) - return decode_response(response) + def _decode_response(self, response, check_status=True): + if check_status: + self.raise_for_status(response) + return decode_response( + response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' -ENCODERS = { +ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, - 'application/json': json.dumps, + 'application/json': json_dumps, } -def encode_data_server(data, content_type='application/x-msgpack'): - encoded_data = ENCODERS[content_type](data) +def encode_data_server( + data, content_type='application/x-msgpack', extra_type_encoders=None): + encoded_data = ENCODERS[content_type]( + data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) -def decode_request(request): +def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) + r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) tb = traceback.format_exception(None, exception, exception.__traceback__) error = { 'exception': { 'type': type(exception).__name__, 'args': exception.args, 'message': str(exception), 'traceback': tb, } } response = encoder(error) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) - @negotiate(MsgpackFormatter) - @negotiate(JSONFormatter) + @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) + @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request(request) + kw = decode_request( + request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index f9aca96..57f37ae 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,198 +1,214 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json import types from uuid import UUID import arrow -import dateutil.parser +import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response -def encode_data_client(data: Any) -> bytes: +ENCODERS = [ + (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), + (datetime.datetime, 'datetime', datetime.datetime.isoformat), + (datetime.timedelta, 'timedelta', lambda o: { + 'days': o.days, + 'seconds': o.seconds, + 'microseconds': o.microseconds, + }), + (UUID, 'uuid', str), + + # Only for JSON: + (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), +] + +DECODERS = { + 'arrow': arrow.get, + 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), + 'timedelta': lambda d: datetime.timedelta(**d), + 'uuid': UUID, + + # Only for JSON: + 'bytes': base64.b85decode, +} + + +def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: - return msgpack_dumps(data) + return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) -def decode_response(response: Response) -> Any: +def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content) + r = msgpack_loads(response.content, + extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder) + r = json_loads(response.text, + extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ + def __init__(self, extra_encoders=None, **kwargs): + super().__init__(**kwargs) + self.encoders = ENCODERS + if extra_encoders: + self.encoders += extra_encoders + def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: - if isinstance(o, bytes): - return { - 'swhtype': 'bytes', - 'd': base64.b85encode(o).decode('ascii'), - } - elif isinstance(o, datetime.datetime): - return { - 'swhtype': 'datetime', - 'd': o.isoformat(), - } - elif isinstance(o, UUID): - return { - 'swhtype': 'uuid', - 'd': str(o), - } - elif isinstance(o, datetime.timedelta): - return { - 'swhtype': 'timedelta', - 'd': { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }, - } - elif isinstance(o, arrow.Arrow): - return { - 'swhtype': 'arrow', - 'd': o.isoformat(), - } + for (type_, type_name, encoder) in self.encoders: + if isinstance(o, type_): + return { + 'swhtype': type_name, + 'd': encoder(o), + } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ + def __init__(self, extra_decoders=None, **kwargs): + super().__init__(**kwargs) + self.decoders = DECODERS + if extra_decoders: + self.decoders = {**self.decoders, **extra_decoders} + def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: - datatype = o['swhtype'] - if datatype == 'bytes': + if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) - elif datatype == 'datetime': - return dateutil.parser.parse(o['d']) - elif datatype == 'uuid': - return UUID(o['d']) - elif datatype == 'timedelta': - return datetime.timedelta(**o['d']) - elif datatype == 'arrow': - return arrow.get(o['d']) + decoder = self.decoders.get(o['swhtype']) + if decoder: + return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index -def msgpack_dumps(data: Any) -> bytes: +def json_dumps(data: Any, extra_encoders=None) -> str: + return json.dumps(data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + + +def json_loads(data: str, extra_decoders=None) -> Any: + return json.loads(data, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) + + +def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" + encoders = ENCODERS + if extra_encoders: + encoders += extra_encoders + def encode_types(obj): - if isinstance(obj, datetime.datetime): - return {b'__datetime__': True, b's': obj.isoformat()} if isinstance(obj, types.GeneratorType): return list(obj) - if isinstance(obj, UUID): - return {b'__uuid__': True, b's': str(obj)} - if isinstance(obj, datetime.timedelta): - return { - b'__timedelta__': True, - b's': { - 'days': obj.days, - 'seconds': obj.seconds, - 'microseconds': obj.microseconds, - }, - } - if isinstance(obj, arrow.Arrow): - return {b'__arrow__': True, b's': obj.isoformat()} + + for (type_, type_name, encoder) in encoders: + if isinstance(obj, type_): + return { + b'swhtype': type_name, + b'd': encoder(obj), + } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) -def msgpack_loads(data: bytes) -> Any: +def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" + decoders = DECODERS + if extra_decoders: + decoders = {**decoders, **extra_decoders} + def decode_types(obj): - if b'__datetime__' in obj and obj[b'__datetime__']: - return dateutil.parser.parse(obj[b's']) - if b'__uuid__' in obj and obj[b'__uuid__']: - return UUID(obj[b's']) - if b'__timedelta__' in obj and obj[b'__timedelta__']: - return datetime.timedelta(**obj[b's']) - if b'__arrow__' in obj and obj[b'__arrow__']: - return arrow.get(obj[b's']) + if set(obj.keys()) == {b'd', b'swhtype'}: + decoder = decoders.get(obj[b'swhtype']) + if decoder: + return decoder(obj[b'd']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 307a5e7..58450a9 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,56 +1,73 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import pytest from swh.core.api import remote_api_endpoint, RPCClient +from .test_serializers import ExtraType, extra_encoders, extra_decoders + @pytest.fixture def rpc_client(requests_mock): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): - return 'egg' + ... @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): - return 'spam' + ... + + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + ... class Testclient(RPCClient): backend_class = TestStorage + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders def callback(request, context): assert request.headers['Content-Type'] == 'application/x-msgpack' context.headers['Content-Type'] = 'application/x-msgpack' if request.path == '/test_endpoint_url': context.content = b'\xa3egg' elif request.path == '/path/to/endpoint': context.content = b'\xa4spam' + elif request.path == '/serializer_test': + context.content = ( + b'\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') else: assert False return context.content requests_mock.post(re.compile('mock://example.com/'), content=callback) return Testclient(url='mock://example.com') def test_client(rpc_client): assert hasattr(rpc_client, 'test_endpoint') assert hasattr(rpc_client, 'something') res = rpc_client.test_endpoint('spam') assert res == 'egg' res = rpc_client.test_endpoint(test_data='spam') assert res == 'egg' res = rpc_client.something('whatever') assert res == 'spam' res = rpc_client.something(data='whatever') assert res == 'spam' + + +def test_client_extra_serializers(rpc_client): + res = rpc_client.serializer_test(['foo', ExtraType('bar', b'baz')]) + assert res == ExtraType({'spam': 'egg'}, 'qux') diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py index 9399f62..81beb12 100644 --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -1,73 +1,100 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import json import msgpack from flask import url_for from swh.core.api import remote_api_endpoint, RPCServerApp +from .test_serializers import ExtraType, extra_encoders, extra_decoders + + +class MyRPCServerApp(RPCServerApp): + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders + @pytest.fixture def app(): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): assert test_data == 'spam' return 'egg' @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): return data - return RPCServerApp('testapp', backend_class=TestStorage) + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + assert data == ['foo', ExtraType('bar', b'baz')] + return ExtraType({'spam': 'egg'}, 'qux') + + return MyRPCServerApp('testapp', backend_class=TestStorage) def test_api_endpoint(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Content-Type', 'application/json'), ('Accept', 'application/json')], data=json.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/json' def test_api_nego_default(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Content-Type', 'application/json')], data=json.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/json' assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Accept', 'application/x-msgpack'), ('Content-Type', 'application/x-msgpack')], data=msgpack.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == b'\xa4toto' def test_rpc_server(flask_app_client): res = flask_app_client.post( url_for('test_endpoint'), headers=[('Content-Type', 'application/x-msgpack'), ('Accept', 'application/x-msgpack')], data=b'\x81\xa9test_data\xa4spam') assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == b'\xa3egg' + + +def test_rpc_server_extra_serializers(flask_app_client): + res = flask_app_client.post( + url_for('serializer_test'), + headers=[('Content-Type', 'application/x-msgpack'), + ('Accept', 'application/x-msgpack')], + data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + + assert res.status_code == 200 + assert res.mimetype == 'application/x-msgpack' + assert res.data == ( + b'\x82\xc4\x07swhtype\xa9extratype\xc4' + b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 64f13e2..3f1a7aa 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,92 +1,133 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json +from typing import Any, Callable, List, Tuple import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, decode_response ) +class ExtraType: + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + def __repr__(self): + return f'ExtraType({self.arg1}, {self.arg2})' + + def __eq__(self, other): + return isinstance(other, ExtraType) \ + and (self.arg1, self.arg2) == (other.arg1, other.arg2) + + +extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ + (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) +] + + +extra_decoders = { + 'extratype': lambda o: ExtraType(*o), +} + + class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { 'bytes': b'123456789\x99\xaf\xff\x00\x12', 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz), 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc), 'datetime_delta': datetime.timedelta(64), 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), } self.encoded_data = { 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, 'datetime_naive': {'swhtype': 'datetime', 'd': '2015-01-01T12:04:42.231455'}, 'datetime_tz': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+01:58'}, 'datetime_utc': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+00:00'}, 'datetime_delta': {'swhtype': 'timedelta', 'd': {'days': 64, 'seconds': 0, 'microseconds': 0}}, 'arrow_date': {'swhtype': 'arrow', 'd': '2018-04-25T16:17:53.533672+00:00'}, 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': {'swhtype': 'uuid', 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) + def test_round_trip_json_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = json.dumps(original_data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + self.assertEqual( + original_data, + json.loads( + data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) + def test_round_trip_msgpack_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = msgpack_dumps(original_data, extra_encoders=extra_encoders) + self.assertEqual( + original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): mock_requests.get('https://example.org/test/data', json=self.encoded_data, headers={'content-type': 'application/json'}) response = requests.get('https://example.org/test/data') assert decode_response(response) == self.data diff --git a/version.txt b/version.txt index ced1abc..43ba54e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -v0.0.88-0-g08289ce \ No newline at end of file +v0.0.89-0-g9667700 \ No newline at end of file