diff --git a/requirements-http.txt b/requirements-http.txt --- a/requirements-http.txt +++ b/requirements-http.txt @@ -4,6 +4,6 @@ decorator Flask iso8601 -msgpack > 0.5 +msgpack >= 1.0.0 requests blinker # dependency of sentry-sdk[flask] diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -22,7 +22,7 @@ def encode_datetime(dt: datetime.datetime) -> str: """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" if dt.tzinfo is None: - raise ValueError(f"{dt} is a naive datetime.") + raise TypeError("can not serialize naive 'datetime.datetime' object") return dt.isoformat() @@ -65,40 +65,51 @@ ENCODERS: List[Tuple[type, str, Callable]] = [ - (datetime.datetime, "datetime", encode_datetime), (UUID, "uuid", str), (datetime.timedelta, "timedelta", encode_timedelta), (PagedResult, "paged_result", _encode_paged_result), - # Only for JSON: - (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), (Exception, "exception", exception_to_dict), ] +JSON_ENCODERS: List[Tuple[type, str, Callable]] = [ + (datetime.datetime, "datetime", encode_datetime), + (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), +] + DECODERS: Dict[str, Callable] = { - "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, "paged_result": _decode_paged_result, - # Only for JSON: - "bytes": base64.b85decode, "exception": dict_to_exception, + # for BW compat, to be moved in JSON_DECODERS ASAP + "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), +} + +JSON_DECODERS: Dict[str, Callable] = { + "bytes": base64.b85decode, } def get_encoders( - extra_encoders: Optional[List[Tuple[Type, str, Callable]]] + extra_encoders: Optional[List[Tuple[Type, str, Callable]]], with_json: bool = False ) -> List[Tuple[Type, str, Callable]]: - if extra_encoders is not None: - return [*ENCODERS, *extra_encoders] - else: - return ENCODERS - - -def get_decoders(extra_decoders: Optional[Dict[str, Callable]]) -> Dict[str, Callable]: + encoders = ENCODERS + if with_json: + encoders = [*encoders, *JSON_ENCODERS] + if extra_encoders: + encoders = [*encoders, *extra_encoders] + return encoders + + +def get_decoders( + extra_decoders: Optional[Dict[str, Callable]], with_json: bool = False +) -> Dict[str, Callable]: + decoders = DECODERS + if with_json: + decoders = {**decoders, **JSON_DECODERS} if extra_decoders is not None: - return {**DECODERS, **extra_decoders} - else: - return DECODERS + decoders = {**decoders, **extra_decoders} + return decoders class MsgpackExtTypeCodes(Enum): @@ -154,7 +165,7 @@ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) - self.encoders = get_encoders(extra_encoders) + self.encoders = get_encoders(extra_encoders, with_json=True) def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: @@ -196,7 +207,7 @@ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) - self.decoders = get_decoders(extra_decoders) + self.decoders = get_decoders(extra_decoders, with_json=True) def decode_data(self, o: Any) -> Any: if isinstance(o, dict): @@ -253,7 +264,12 @@ } return obj - return msgpack.packb(data, use_bin_type=True, default=encode_types) + return msgpack.packb( + data, + use_bin_type=True, + datetime=True, # encode datetime as msgpack.Timestamp + default=encode_types, + ) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: @@ -290,6 +306,7 @@ object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -8,6 +8,7 @@ from typing import Any, Callable, List, Tuple, Union from uuid import UUID +import msgpack import pytest import requests from requests.exceptions import ConnectionError @@ -15,9 +16,9 @@ from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( ENCODERS, - SWHJSONDecoder, - SWHJSONEncoder, decode_response, + json_dumps, + json_loads, msgpack_dumps, msgpack_loads, ) @@ -131,33 +132,29 @@ def test_serializers_round_trip_json(): - json_data = json.dumps(DATA, cls=SWHJSONEncoder) - actual_data = json.loads(json_data, cls=SWHJSONDecoder) + json_data = json_dumps(DATA) + actual_data = json_loads(json_data) assert actual_data == DATA def test_serializers_round_trip_json_extra_types(): expected_original_data = [ExtraType("baz", DATA), "qux"] - data = json.dumps( - expected_original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders - ) - actual_data = json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) + data = json_dumps(expected_original_data, extra_encoders=extra_encoders) + actual_data = json_loads(data, extra_decoders=extra_decoders) assert actual_data == expected_original_data def test_exception_serializer_round_trip_json(): error_message = "unreachable host" - json_data = json.dumps( - {"exception": ConnectionError(error_message)}, cls=SWHJSONEncoder - ) - actual_data = json.loads(json_data, cls=SWHJSONDecoder) + json_data = json_dumps({"exception": ConnectionError(error_message)},) + actual_data = json_loads(json_data) assert "exception" in actual_data assert type(actual_data["exception"]) == ConnectionError assert str(actual_data["exception"]) == error_message def test_serializers_encode_swh_json(): - json_str = json.dumps(DATA, cls=SWHJSONEncoder) + json_str = json_dumps(DATA) actual_data = json.loads(json_str) assert actual_data == ENCODED_DATA @@ -191,8 +188,8 @@ def test_serializers_generator_json(): - data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder) - assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)] + data = json_dumps((i for i in range(5))) + assert json_loads(data) == [i for i in range(5)] def test_serializers_generator_msgpack(): @@ -211,12 +208,33 @@ assert decode_response(response) == DATA -def test_serializers_encode_native_datetime(): +def test_serializers_encode_datetime_msgpack(): + dt = datetime.datetime.now(tz=datetime.timezone.utc) + encmsg = msgpack_dumps(dt) + decmsg = msgpack.loads(encmsg, timestamp=0) + assert isinstance(decmsg, msgpack.Timestamp) + assert decmsg.to_datetime() == dt + + +def test_serializers_decode_datetime_compat_msgpack(): + dt = datetime.datetime.now(tz=datetime.timezone.utc) + encmsg = msgpack_dumps({b"swhtype": "datetime", b"d": dt.isoformat()}) + decmsg = msgpack_loads(encmsg) + assert decmsg == dt + + +def test_serializers_encode_native_datetime_msgpack(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) - with pytest.raises(ValueError, match="naive datetime"): + with pytest.raises(TypeError, match="datetime"): msgpack_dumps(dt) +def test_serializers_encode_native_datetime_json(): + dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + with pytest.raises(TypeError, match="datetime"): + json_dumps(dt) + + def test_serializers_decode_naive_datetime(): expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) @@ -237,6 +255,6 @@ def test_json_extra_encoders_mutation(): - data = json.dumps({}, cls=SWHJSONEncoder, extra_encoders=extra_encoders) + data = json_dumps({}, extra_encoders=extra_encoders) assert data is not None assert ENCODERS[-1][0] != ExtraType