diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index ba08caa..dc3bdea 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,413 +1,429 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc -import datetime import functools import inspect -import json import logging import pickle import requests import traceback -from typing import Any, ClassVar, List, Optional, Type +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, +) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, SWHJSONDecoder) + msgpack_dumps, msgpack_loads, + json_dumps, json_loads) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) - -class SWHJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return str(obj) - # Let the base class default method raise the TypeError - return super().default(obj) + def configure(self, extra_encoders): + self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): - return json.dumps(obj, cls=SWHJSONEncoder) + return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): - return msgpack_dumps(obj) + return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__(self, payload: Optional[Any] = None, response: Optional[requests.Response] = None): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if self.args and isinstance(self.args[0], dict) \ and 'type' in self.args[0] and 'args' in self.args[0]: return ( f'') else: return super().__str__() def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): - data = (encode_data(x) for x in data) + data = (self._encode_data(x) for x in data) else: - data = encode_data(data) + data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) + def _encode_data(self, data): + return encode_data(data, extra_encoders=self.extra_type_encoders) + post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload='404 not found', response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: exception = exc_type(*data['exception']['args']) break else: exception = RemoteException(payload=data['exception'], response=response) else: exception = pickle.loads(data) elif status_class == 5: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: exception = RemoteException(payload=data['exception'], response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f'API HTTP error: {status_code} {response.content}', response=response) - def _decode_response(self, response): - self.raise_for_status(response) - return decode_response(response) + def _decode_response(self, response, check_status=True): + if check_status: + self.raise_for_status(response) + return decode_response( + response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' -ENCODERS = { +ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, - 'application/json': json.dumps, + 'application/json': json_dumps, } -def encode_data_server(data, content_type='application/x-msgpack'): - encoded_data = ENCODERS[content_type](data) +def encode_data_server( + data, content_type='application/x-msgpack', extra_type_encoders=None): + encoded_data = ENCODERS[content_type]( + data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) -def decode_request(request): +def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) + r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) tb = traceback.format_exception(None, exception, exception.__traceback__) error = { 'exception': { 'type': type(exception).__name__, 'args': exception.args, 'message': str(exception), 'traceback': tb, } } response = encoder(error) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) - @negotiate(MsgpackFormatter) - @negotiate(JSONFormatter) + @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) + @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request(request) + kw = decode_request( + request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 405f460..c16108e 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,204 +1,214 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json import types from uuid import UUID import arrow import dateutil.parser import msgpack from typing import Any, Dict, Union, Tuple from requests import Response ENCODERS = [ (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), (datetime.datetime, 'datetime', datetime.datetime.isoformat), (datetime.timedelta, 'timedelta', lambda o: { 'days': o.days, 'seconds': o.seconds, 'microseconds': o.microseconds, }), (UUID, 'uuid', str), # Only for JSON: (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), ] DECODERS = { 'arrow': arrow.get, 'datetime': dateutil.parser.parse, 'timedelta': lambda d: datetime.timedelta(**d), 'uuid': UUID, # Only for JSON: 'bytes': base64.b85decode, } def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder, + r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { 'swhtype': type_name, 'd': encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) decoder = self.decoders.get(o['swhtype']) if decoder: return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index +def json_dumps(data: Any, extra_encoders=None) -> str: + return json.dumps(data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + + +def json_loads(data: str, extra_decoders=None) -> Any: + return json.loads(data, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) + + def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b'swhtype': type_name, b'd': encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def decode_types(obj): if set(obj.keys()) == {b'd', b'swhtype'}: decoder = decoders.get(obj[b'swhtype']) if decoder: return decoder(obj[b'd']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 307a5e7..58450a9 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,56 +1,73 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re import pytest from swh.core.api import remote_api_endpoint, RPCClient +from .test_serializers import ExtraType, extra_encoders, extra_decoders + @pytest.fixture def rpc_client(requests_mock): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): - return 'egg' + ... @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): - return 'spam' + ... + + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + ... class Testclient(RPCClient): backend_class = TestStorage + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders def callback(request, context): assert request.headers['Content-Type'] == 'application/x-msgpack' context.headers['Content-Type'] = 'application/x-msgpack' if request.path == '/test_endpoint_url': context.content = b'\xa3egg' elif request.path == '/path/to/endpoint': context.content = b'\xa4spam' + elif request.path == '/serializer_test': + context.content = ( + b'\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') else: assert False return context.content requests_mock.post(re.compile('mock://example.com/'), content=callback) return Testclient(url='mock://example.com') def test_client(rpc_client): assert hasattr(rpc_client, 'test_endpoint') assert hasattr(rpc_client, 'something') res = rpc_client.test_endpoint('spam') assert res == 'egg' res = rpc_client.test_endpoint(test_data='spam') assert res == 'egg' res = rpc_client.something('whatever') assert res == 'spam' res = rpc_client.something(data='whatever') assert res == 'spam' + + +def test_client_extra_serializers(rpc_client): + res = rpc_client.serializer_test(['foo', ExtraType('bar', b'baz')]) + assert res == ExtraType({'spam': 'egg'}, 'qux') diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py index 9399f62..81beb12 100644 --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -1,73 +1,100 @@ # Copyright (C) 2018-2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest import json import msgpack from flask import url_for from swh.core.api import remote_api_endpoint, RPCServerApp +from .test_serializers import ExtraType, extra_encoders, extra_decoders + + +class MyRPCServerApp(RPCServerApp): + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders + @pytest.fixture def app(): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): assert test_data == 'spam' return 'egg' @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): return data - return RPCServerApp('testapp', backend_class=TestStorage) + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + assert data == ['foo', ExtraType('bar', b'baz')] + return ExtraType({'spam': 'egg'}, 'qux') + + return MyRPCServerApp('testapp', backend_class=TestStorage) def test_api_endpoint(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Content-Type', 'application/json'), ('Accept', 'application/json')], data=json.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/json' def test_api_nego_default(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Content-Type', 'application/json')], data=json.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/json' assert res.data == b'"toto"' def test_api_nego_accept(flask_app_client): res = flask_app_client.post( url_for('something'), headers=[('Accept', 'application/x-msgpack'), ('Content-Type', 'application/x-msgpack')], data=msgpack.dumps({'data': 'toto'}), ) assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == b'\xa4toto' def test_rpc_server(flask_app_client): res = flask_app_client.post( url_for('test_endpoint'), headers=[('Content-Type', 'application/x-msgpack'), ('Accept', 'application/x-msgpack')], data=b'\x81\xa9test_data\xa4spam') assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == b'\xa3egg' + + +def test_rpc_server_extra_serializers(flask_app_client): + res = flask_app_client.post( + url_for('serializer_test'), + headers=[('Content-Type', 'application/x-msgpack'), + ('Accept', 'application/x-msgpack')], + data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + + assert res.status_code == 200 + assert res.mimetype == 'application/x-msgpack' + assert res.data == ( + b'\x82\xc4\x07swhtype\xa9extratype\xc4' + b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 373518b..3f1a7aa 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,131 +1,133 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json +from typing import Any, Callable, List, Tuple import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, decode_response ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f'ExtraType({self.arg1}, {self.arg2})' def __eq__(self, other): - return (self.arg1, self.arg2) == (other.arg1, other.arg2) + return isinstance(other, ExtraType) \ + and (self.arg1, self.arg2) == (other.arg1, other.arg2) -extra_encoders = [ +extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) ] extra_decoders = { 'extratype': lambda o: ExtraType(*o), } class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { 'bytes': b'123456789\x99\xaf\xff\x00\x12', 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz), 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc), 'datetime_delta': datetime.timedelta(64), 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), } self.encoded_data = { 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, 'datetime_naive': {'swhtype': 'datetime', 'd': '2015-01-01T12:04:42.231455'}, 'datetime_tz': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+01:58'}, 'datetime_utc': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+00:00'}, 'datetime_delta': {'swhtype': 'timedelta', 'd': {'days': 64, 'seconds': 0, 'microseconds': 0}}, 'arrow_date': {'swhtype': 'arrow', 'd': '2018-04-25T16:17:53.533672+00:00'}, 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': {'swhtype': 'uuid', 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_round_trip_json_extra_types(self): original_data = [ExtraType('baz', self.data), 'qux'] data = json.dumps(original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) self.assertEqual( original_data, json.loads( data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) def test_round_trip_msgpack_extra_types(self): original_data = [ExtraType('baz', self.data), 'qux'] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( original_data, msgpack_loads(data, extra_decoders=extra_decoders)) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): mock_requests.get('https://example.org/test/data', json=self.encoded_data, headers={'content-type': 'application/json'}) response = requests.get('https://example.org/test/data') assert decode_response(response) == self.data