diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -17,6 +17,31 @@ from requests import Response +ENCODERS = [ + (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), + (datetime.datetime, 'datetime', datetime.datetime.isoformat), + (datetime.timedelta, 'timedelta', lambda o: { + 'days': o.days, + 'seconds': o.seconds, + 'microseconds': o.microseconds, + }), + (UUID, 'uuid', str), + + # Only for JSON: + (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), +] + +DECODERS = { + 'arrow': arrow.get, + 'datetime': dateutil.parser.parse, + 'timedelta': lambda d: datetime.timedelta(**d), + 'uuid': UUID, + + # Only for JSON: + 'bytes': base64.b85decode, +} + + def encode_data_client(data: Any) -> bytes: try: return msgpack_dumps(data) @@ -67,35 +92,12 @@ def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: - if isinstance(o, bytes): - return { - 'swhtype': 'bytes', - 'd': base64.b85encode(o).decode('ascii'), - } - elif isinstance(o, datetime.datetime): - return { - 'swhtype': 'datetime', - 'd': o.isoformat(), - } - elif isinstance(o, UUID): - return { - 'swhtype': 'uuid', - 'd': str(o), - } - elif isinstance(o, datetime.timedelta): - return { - 'swhtype': 'timedelta', - 'd': { - 'days': o.days, - 'seconds': o.seconds, - 'microseconds': o.microseconds, - }, - } - elif isinstance(o, arrow.Arrow): - return { - 'swhtype': 'arrow', - 'd': o.isoformat(), - } + for (type_, type_name, encoder) in ENCODERS: + if isinstance(o, type_): + return { + 'swhtype': type_name, + 'd': encoder(o), + } try: return super().default(o) except TypeError as e: @@ -130,17 +132,11 @@ def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: - datatype = o['swhtype'] - if datatype == 'bytes': + if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) - elif datatype == 'datetime': - return dateutil.parser.parse(o['d']) - elif datatype == 'uuid': - return UUID(o['d']) - elif datatype == 'timedelta': - return datetime.timedelta(**o['d']) - elif datatype == 'arrow': - return arrow.get(o['d']) + decoder = DECODERS.get(o['swhtype']) + if decoder: + return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] @@ -155,23 +151,15 @@ def msgpack_dumps(data: Any) -> bytes: """Write data as a msgpack stream""" def encode_types(obj): - if isinstance(obj, datetime.datetime): - return {b'__datetime__': True, b's': obj.isoformat()} if isinstance(obj, types.GeneratorType): return list(obj) - if isinstance(obj, UUID): - return {b'__uuid__': True, b's': str(obj)} - if isinstance(obj, datetime.timedelta): - return { - b'__timedelta__': True, - b's': { - 'days': obj.days, - 'seconds': obj.seconds, - 'microseconds': obj.microseconds, - }, - } - if isinstance(obj, arrow.Arrow): - return {b'__arrow__': True, b's': obj.isoformat()} + + for (type_, type_name, encoder) in ENCODERS: + if isinstance(obj, type_): + return { + b'swhtype': type_name, + b'd': encoder(obj), + } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) @@ -180,14 +168,10 @@ def msgpack_loads(data: bytes) -> Any: """Read data as a msgpack stream""" def decode_types(obj): - if b'__datetime__' in obj and obj[b'__datetime__']: - return dateutil.parser.parse(obj[b's']) - if b'__uuid__' in obj and obj[b'__uuid__']: - return UUID(obj[b's']) - if b'__timedelta__' in obj and obj[b'__timedelta__']: - return datetime.timedelta(**obj[b's']) - if b'__arrow__' in obj and obj[b'__arrow__']: - return arrow.get(obj[b's']) + if set(obj.keys()) == {b'd', b'swhtype'}: + decoder = DECODERS.get(obj[b'swhtype']) + if decoder: + return decoder(obj[b'd']) return obj try: