diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -9,7 +9,7 @@ import json import traceback import types -from typing import Any, Dict, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from uuid import UUID import arrow @@ -88,6 +88,22 @@ } +def get_encoders( + extra_encoders: Optional[List[Tuple[Type, str, Callable]]] +) -> List[Tuple[Type, str, Callable]]: + if extra_encoders is not None: + return [*ENCODERS, *extra_encoders] + else: + return ENCODERS + + +def get_decoders(extra_decoders: Optional[Dict[str, Callable]]) -> Dict[str, Callable]: + if extra_decoders is not None: + return {**DECODERS, **extra_decoders} + else: + return DECODERS + + class MsgpackExtTypeCodes(Enum): LONG_INT = 1 LONG_NEG_INT = 2 @@ -141,9 +157,7 @@ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) - self.encoders = ENCODERS - if extra_encoders: - self.encoders += extra_encoders + self.encoders = get_encoders(extra_encoders) def default(self, o: Any) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: @@ -185,9 +199,7 @@ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) - self.decoders = DECODERS - if extra_decoders: - self.decoders = {**self.decoders, **extra_decoders} + self.decoders = get_decoders(extra_decoders) def decode_data(self, o: Any) -> Any: if isinstance(o, dict): @@ -218,9 +230,7 @@ def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" - encoders = ENCODERS - if extra_encoders: - encoders += extra_encoders + encoders = get_encoders(extra_encoders) def encode_types(obj): if isinstance(obj, int): @@ -256,9 +266,7 @@ This function is used by swh.journal to decode the contents of the journal. This function **must** be kept backwards-compatible. """ - decoders = DECODERS - if extra_decoders: - decoders = {**decoders, **extra_decoders} + decoders = get_decoders(extra_decoders) def ext_hook(code, data): if code == MsgpackExtTypeCodes.LONG_INT.value: diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -16,6 +16,7 @@ from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( + ENCODERS, SWHJSONDecoder, SWHJSONEncoder, decode_response, @@ -270,3 +271,15 @@ ) == expected_dt ) + + +def test_msgpack_extra_encoders_mutation(): + data = msgpack_dumps({}, extra_encoders=extra_encoders) + assert data is not None + assert ENCODERS[-1][0] != ExtraType + + +def test_json_extra_encoders_mutation(): + data = json.dumps({}, cls=SWHJSONEncoder, extra_encoders=extra_encoders) + assert data is not None + assert ENCODERS[-1][0] != ExtraType