diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -316,13 +316,17 @@ data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: - if exc_type.__name__ == data["exception"]["type"]: - exception = exc_type(*data["exception"]["args"]) + if exc_type.__name__ == data["type"]: + exception = exc_type(*data["args"]) break - else: + # old dict encoded exception schema + # TODO: Remove that code once all servers are using new schema + if "exception" in data: exception = RemoteException( payload=data["exception"], response=response ) + else: + exception = RemoteException(payload=data, response=response) else: exception = pickle.loads(data) @@ -330,10 +334,14 @@ data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) - else: + # old dict encoded exception schema + # TODO: Remove that code once all servers are using new schema + elif "exception" in data: exception = RemoteException( payload=data["exception"], response=response ) + else: + exception = RemoteException(payload=data, response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -40,6 +40,23 @@ return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],) +def exception_to_dict(exception: Exception) -> Dict[str, Any]: + tb = traceback.format_exception(None, exception, exception.__traceback__) + exc_type = type(exception) + return { + "type": exc_type.__name__, + "module": exc_type.__module__, + "args": exception.args, + "message": str(exception), + "traceback": tb, + } + + +def dict_to_exception(exc_dict: Dict[str, Any]) -> Exception: + temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]]) + return getattr(temp, exc_dict["type"])(*exc_dict["args"]) + + ENCODERS = [ (arrow.Arrow, "arrow", arrow.Arrow.isoformat), (datetime.datetime, "datetime", encode_datetime), @@ -56,6 +73,7 @@ (PagedResult, "paged_result", _encode_paged_result), # Only for JSON: (bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")), + (Exception, "exception", exception_to_dict), ] DECODERS = { @@ -66,6 +84,7 @@ "paged_result": _decode_paged_result, # Only for JSON: "bytes": base64.b85decode, + "exception": dict_to_exception, } @@ -279,15 +298,3 @@ return msgpack.unpackb( data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook ) - - -def exception_to_dict(exception): - tb = traceback.format_exception(None, exception, exception.__traceback__) - return { - "exception": { - "type": type(exception).__name__, - "args": exception.args, - "message": str(exception), - "traceback": tb, - } - } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -116,7 +116,7 @@ assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data["exception"]["type"] == "TestServerException" + assert data["type"] == "TestServerException" async def test_get_client_error(cli) -> None: @@ -124,7 +124,7 @@ assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) - assert data["exception"]["type"] == "TestClientError" + assert data["type"] == "TestClientError" async def test_get_simple_nego(cli) -> None: diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -30,6 +30,10 @@ def raise_typeerror(self): raise TypeError("Did I pass through?") + @remote_api_endpoint("raise_exception_exc_arg") + def raise_exception_exc_arg(self): + raise Exception(Exception("error")) + # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient @@ -115,3 +119,12 @@ str(exc_info.value) == "" ) + + +def test_api_raise_exception_exc_arg(swh_rpc_client): + with pytest.raises(RemoteException) as exc_info: + swh_rpc_client.post("raise_exception_exc_arg", data={}) + + assert exc_info.value.args[0]["type"] == "Exception" + assert type(exc_info.value.args[0]["args"][0]) == Exception + assert str(exc_info.value.args[0]["args"][0]) == "error" diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -12,6 +12,7 @@ from arrow import Arrow import pytest import requests +from requests.exceptions import ConnectionError from swh.core.api.classes import PagedResult from swh.core.api.serializers import ( @@ -148,6 +149,17 @@ assert actual_data == expected_original_data +def test_exception_serializer_round_trip_json(): + error_message = "unreachable host" + json_data = json.dumps( + {"exception": ConnectionError(error_message)}, cls=SWHJSONEncoder + ) + actual_data = json.loads(json_data, cls=SWHJSONDecoder) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_encode_swh_json(): json_str = json.dumps(DATA, cls=SWHJSONEncoder) actual_data = json.loads(json_str) @@ -172,6 +184,15 @@ assert actual_data == original_data +def test_exception_serializer_round_trip_msgpack(): + error_message = "unreachable host" + data = msgpack_dumps({"exception": ConnectionError(error_message)}) + actual_data = msgpack_loads(data) + assert "exception" in actual_data + assert type(actual_data["exception"]) == ConnectionError + assert str(actual_data["exception"]) == error_message + + def test_serializers_generator_json(): data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder) assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)]