Changeset View
Changeset View
Standalone View
Standalone View
swh/core/api/serializers.py
Show All 11 Lines | |||||
import arrow | import arrow | ||||
import dateutil.parser | import dateutil.parser | ||||
import msgpack | import msgpack | ||||
from typing import Any, Dict, Union, Tuple | from typing import Any, Dict, Union, Tuple | ||||
from requests import Response | from requests import Response | ||||
ENCODERS = [ | |||||
(arrow.Arrow, 'arrow', arrow.Arrow.isoformat), | |||||
(datetime.datetime, 'datetime', datetime.datetime.isoformat), | |||||
(datetime.timedelta, 'timedelta', lambda o: { | |||||
'days': o.days, | |||||
'seconds': o.seconds, | |||||
'microseconds': o.microseconds, | |||||
}), | |||||
(UUID, 'uuid', str), | |||||
# Only for JSON: | |||||
(bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), | |||||
] | |||||
DECODERS = { | |||||
'arrow': arrow.get, | |||||
olasd: I guess now that we're making a protocol break, it'd be a good time to switch over to `iso8601`… | |||||
'datetime': dateutil.parser.parse, | |||||
'timedelta': lambda d: datetime.timedelta(**d), | |||||
'uuid': UUID, | |||||
# Only for JSON: | |||||
'bytes': base64.b85decode, | |||||
} | |||||
def encode_data_client(data: Any) -> bytes: | def encode_data_client(data: Any) -> bytes: | ||||
try: | try: | ||||
return msgpack_dumps(data) | return msgpack_dumps(data) | ||||
except OverflowError as e: | except OverflowError as e: | ||||
raise ValueError('Limits were reached. Please, check your input.\n' + | raise ValueError('Limits were reached. Please, check your input.\n' + | ||||
str(e)) | str(e)) | ||||
Show All 34 Lines | class SWHJSONEncoder(json.JSONEncoder): | ||||
prevent us from "escaping" dictionaries that only contain the | prevent us from "escaping" dictionaries that only contain the | ||||
swhtype and d keys, and therefore arbitrary data structures can't | swhtype and d keys, and therefore arbitrary data structures can't | ||||
be round-tripped through SWHJSONEncoder and SWHJSONDecoder. | be round-tripped through SWHJSONEncoder and SWHJSONDecoder. | ||||
""" | """ | ||||
def default(self, o: Any | def default(self, o: Any | ||||
) -> Union[Dict[str, Union[Dict[str, int], str]], list]: | ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: | ||||
if isinstance(o, bytes): | for (type_, type_name, encoder) in ENCODERS: | ||||
return { | if isinstance(o, type_): | ||||
'swhtype': 'bytes', | |||||
'd': base64.b85encode(o).decode('ascii'), | |||||
} | |||||
elif isinstance(o, datetime.datetime): | |||||
return { | return { | ||||
'swhtype': 'datetime', | 'swhtype': type_name, | ||||
'd': o.isoformat(), | 'd': encoder(o), | ||||
} | |||||
elif isinstance(o, UUID): | |||||
return { | |||||
'swhtype': 'uuid', | |||||
'd': str(o), | |||||
} | |||||
elif isinstance(o, datetime.timedelta): | |||||
return { | |||||
'swhtype': 'timedelta', | |||||
'd': { | |||||
'days': o.days, | |||||
'seconds': o.seconds, | |||||
'microseconds': o.microseconds, | |||||
}, | |||||
} | |||||
elif isinstance(o, arrow.Arrow): | |||||
return { | |||||
'swhtype': 'arrow', | |||||
'd': o.isoformat(), | |||||
} | } | ||||
try: | try: | ||||
return super().default(o) | return super().default(o) | ||||
except TypeError as e: | except TypeError as e: | ||||
try: | try: | ||||
iterable = iter(o) | iterable = iter(o) | ||||
except TypeError: | except TypeError: | ||||
raise e from None | raise e from None | ||||
else: | else: | ||||
Show All 18 Lines | class SWHJSONDecoder(json.JSONDecoder): | ||||
To limit the impact our encoding, if the swhtype key doesn't | To limit the impact our encoding, if the swhtype key doesn't | ||||
contain a known value, the dictionary is decoded as-is. | contain a known value, the dictionary is decoded as-is. | ||||
""" | """ | ||||
def decode_data(self, o: Any) -> Any: | def decode_data(self, o: Any) -> Any: | ||||
if isinstance(o, dict): | if isinstance(o, dict): | ||||
if set(o.keys()) == {'d', 'swhtype'}: | if set(o.keys()) == {'d', 'swhtype'}: | ||||
datatype = o['swhtype'] | if o['swhtype'] == 'bytes': | ||||
if datatype == 'bytes': | |||||
return base64.b85decode(o['d']) | return base64.b85decode(o['d']) | ||||
elif datatype == 'datetime': | decoder = DECODERS.get(o['swhtype']) | ||||
return dateutil.parser.parse(o['d']) | if decoder: | ||||
elif datatype == 'uuid': | return decoder(self.decode_data(o['d'])) | ||||
return UUID(o['d']) | |||||
elif datatype == 'timedelta': | |||||
return datetime.timedelta(**o['d']) | |||||
elif datatype == 'arrow': | |||||
return arrow.get(o['d']) | |||||
return {key: self.decode_data(value) for key, value in o.items()} | return {key: self.decode_data(value) for key, value in o.items()} | ||||
if isinstance(o, list): | if isinstance(o, list): | ||||
return [self.decode_data(value) for value in o] | return [self.decode_data(value) for value in o] | ||||
else: | else: | ||||
return o | return o | ||||
def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: | def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: | ||||
data, index = super().raw_decode(s, idx) | data, index = super().raw_decode(s, idx) | ||||
return self.decode_data(data), index | return self.decode_data(data), index | ||||
def msgpack_dumps(data: Any) -> bytes: | def msgpack_dumps(data: Any) -> bytes: | ||||
"""Write data as a msgpack stream""" | """Write data as a msgpack stream""" | ||||
def encode_types(obj): | def encode_types(obj): | ||||
if isinstance(obj, datetime.datetime): | |||||
return {b'__datetime__': True, b's': obj.isoformat()} | |||||
if isinstance(obj, types.GeneratorType): | if isinstance(obj, types.GeneratorType): | ||||
return list(obj) | return list(obj) | ||||
if isinstance(obj, UUID): | |||||
return {b'__uuid__': True, b's': str(obj)} | for (type_, type_name, encoder) in ENCODERS: | ||||
if isinstance(obj, datetime.timedelta): | if isinstance(obj, type_): | ||||
return { | return { | ||||
b'__timedelta__': True, | b'swhtype': type_name, | ||||
b's': { | b'd': encoder(obj), | ||||
'days': obj.days, | |||||
'seconds': obj.seconds, | |||||
'microseconds': obj.microseconds, | |||||
}, | |||||
} | } | ||||
if isinstance(obj, arrow.Arrow): | |||||
return {b'__arrow__': True, b's': obj.isoformat()} | |||||
return obj | return obj | ||||
return msgpack.packb(data, use_bin_type=True, default=encode_types) | return msgpack.packb(data, use_bin_type=True, default=encode_types) | ||||
def msgpack_loads(data: bytes) -> Any: | def msgpack_loads(data: bytes) -> Any: | ||||
"""Read data as a msgpack stream""" | """Read data as a msgpack stream""" | ||||
def decode_types(obj): | def decode_types(obj): | ||||
if b'__datetime__' in obj and obj[b'__datetime__']: | if set(obj.keys()) == {b'd', b'swhtype'}: | ||||
return dateutil.parser.parse(obj[b's']) | decoder = DECODERS.get(obj[b'swhtype']) | ||||
if b'__uuid__' in obj and obj[b'__uuid__']: | if decoder: | ||||
return UUID(obj[b's']) | return decoder(obj[b'd']) | ||||
if b'__timedelta__' in obj and obj[b'__timedelta__']: | |||||
return datetime.timedelta(**obj[b's']) | |||||
if b'__arrow__' in obj and obj[b'__arrow__']: | |||||
return arrow.get(obj[b's']) | |||||
return obj | return obj | ||||
try: | try: | ||||
return msgpack.unpackb(data, raw=False, | return msgpack.unpackb(data, raw=False, | ||||
object_hook=decode_types) | object_hook=decode_types) | ||||
except TypeError: # msgpack < 0.5.2 | except TypeError: # msgpack < 0.5.2 | ||||
return msgpack.unpackb(data, encoding='utf-8', | return msgpack.unpackb(data, encoding='utf-8', | ||||
object_hook=decode_types) | object_hook=decode_types) |
I guess now that we're making a protocol break, it'd be a good time to switch over to iso8601, or at least dateutil.parser.isoparse, instead of letting ourselves parse garbage.