diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -40,6 +40,27 @@ return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) +def exception_to_dict(exception: Exception) -> Dict[str, Any]: + tb = traceback.format_exception(None, exception, exception.__traceback__) + exc_type = type(exception) + return { + "exception": { + "type": exc_type.__name__, + "module": exc_type.__module__, + "args": exception.args, + "message": str(exception), + "traceback": tb, + } + } + + +def dict_to_exception(input_dict: Dict[str, Any]) -> Exception: + assert "exception" in input_dict + exc_dict = input_dict["exception"] + temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]]) + return getattr(temp, exc_dict["type"])(*exc_dict["args"]) + + ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), (datetime.datetime, "datetime", encode_datetime), @@ -56,6 +77,7 @@ (PagedResult, "paged_result", _encode_paged_result), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), + (Exception, "exception", exception_to_dict), ] DECODERS = { @@ -66,6 +88,7 @@ "paged_result": _decode_paged_result, # Only for JSON: "bytes": base64.b85decode, + "exception": dict_to_exception, } @@ -279,15 +302,3 @@ return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) - - -def exception_to_dict(exception): - tb = traceback.format_exception(None, exception, exception.__traceback__) - return { - "exception": { - "type": type(exception).__name__, - "args": exception.args, - "message": str(exception), - "traceback": tb, - } - } diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -30,6 +30,10 @@ def raise_typeerror(self): raise TypeError("Did I pass through?") + @remote_api_endpoint("raise_exception_exc_arg") + def raise_exception_exc_arg(self): + raise Exception(Exception("error")) + # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient @@ -115,3 +119,12 @@ str(exc_info.value) == "" ) + + +def test_api_raise_exception_exc_arg(swh_rpc_client): + with pytest.raises(RemoteException) as exc_info: + swh_rpc_client.post("raise_exception_exc_arg", data={}) + + assert exc_info.value.args[0]["type"] == "Exception" + assert type(exc_info.value.args[0]["args"][0]) == Exception + assert str(exc_info.value.args[0]["args"][0]) == "error" diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -12,6 +12,7 @@ from arrow import Arrow import pytest import requests +from requests.exceptions import ConnectionError from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( @@ -148,6 +149,17 @@ assert actual_data == expected_original_data +def test_exception_serializer_round_trip_json(): + error_message = "unreachable host" + json_data = json.dumps( + {"exception": ConnectionError(error_message)}, cls=SWHJSONEncoder + ) + actual_data = json.loads(json_data, cls=SWHJSONDecoder) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_encode_swh_json(): json_str = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_str) @@ -172,6 +184,15 @@ assert actual_data == original_data +def test_exception_serializer_round_trip_msgpack(): + error_message = "unreachable host" + data = msgpack_dumps({"exception": ConnectionError(error_message)}) + actual_data = msgpack_loads(data) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_generator_json(): data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder) assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)]