diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -42,21 +42,23 @@ } -def encode_data_client(data: Any) -> bytes: +def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: - return msgpack_dumps(data) + return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) -def decode_response(response: Response) -> Any: +def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): - r = msgpack_loads(response.content) + r = msgpack_loads(response.content, + extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder) + r = json.loads(response.text, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: @@ -90,9 +92,15 @@ """ + def __init__(self, extra_encoders=None, **kwargs): + super().__init__(**kwargs) + self.encoders = ENCODERS + if extra_encoders: + self.encoders += extra_encoders + def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: - for (type_, type_name, encoder) in ENCODERS: + for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { 'swhtype': type_name, @@ -129,12 +137,18 @@ """ + def __init__(self, extra_decoders=None, **kwargs): + super().__init__(**kwargs) + self.decoders = DECODERS + if extra_decoders: + self.decoders = {**self.decoders, **extra_decoders} + def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) - decoder = DECODERS.get(o['swhtype']) + decoder = self.decoders.get(o['swhtype']) if decoder: return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} @@ -148,13 +162,17 @@ return self.decode_data(data), index -def msgpack_dumps(data: Any) -> bytes: +def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" + encoders = ENCODERS + if extra_encoders: + encoders += extra_encoders + def encode_types(obj): if isinstance(obj, types.GeneratorType): return list(obj) - for (type_, type_name, encoder) in ENCODERS: + for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b'swhtype': type_name, @@ -165,11 +183,15 @@ return msgpack.packb(data, use_bin_type=True, default=encode_types) -def msgpack_loads(data: bytes) -> Any: +def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" + decoders = DECODERS + if extra_decoders: + decoders = {**decoders, **extra_decoders} + def decode_types(obj): if set(obj.keys()) == {b'd', b'swhtype'}: - decoder = DECODERS.get(obj[b'swhtype']) + decoder = decoders.get(obj[b'swhtype']) if decoder: return decoder(obj[b'd']) return obj diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -21,6 +21,28 @@ ) +class ExtraType: + def __init__(self, arg1, arg2): + self.arg1 = arg1 + self.arg2 = arg2 + + def __repr__(self): + return f'ExtraType({self.arg1}, {self.arg2})' + + def __eq__(self, other): + return (self.arg1, self.arg2) == (other.arg1, other.arg2) + + +extra_encoders = [ + (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) +] + + +extra_decoders = { + 'extratype': lambda o: ExtraType(*o), +} + + class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) @@ -67,6 +89,16 @@ data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) + def test_round_trip_json_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = json.dumps(original_data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + self.assertEqual( + original_data, + json.loads( + data, cls=SWHJSONDecoder, extra_decoders=extra_decoders)) + def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) @@ -75,6 +107,13 @@ data = msgpack_dumps(self.data) self.assertEqual(self.data, msgpack_loads(data)) + def test_round_trip_msgpack_extra_types(self): + original_data = [ExtraType('baz', self.data), 'qux'] + + data = msgpack_dumps(original_data, extra_encoders=extra_encoders) + self.assertEqual( + original_data, msgpack_loads(data, extra_decoders=extra_decoders)) + def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder))