diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -23,6 +23,9 @@ [mypy-django.*] # false positive, only used my hypotesis' extras ignore_missing_imports = True +[mypy-iso8601.*] +ignore_missing_imports = True + [mypy-msgpack.*] ignore_missing_imports = True diff --git a/requirements-http.txt b/requirements-http.txt --- a/requirements-http.txt +++ b/requirements-http.txt @@ -4,7 +4,7 @@ arrow decorator Flask +iso8601 msgpack > 0.5 -python-dateutil requests blinker # dependency of sentry-sdk[flask] diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -4,23 +4,24 @@ # See top-level LICENSE file for more information from collections import abc -import datetime import functools import inspect -import json import logging import pickle import requests import traceback -from typing import Any, ClassVar, List, Optional, Type +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, +) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, SWHJSONDecoder) + msgpack_dumps, msgpack_loads, + json_dumps, json_loads) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, @@ -49,15 +50,8 @@ def _make_response(self, body, content_type): return Response(body, content_type=content_type) - -class SWHJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return str(obj) - # Let the base class default method raise the TypeError - return super().default(obj) + def configure(self, extra_encoders): + self.extra_encoders = extra_encoders class JSONFormatter(Formatter): @@ -65,7 +59,7 @@ mimetypes = ['application/json'] def render(self, obj): - return json.dumps(obj, cls=SWHJSONEncoder) + return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): @@ -73,7 +67,7 @@ mimetypes = ['application/x-msgpack'] def render(self, obj): - return msgpack_dumps(obj) + return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes @@ -184,6 +178,13 @@ has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): @@ -223,9 +224,9 @@ def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): - data = (encode_data(x) for x in data) + data = (self._encode_data(x) for x in data) else: - data = encode_data(data) + data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, @@ -239,6 +240,9 @@ else: return self._decode_response(response) + def _encode_data(self, data): + return encode_data(data, extra_encoders=self.extra_type_encoders) + post_stream = post def get(self, endpoint, **opts): @@ -274,7 +278,7 @@ # after they are all upgraded try: if status_class == 4: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: @@ -287,7 +291,7 @@ exception = pickle.loads(data) elif status_class == 5: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: @@ -305,9 +309,11 @@ payload=f'API HTTP error: {status_code} {response.content}', response=response) - def _decode_response(self, response): - self.raise_for_status(response) - return decode_response(response) + def _decode_response(self, response, check_status=True): + if check_status: + self.raise_for_status(response) + return decode_response( + response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) @@ -319,32 +325,34 @@ encoding_errors = 'surrogateescape' -ENCODERS = { +ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, - 'application/json': json.dumps, + 'application/json': json_dumps, } -def encode_data_server(data, content_type='application/x-msgpack'): - encoded_data = ENCODERS[content_type](data) +def encode_data_server( + data, content_type='application/x-msgpack', extra_type_encoders=None): + encoded_data = ENCODERS[content_type]( + data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) -def decode_request(request): +def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) + r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) @@ -387,6 +395,13 @@ """ request_class = BytesRequest + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) @@ -403,11 +418,12 @@ from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) - @negotiate(MsgpackFormatter) - @negotiate(JSONFormatter) + @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) + @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request(request) + kw = decode_request( + request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -10,28 +10,55 @@ from uuid import UUID import arrow -import dateutil.parser +import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response -def encode_data_client(data: Any) -> bytes: +ENCODERS = [ + (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), + (datetime.datetime, 'datetime', datetime.datetime.isoformat), + (datetime.timedelta, 'timedelta', lambda o: { + 'days': o.days, + 'seconds': o.seconds, + 'microseconds': o.microseconds, + }), + (UUID, 'uuid', str), + + # Only for JSON: + (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), +] + +DECODERS = { + 'arrow': arrow.get, + 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), + 'timedelta': lambda d: datetime.timedelta(**d), + 'uuid': UUID, + + # Only for JSON: + 'bytes': base64.b85decode, +} + + +def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: - return msgpack_dumps(data) + return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) -def decode_response(response: Response) -> Any: +def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content) + r = msgpack_loads(response.content, + extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder) + r = json_loads(response.text, + extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: @@ -65,37 +92,20 @@ """ + def __init__(self, extra_encoders=None, **kwargs): + super().__init__(**kwargs) + self.encoders = ENCODERS + if extra_encoders: + self.encoders += extra_encoders + def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: - if isinstance(o, bytes): - return { - 'swhtype': 'bytes', - 'd': base64.b85encode(o).decode('ascii'), - } - elif isinstance(o, datetime.datetime): - return { - 'swhtype': 'datetime', - 'd': o.isoformat(), - } - elif isinstance(o, UUID): - return { - 'swhtype': 'uuid', - 'd': str(o), - } - elif isinstance(o, datetime.timedelta): - return { - 'swhtype': 'timedelta', - 'd': { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }, - } - elif isinstance(o, arrow.Arrow): - return { - 'swhtype': 'arrow', - 'd': o.isoformat(), - } + for (type_, type_name, encoder) in self.encoders: + if isinstance(o, type_): + return { + 'swhtype': type_name, + 'd': encoder(o), + } try: return super().default(o) except TypeError as e: @@ -127,20 +137,20 @@ """ + def __init__(self, extra_decoders=None, **kwargs): + super().__init__(**kwargs) + self.decoders = DECODERS + if extra_decoders: + self.decoders = {**self.decoders, **extra_decoders} + def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: - datatype = o['swhtype'] - if datatype == 'bytes': + if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) - elif datatype == 'datetime': - return dateutil.parser.parse(o['d']) - elif datatype == 'uuid': - return UUID(o['d']) - elif datatype == 'timedelta': - return datetime.timedelta(**o['d']) - elif datatype == 'arrow': - return arrow.get(o['d']) + decoder = self.decoders.get(o['swhtype']) + if decoder: + return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] @@ -152,42 +162,48 @@ return self.decode_data(data), index -def msgpack_dumps(data: Any) -> bytes: +def json_dumps(data: Any, extra_encoders=None) -> str: + return json.dumps(data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + + +def json_loads(data: str, extra_decoders=None) -> Any: + return json.loads(data, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) + + +def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" + encoders = ENCODERS + if extra_encoders: + encoders += extra_encoders + def encode_types(obj): - if isinstance(obj, datetime.datetime): - return {b'__datetime__': True, b's': obj.isoformat()} if isinstance(obj, types.GeneratorType): return list(obj) - if isinstance(obj, UUID): - return {b'__uuid__': True, b's': str(obj)} - if isinstance(obj, datetime.timedelta): - return { - b'__timedelta__': True, - b's': { - 'days': obj.days, - 'seconds': obj.seconds, - 'microseconds': obj.microseconds, - }, - } - if isinstance(obj, arrow.Arrow): - return {b'__arrow__': True, b's': obj.isoformat()} + + for (type_, type_name, encoder) in encoders: + if isinstance(obj, type_): + return { + b'swhtype': type_name, + b'd': encoder(obj), + } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) -def msgpack_loads(data: bytes) -> Any: +def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" + decoders = DECODERS + if extra_decoders: + decoders = {**decoders, **extra_decoders} + def decode_types(obj): - if b'__datetime__' in obj and obj[b'__datetime__']: - return dateutil.parser.parse(obj[b's']) - if b'__uuid__' in obj and obj[b'__uuid__']: - return UUID(obj[b's']) - if b'__timedelta__' in obj and obj[b'__timedelta__']: - return datetime.timedelta(**obj[b's']) - if b'__arrow__' in obj and obj[b'__arrow__']: - return arrow.get(obj[b's']) + if set(obj.keys()) == {b'd', b'swhtype'}: + decoder = decoders.get(obj[b'swhtype']) + if decoder: + return decoder(obj[b'd']) return obj try: diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -21,6 +21,28 @@ ) +class ExtraType: + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + def __repr__(self): + return f'ExtraType({self.arg1}, {self.arg2})' + + def __eq__(self, other): + return (self.arg1, self.arg2) == (other.arg1, other.arg2) + + +extra_encoders = [ + (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) +] + + +extra_decoders = { + 'extratype': lambda o: ExtraType(*o), +} + + class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) @@ -67,6 +89,16 @@ data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) + def test_round_trip_json_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = json.dumps(original_data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + self.assertEqual( + original_data, + json.loads( + data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) @@ -75,6 +107,13 @@ data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) + def test_round_trip_msgpack_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = msgpack_dumps(original_data, extra_encoders=extra_encoders) + self.assertEqual( + original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder))