diff --git a/requirements-http.txt b/requirements-http.txt index f2d16cc..e52b574 100644 --- a/requirements-http.txt +++ b/requirements-http.txt @@ -1,9 +1,9 @@ # requirements for swh.core.api aiohttp aiohttp_utils >= 3.1.1 decorator Flask iso8601 -msgpack > 0.5 +msgpack >= 1.0.0 requests blinker # dependency of sentry-sdk[flask] diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 34511d0..f732656 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,311 +1,318 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from uuid import UUID import iso8601 import msgpack from requests import Response from swh.core.api.classes import PagedResult def encode_datetime(dt: datetime.datetime) -> str: """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" if dt.tzinfo is None: - raise ValueError(f"{dt} is a naive datetime.") + raise TypeError("can not serialize naive 'datetime.datetime' object") return dt.isoformat() def _encode_paged_result(obj: PagedResult) -> Dict[str, Any]: """Serialize PagedResult to a Dict.""" return { "results": obj.results, "next_page_token": obj.next_page_token, } def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult: """Deserialize Dict into PagedResult""" return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) def exception_to_dict(exception: Exception) -> Dict[str, Any]: tb = traceback.format_exception(None, exception, exception.__traceback__) exc_type = type(exception) return { "type": exc_type.__name__, "module": exc_type.__module__, "args": exception.args, "message": str(exception), "traceback": tb, } def dict_to_exception(exc_dict: Dict[str, Any]) -> Exception: temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]]) return getattr(temp, exc_dict["type"])(*exc_dict["args"]) def encode_timedelta(td: datetime.timedelta) -> Dict[str, int]: return { "days": td.days, "seconds": td.seconds, "microseconds": td.microseconds, } ENCODERS: List[Tuple[type, str, Callable]] = [ - (datetime.datetime, "datetime", encode_datetime), (UUID, "uuid", str), (datetime.timedelta, "timedelta", encode_timedelta), (PagedResult, "paged_result", _encode_paged_result), (Exception, "exception", exception_to_dict), ] JSON_ENCODERS: List[Tuple[type, str, Callable]] = [ + (datetime.datetime, "datetime", encode_datetime), (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS: Dict[str, Callable] = { - "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, "paged_result": _decode_paged_result, "exception": dict_to_exception, + # for BW compat, to be moved in JSON_DECODERS ASAP + "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), } JSON_DECODERS: Dict[str, Callable] = { "bytes": base64.b85decode, } def get_encoders( extra_encoders: Optional[List[Tuple[Type, str, Callable]]], with_json: bool = False ) -> List[Tuple[Type, str, Callable]]: encoders = ENCODERS if with_json: encoders = [*encoders, *JSON_ENCODERS] if extra_encoders: encoders = [*encoders, *extra_encoders] return encoders def get_decoders( extra_decoders: Optional[Dict[str, Callable]], with_json: bool = False ) -> Dict[str, Callable]: decoders = DECODERS if with_json: decoders = {**decoders, **JSON_DECODERS} if extra_decoders is not None: decoders = {**decoders, **extra_decoders} return decoders class MsgpackExtTypeCodes(Enum): LONG_INT = 1 LONG_NEG_INT = 2 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = get_encoders(extra_encoders, with_json=True) def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = get_decoders(extra_decoders, with_json=True) def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = get_encoders(extra_encoders) def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type if obj > 0: code = MsgpackExtTypeCodes.LONG_INT.value else: code = MsgpackExtTypeCodes.LONG_NEG_INT.value obj = -obj length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType(code, int.to_bytes(obj, length, "big")) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj - return msgpack.packb(data, use_bin_type=True, default=encode_types) + return msgpack.packb( + data, + use_bin_type=True, + datetime=True, # encode datetime as msgpack.Timestamp + default=encode_types, + ) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = get_decoders(extra_decoders) def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") elif code == MsgpackExtTypeCodes.LONG_NEG_INT.value: return -int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, + timestamp=3, # convert Timestamp in datetime objects (tz UTC) ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 12ed922..e75e01b 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,238 +1,260 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple, Union from uuid import UUID +import msgpack import pytest import requests from requests.exceptions import ConnectionError from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( ENCODERS, decode_response, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } TZ = datetime.timezone(datetime.timedelta(minutes=118)) DATA_BYTES = b"123456789\x99\xaf\xff\x00\x12" ENCODED_DATA_BYTES = {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"} DATA_DATETIME = datetime.datetime(2015, 3, 4, 18, 25, 13, 1234, tzinfo=TZ,) ENCODED_DATA_DATETIME = { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", } DATA_TIMEDELTA = datetime.timedelta(64) ENCODED_DATA_TIMEDELTA = { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, } DATA_UUID = UUID("cdd8f804-9db6-40c3-93ab-5955d3836234") ENCODED_DATA_UUID = {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"} # For test demonstration purposes TestPagedResultStr = PagedResult[ Union[UUID, datetime.datetime, datetime.timedelta], str ] DATA_PAGED_RESULT = TestPagedResultStr( results=[DATA_UUID, DATA_DATETIME, DATA_TIMEDELTA], next_page_token="10", ) ENCODED_DATA_PAGED_RESULT = { "d": { "results": [ENCODED_DATA_UUID, ENCODED_DATA_DATETIME, ENCODED_DATA_TIMEDELTA,], "next_page_token": "10", }, "swhtype": "paged_result", } TestPagedResultTuple = PagedResult[ Union[str, bytes, datetime.datetime], List[Union[str, UUID]] ] DATA_PAGED_RESULT2 = TestPagedResultTuple( results=["data0", DATA_BYTES, DATA_DATETIME], next_page_token=["10", DATA_UUID], ) ENCODED_DATA_PAGED_RESULT2 = { "d": { "results": ["data0", ENCODED_DATA_BYTES, ENCODED_DATA_DATETIME,], "next_page_token": ["10", ENCODED_DATA_UUID], }, "swhtype": "paged_result", } DATA = { "bytes": DATA_BYTES, "datetime_tz": DATA_DATETIME, "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": DATA_TIMEDELTA, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": DATA_UUID, "paged-result": DATA_PAGED_RESULT, "paged-result2": DATA_PAGED_RESULT2, } ENCODED_DATA = { "bytes": ENCODED_DATA_BYTES, "datetime_tz": ENCODED_DATA_DATETIME, "datetime_utc": {"swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00",}, "datetime_delta": ENCODED_DATA_TIMEDELTA, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": ENCODED_DATA_UUID, "paged-result": ENCODED_DATA_PAGED_RESULT, "paged-result2": ENCODED_DATA_PAGED_RESULT2, } def test_serializers_round_trip_json(): json_data = json_dumps(DATA) actual_data = json_loads(json_data) assert actual_data == DATA def test_serializers_round_trip_json_extra_types(): expected_original_data = [ExtraType("baz", DATA), "qux"] data = json_dumps(expected_original_data, extra_encoders=extra_encoders) actual_data = json_loads(data, extra_decoders=extra_decoders) assert actual_data == expected_original_data def test_exception_serializer_round_trip_json(): error_message = "unreachable host" json_data = json_dumps({"exception": ConnectionError(error_message)},) actual_data = json_loads(json_data) assert "exception" in actual_data assert type(actual_data["exception"]) == ConnectionError assert str(actual_data["exception"]) == error_message def test_serializers_encode_swh_json(): json_str = json_dumps(DATA) actual_data = json.loads(json_str) assert actual_data == ENCODED_DATA def test_serializers_round_trip_msgpack(): expected_original_data = { **DATA, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, "long_negative_int_is_loooong": -10000000000000000000000000000000, } data = msgpack_dumps(expected_original_data) actual_data = msgpack_loads(data) assert actual_data == expected_original_data def test_serializers_round_trip_msgpack_extra_types(): original_data = [ExtraType("baz", DATA), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) actual_data = msgpack_loads(data, extra_decoders=extra_decoders) assert actual_data == original_data def test_exception_serializer_round_trip_msgpack(): error_message = "unreachable host" data = msgpack_dumps({"exception": ConnectionError(error_message)}) actual_data = msgpack_loads(data) assert "exception" in actual_data assert type(actual_data["exception"]) == ConnectionError assert str(actual_data["exception"]) == error_message def test_serializers_generator_json(): data = json_dumps((i for i in range(5))) assert json_loads(data) == [i for i in range(5)] def test_serializers_generator_msgpack(): data = msgpack_dumps((i for i in range(5))) assert msgpack_loads(data) == [i for i in range(5)] def test_serializers_decode_response_json(requests_mock): requests_mock.get( "https://example.org/test/data", json=ENCODED_DATA, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == DATA -def test_serializers_encode_native_datetime(): +def test_serializers_encode_datetime_msgpack(): + dt = datetime.datetime.now(tz=datetime.timezone.utc) + encmsg = msgpack_dumps(dt) + decmsg = msgpack.loads(encmsg, timestamp=0) + assert isinstance(decmsg, msgpack.Timestamp) + assert decmsg.to_datetime() == dt + + +def test_serializers_decode_datetime_compat_msgpack(): + dt = datetime.datetime.now(tz=datetime.timezone.utc) + encmsg = msgpack_dumps({b"swhtype": "datetime", b"d": dt.isoformat()}) + decmsg = msgpack_loads(encmsg) + assert decmsg == dt + + +def test_serializers_encode_native_datetime_msgpack(): dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) - with pytest.raises(ValueError, match="naive datetime"): + with pytest.raises(TypeError, match="datetime"): msgpack_dumps(dt) +def test_serializers_encode_native_datetime_json(): + dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + with pytest.raises(TypeError, match="datetime"): + json_dumps(dt) + + def test_serializers_decode_naive_datetime(): expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) # Current encoding assert ( msgpack_loads( b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" b"2015-01-01T12:04:42.231455" ) == expected_dt ) def test_msgpack_extra_encoders_mutation(): data = msgpack_dumps({}, extra_encoders=extra_encoders) assert data is not None assert ENCODERS[-1][0] != ExtraType def test_json_extra_encoders_mutation(): data = json_dumps({}, extra_encoders=extra_encoders) assert data is not None assert ENCODERS[-1][0] != ExtraType