diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 1dcfe88..405f460 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,182 +1,204 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json import types from uuid import UUID import arrow import dateutil.parser import msgpack from typing import Any, Dict, Union, Tuple from requests import Response ENCODERS = [ (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), (datetime.datetime, 'datetime', datetime.datetime.isoformat), (datetime.timedelta, 'timedelta', lambda o: { 'days': o.days, 'seconds': o.seconds, 'microseconds': o.microseconds, }), (UUID, 'uuid', str), # Only for JSON: (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), ] DECODERS = { 'arrow': arrow.get, 'datetime': dateutil.parser.parse, 'timedelta': lambda d: datetime.timedelta(**d), 'uuid': UUID, # Only for JSON: 'bytes': base64.b85decode, } -def encode_data_client(data: Any) -> bytes: +def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: - return msgpack_dumps(data) + return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) -def decode_response(response: Response) -> Any: +def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content) + r = msgpack_loads(response.content, + extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder) + r = json.loads(response.text, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ + def __init__(self, extra_encoders=None, **kwargs): + super().__init__(**kwargs) + self.encoders = ENCODERS + if extra_encoders: + self.encoders += extra_encoders + def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: - for (type_, type_name, encoder) in ENCODERS: + for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { 'swhtype': type_name, 'd': encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ + def __init__(self, extra_decoders=None, **kwargs): + super().__init__(**kwargs) + self.decoders = DECODERS + if extra_decoders: + self.decoders = {**self.decoders, **extra_decoders} + def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) - decoder = DECODERS.get(o['swhtype']) + decoder = self.decoders.get(o['swhtype']) if decoder: return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index -def msgpack_dumps(data: Any) -> bytes: +def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" + encoders = ENCODERS + if extra_encoders: + encoders += extra_encoders + def encode_types(obj): if isinstance(obj, types.GeneratorType): return list(obj) - for (type_, type_name, encoder) in ENCODERS: + for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b'swhtype': type_name, b'd': encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) -def msgpack_loads(data: bytes) -> Any: +def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" + decoders = DECODERS + if extra_decoders: + decoders = {**decoders, **extra_decoders} + def decode_types(obj): if set(obj.keys()) == {b'd', b'swhtype'}: - decoder = DECODERS.get(obj[b'swhtype']) + decoder = decoders.get(obj[b'swhtype']) if decoder: return decoder(obj[b'd']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 64f13e2..373518b 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,92 +1,131 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, decode_response ) +class ExtraType: + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + def __repr__(self): + return f'ExtraType({self.arg1}, {self.arg2})' + + def __eq__(self, other): + return (self.arg1, self.arg2) == (other.arg1, other.arg2) + + +extra_encoders = [ + (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) +] + + +extra_decoders = { + 'extratype': lambda o: ExtraType(*o), +} + + class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { 'bytes': b'123456789\x99\xaf\xff\x00\x12', 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz), 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc), 'datetime_delta': datetime.timedelta(64), 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), } self.encoded_data = { 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, 'datetime_naive': {'swhtype': 'datetime', 'd': '2015-01-01T12:04:42.231455'}, 'datetime_tz': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+01:58'}, 'datetime_utc': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+00:00'}, 'datetime_delta': {'swhtype': 'timedelta', 'd': {'days': 64, 'seconds': 0, 'microseconds': 0}}, 'arrow_date': {'swhtype': 'arrow', 'd': '2018-04-25T16:17:53.533672+00:00'}, 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': {'swhtype': 'uuid', 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) + def test_round_trip_json_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = json.dumps(original_data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + self.assertEqual( + original_data, + json.loads( + data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) + def test_round_trip_msgpack_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = msgpack_dumps(original_data, extra_encoders=extra_encoders) + self.assertEqual( + original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): mock_requests.get('https://example.org/test/data', json=self.encoded_data, headers={'content-type': 'application/json'}) response = requests.get('https://example.org/test/data') assert decode_response(response) == self.data