diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 1b33c59..d67dbe5 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,191 +1,191 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime -from json import JSONDecoder, JSONEncoder +import json import types from uuid import UUID import arrow import dateutil.parser import msgpack def encode_data_client(data): try: return msgpack_dumps(data) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) def decode_response(response): content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): r = msgpack_loads(response.content) elif content_type.startswith('application/json'): - r = response.json(cls=SWHJSONDecoder) + r = json.loads(response.text, cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r -class SWHJSONEncoder(JSONEncoder): +class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def default(self, o): if isinstance(o, bytes): return { 'swhtype': 'bytes', 'd': base64.b85encode(o).decode('ascii'), } elif isinstance(o, datetime.datetime): return { 'swhtype': 'datetime', 'd': o.isoformat(), } elif isinstance(o, UUID): return { 'swhtype': 'uuid', 'd': str(o), } elif isinstance(o, datetime.timedelta): return { 'swhtype': 'timedelta', 'd': { 'days': o.days, 'seconds': o.seconds, 'microseconds': o.microseconds, }, } elif isinstance(o, arrow.Arrow): return { 'swhtype': 'arrow', 'd': o.isoformat(), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) -class SWHJSONDecoder(JSONDecoder): +class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def decode_data(self, o): if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: datatype = o['swhtype'] if datatype == 'bytes': return base64.b85decode(o['d']) elif datatype == 'datetime': return dateutil.parser.parse(o['d']) elif datatype == 'uuid': return UUID(o['d']) elif datatype == 'timedelta': return datetime.timedelta(**o['d']) elif datatype == 'arrow': return arrow.get(o['d']) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s, idx=0): data, index = super().raw_decode(s, idx) return self.decode_data(data), index def msgpack_dumps(data): """Write data as a msgpack stream""" def encode_types(obj): if isinstance(obj, datetime.datetime): return {b'__datetime__': True, b's': obj.isoformat()} if isinstance(obj, types.GeneratorType): return list(obj) if isinstance(obj, UUID): return {b'__uuid__': True, b's': str(obj)} if isinstance(obj, datetime.timedelta): return { b'__timedelta__': True, b's': { 'days': obj.days, 'seconds': obj.seconds, 'microseconds': obj.microseconds, }, } if isinstance(obj, arrow.Arrow): return {b'__arrow__': True, b's': obj.isoformat()} return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data): """Read data as a msgpack stream""" def decode_types(obj): if b'__datetime__' in obj and obj[b'__datetime__']: return dateutil.parser.parse(obj[b's']) if b'__uuid__' in obj and obj[b'__uuid__']: return UUID(obj[b's']) if b'__timedelta__' in obj and obj[b'__timedelta__']: return datetime.timedelta(**obj[b's']) if b'__arrow__' in obj and obj[b'__arrow__']: return arrow.get(obj[b's']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 40aec9a..64f13e2 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,81 +1,92 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import unittest from uuid import UUID import arrow +import requests +import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, - msgpack_loads + msgpack_loads, + decode_response ) class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { 'bytes': b'123456789\x99\xaf\xff\x00\x12', 'datetime_naive': datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), 'datetime_tz': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz), 'datetime_utc': datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc), 'datetime_delta': datetime.timedelta(64), 'arrow_date': arrow.get('2018-04-25T16:17:53.533672+00:00'), 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': UUID('cdd8f804-9db6-40c3-93ab-5955d3836234'), } self.encoded_data = { 'bytes': {'swhtype': 'bytes', 'd': 'F)}kWH8wXmIhn8j01^'}, 'datetime_naive': {'swhtype': 'datetime', 'd': '2015-01-01T12:04:42.231455'}, 'datetime_tz': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+01:58'}, 'datetime_utc': {'swhtype': 'datetime', 'd': '2015-03-04T18:25:13.001234+00:00'}, 'datetime_delta': {'swhtype': 'timedelta', 'd': {'days': 64, 'seconds': 0, 'microseconds': 0}}, 'arrow_date': {'swhtype': 'arrow', 'd': '2018-04-25T16:17:53.533672+00:00'}, 'swhtype': 'fake', 'swh_dict': {'swhtype': 42, 'd': 'test'}, 'random_dict': {'swhtype': 43}, 'uuid': {'swhtype': 'uuid', 'd': 'cdd8f804-9db6-40c3-93ab-5955d3836234'}, } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) + + @requests_mock.Mocker() + def test_decode_response_json(self, mock_requests): + mock_requests.get('https://example.org/test/data', + json=self.encoded_data, + headers={'content-type': 'application/json'}) + response = requests.get('https://example.org/test/data') + assert decode_response(response) == self.data