diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index b846cf6..b4e0a3f 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,270 +1,277 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime from enum import Enum import json import traceback import types from uuid import UUID import arrow import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response +def encode_datetime(dt: datetime.datetime) -> str: + """Wrapper of datetime.datetime.isoformat() that forbids naive datetimes.""" + if dt.tzinfo is None: + raise ValueError(f"{dt} is a naive datetime.") + return dt.isoformat() + + ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), - (datetime.datetime, "datetime", datetime.datetime.isoformat), + (datetime.datetime, "datetime", encode_datetime), ( datetime.timedelta, "timedelta", lambda o: { "days": o.days, "seconds": o.seconds, "microseconds": o.microseconds, }, ), (UUID, "uuid", str), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), ] DECODERS = { "arrow": arrow.get, "datetime": lambda d: iso8601.parse_date(d, default_timezone=None), "timedelta": lambda d: datetime.timedelta(**d), "uuid": UUID, # Only for JSON: "bytes": base64.b85decode, } class MsgpackExtTypeCodes(Enum): LONG_INT = 1 def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers["content-type"] if content_type.startswith("application/x-msgpack"): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith("text/"): r = response.text else: raise ValueError("Wrong content type `%s` for API response" % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { "swhtype": type_name, "d": encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {"d", "swhtype"}: if o["swhtype"] == "bytes": return base64.b85decode(o["d"]) decoder = self.decoders.get(o["swhtype"]) if decoder: return decoder(self.decode_data(o["d"])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, int): # integer overflowed while packing. Handle it as an extended type length, rem = divmod(obj.bit_length(), 8) if rem: length += 1 return msgpack.ExtType( MsgpackExtTypeCodes.LONG_INT.value, int.to_bytes(obj, length, "big") ) if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b"swhtype": type_name, b"d": encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream. .. Caution:: This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: return int.from_bytes(data, "big") raise ValueError("Unknown msgpack extended code %s" % code) def decode_types(obj): # Support for current encodings if set(obj.keys()) == {b"d", b"swhtype"}: decoder = decoders.get(obj[b"swhtype"]) if decoder: return decoder(obj[b"d"]) # Support for legacy encodings if b"__datetime__" in obj and obj[b"__datetime__"]: return iso8601.parse_date(obj[b"s"], default_timezone=None) if b"__uuid__" in obj and obj[b"__uuid__"]: return UUID(obj[b"s"]) if b"__timedelta__" in obj and obj[b"__timedelta__"]: return datetime.timedelta(**obj[b"s"]) if b"__arrow__" in obj and obj[b"__arrow__"]: return arrow.get(obj[b"s"]) # Fallthrough return obj try: try: return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook, strict_map_key=False, ) except TypeError: # msgpack < 0.6.0 return msgpack.unpackb( data, raw=False, object_hook=decode_types, ext_hook=ext_hook ) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) def exception_to_dict(exception): tb = traceback.format_exception(None, exception, exception.__traceback__) return { "exception": { "type": type(exception).__name__, "args": exception.args, "message": str(exception), "traceback": tb, } } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 7d24408..7b8400a 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,232 +1,241 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import msgpack import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): return Response("toor") STRUCT = { "txt": "something stupid", # 'date': datetime.date(2019, 6, 9), # not supported - "datetime": datetime.datetime(2019, 6, 9, 10, 12), + "datetime": datetime.datetime(2019, 6, 9, 10, 12, tzinfo=datetime.timezone.utc), "timedelta": datetime.timedelta(days=-2, hours=3), "int": 42, "float": 3.14, - "subdata": {"int": 42, "datetime": datetime.datetime(2019, 6, 10, 11, 12),}, - "list": [42, datetime.datetime(2019, 9, 10, 11, 12), "ok"], + "subdata": { + "int": 42, + "datetime": datetime.datetime( + 2019, 6, 10, 11, 12, tzinfo=datetime.timezone.utc + ), + }, + "list": [ + 42, + datetime.datetime(2019, 9, 10, 11, 12, tzinfo=datetime.timezone.utc), + "ok", + ], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(";")[0].strip() dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) app.router.add_route("GET", "/", root) app.router.add_route("GET", "/struct", struct) app.router.add_route("POST", "/echo", echo) app.router.add_route("GET", "/server_exception", server_exception) app.router.add_route("GET", "/client_error", client_error) app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app async def test_get_simple(async_app, aiohttp_client) -> None: assert async_app is not None cli = await aiohttp_client(async_app) resp = await cli.get("/") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == "toor" async def test_get_server_exception(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == "toor" async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) resp = await cli.get("/struct") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(async_app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps({"toto": 42}), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps(STRUCT), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(async_app) for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo-no-nego", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py index 5211310..b710657 100644 --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -1,186 +1,203 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json from typing import Any, Callable, List, Tuple import unittest from uuid import UUID import arrow import requests import requests_mock from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps, msgpack_loads, decode_response, ) class ExtraType: def __init__(self, arg1, arg2): self.arg1 = arg1 self.arg2 = arg2 def __repr__(self): return f"ExtraType({self.arg1}, {self.arg2})" def __eq__(self, other): return isinstance(other, ExtraType) and (self.arg1, self.arg2) == ( other.arg1, other.arg2, ) extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, "extratype", lambda o: (o.arg1, o.arg2)) ] extra_decoders = { "extratype": lambda o: ExtraType(*o), } class Serializers(unittest.TestCase): def setUp(self): self.tz = datetime.timezone(datetime.timedelta(minutes=118)) self.data = { "bytes": b"123456789\x99\xaf\xff\x00\x12", - "datetime_naive": datetime.datetime(2015, 1, 1, 12, 4, 42, 231455), "datetime_tz": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=self.tz ), "datetime_utc": datetime.datetime( 2015, 3, 4, 18, 25, 13, 1234, tzinfo=datetime.timezone.utc ), "datetime_delta": datetime.timedelta(64), "arrow_date": arrow.get("2018-04-25T16:17:53.533672+00:00"), "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": UUID("cdd8f804-9db6-40c3-93ab-5955d3836234"), } self.encoded_data = { "bytes": {"swhtype": "bytes", "d": "F)}kWH8wXmIhn8j01^"}, - "datetime_naive": { - "swhtype": "datetime", - "d": "2015-01-01T12:04:42.231455", - }, "datetime_tz": { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+01:58", }, "datetime_utc": { "swhtype": "datetime", "d": "2015-03-04T18:25:13.001234+00:00", }, "datetime_delta": { "swhtype": "timedelta", "d": {"days": 64, "seconds": 0, "microseconds": 0}, }, "arrow_date": {"swhtype": "arrow", "d": "2018-04-25T16:17:53.533672+00:00"}, "swhtype": "fake", "swh_dict": {"swhtype": 42, "d": "test"}, "random_dict": {"swhtype": 43}, "uuid": {"swhtype": "uuid", "d": "cdd8f804-9db6-40c3-93ab-5955d3836234"}, } self.legacy_msgpack = { "bytes": b"\xc4\x0e123456789\x99\xaf\xff\x00\x12", - "datetime_naive": ( - b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" - b"2015-01-01T12:04:42.231455" - ), "datetime_tz": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+01:58" ), "datetime_utc": ( b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xd9 " b"2015-03-04T18:25:13.001234+00:00" ), "datetime_delta": ( b"\x82\xc4\r__timedelta__\xc3\xc4\x01s\x83\xa4" b"days@\xa7seconds\x00\xacmicroseconds\x00" ), "arrow_date": ( b"\x82\xc4\t__arrow__\xc3\xc4\x01s\xd9 " b"2018-04-25T16:17:53.533672+00:00" ), "swhtype": b"\xa4fake", "swh_dict": b"\x82\xa7swhtype*\xa1d\xa4test", "random_dict": b"\x81\xa7swhtype+", "uuid": ( b"\x82\xc4\x08__uuid__\xc3\xc4\x01s\xd9$" b"cdd8f804-9db6-40c3-93ab-5955d3836234" ), } self.generator = (i for i in range(5)) self.gen_lst = list(range(5)) def test_round_trip_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder)) def test_round_trip_json_extra_types(self): original_data = [ExtraType("baz", self.data), "qux"] data = json.dumps( original_data, cls=SWHJSONEncoder, extra_encoders=extra_encoders ) self.assertEqual( original_data, json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders), ) def test_encode_swh_json(self): data = json.dumps(self.data, cls=SWHJSONEncoder) self.assertEqual(self.encoded_data, json.loads(data)) def test_round_trip_msgpack(self): original_data = { **self.data, "none_dict_key": {None: 42}, "long_int_is_loooong": 10000000000000000000000000000000, } data = msgpack_dumps(original_data) self.assertEqual(original_data, msgpack_loads(data)) def test_round_trip_msgpack_extra_types(self): original_data = [ExtraType("baz", self.data), "qux"] data = msgpack_dumps(original_data, extra_encoders=extra_encoders) self.assertEqual( original_data, msgpack_loads(data, extra_decoders=extra_decoders) ) def test_generator_json(self): data = json.dumps(self.generator, cls=SWHJSONEncoder) self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder)) def test_generator_msgpack(self): data = msgpack_dumps(self.generator) self.assertEqual(self.gen_lst, msgpack_loads(data)) @requests_mock.Mocker() def test_decode_response_json(self, mock_requests): mock_requests.get( "https://example.org/test/data", json=self.encoded_data, headers={"content-type": "application/json"}, ) response = requests.get("https://example.org/test/data") assert decode_response(response) == self.data def test_decode_legacy_msgpack(self): for k, v in self.legacy_msgpack.items(): assert msgpack_loads(v) == self.data[k] + + def test_encode_native_datetime(self): + dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + with self.assertRaises(ValueError, matches="naive datetime"): + msgpack_dumps(dt) + + def test_decode_naive_datetime(self): + expected_dt = datetime.datetime(2015, 1, 1, 12, 4, 42, 231455) + + # Current encoding + assert ( + msgpack_loads( + b"\x82\xc4\x07swhtype\xa8datetime\xc4\x01d\xba" + b"2015-01-01T12:04:42.231455" + ) + == expected_dt + ) + + # Legacy encoding + assert ( + msgpack_loads( + b"\x82\xc4\x0c__datetime__\xc3\xc4\x01s\xba" + b"2015-01-01T12:04:42.231455" + ) + == expected_dt + )