diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -315,19 +315,15 @@ if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): + # TODO: remove "exception" key check once all servers + # are using new schema + exc_data = data["exception"] if "exception" in data else data for exc_type in self.reraise_exceptions: - if exc_type.__name__ == data["type"]: - exception = exc_type(*data["args"]) + if exc_type.__name__ == exc_data["type"]: + exception = exc_type(*exc_data["args"]) break else: - # old dict encoded exception schema - # TODO: Remove that code once all servers are using new schema - if "exception" in data: - exception = RemoteException( - payload=data["exception"], response=response - ) - else: - exception = RemoteException(payload=data, response=response) + exception = RemoteException(payload=exc_data, response=response) else: exception = pickle.loads(data) @@ -335,14 +331,11 @@ data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) - # old dict encoded exception schema - # TODO: Remove that code once all servers are using new schema - elif "exception" in data: - exception = RemoteException( - payload=data["exception"], response=response - ) else: - exception = RemoteException(payload=data, response=response) + # TODO: remove "exception" key check once all servers + # are using new schema + exc_data = data["exception"] if "exception" in data else data + exception = RemoteException(payload=exc_data, response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -6,12 +6,18 @@ import re import pytest +from requests.exceptions import ConnectionError -from swh.core.api import RPCClient, remote_api_endpoint +from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint +from swh.core.api.serializers import exception_to_dict, msgpack_dumps from .test_serializers import ExtraType, extra_decoders, extra_encoders +class ReraiseException(Exception): + pass + + @pytest.fixture def rpc_client(requests_mock): class TestStorage: @@ -35,6 +41,7 @@ backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders + reraise_exceptions = [ReraiseException] def overridden_method(self, data): return "bar" @@ -84,3 +91,87 @@ def test_client_overridden_method(rpc_client): res = rpc_client.overridden_method("foo") assert res == "bar" + + +def test_client_connexion_error(rpc_client, requests_mock): + """ + ConnectionError should be wrapped and raised as an APIError. + """ + error_message = "unreachable host" + requests_mock.post( + re.compile("mock://example.com/connection_error"), + exc=ConnectionError(error_message), + ) + + with pytest.raises(APIError) as exc_info: + rpc_client.post("connection_error", data={}) + + assert type(exc_info.value.args[0]) == ConnectionError + assert str(exc_info.value.args[0]) == error_message + + +def _exception_response(exception, status_code, old_exception_schema=False): + def callback(request, context): + assert request.headers["Content-Type"] == "application/x-msgpack" + context.headers["Content-Type"] = "application/x-msgpack" + exc_dict = exception_to_dict(exception) + if old_exception_schema: + exc_dict = {"exception": exc_dict} + context.content = msgpack_dumps(exc_dict) + context.status_code = status_code + return context.content + + return callback + + +@pytest.mark.parametrize("old_exception_schema", [False, True]) +def test_client_reraise_exception(rpc_client, requests_mock, old_exception_schema): + """ + Exceptions caught server-side and whitelisted will be raised again client-side. + """ + error_message = "something went wrong" + endpoint = "reraise_exception" + + requests_mock.post( + re.compile(f"mock://example.com/{endpoint}"), + content=_exception_response( + exception=ReraiseException(error_message), + status_code=400, + old_exception_schema=old_exception_schema, + ), + ) + + with pytest.raises(ReraiseException) as exc_info: + rpc_client.post(endpoint, data={}) + + assert str(exc_info.value) == error_message + + +@pytest.mark.parametrize( + "status_code, old_exception_schema", + [(400, False), (500, False), (400, True), (500, True),], +) +def test_client_raise_remote_exception( + rpc_client, requests_mock, status_code, old_exception_schema +): + """ + Exceptions caught server-side and not whitelisted will be wrapped and raised + as a RemoteException client-side. + """ + error_message = "something went wrong" + endpoint = "raise_remote_exception" + + requests_mock.post( + re.compile(f"mock://example.com/{endpoint}"), + content=_exception_response( + exception=Exception(error_message), + status_code=status_code, + old_exception_schema=old_exception_schema, + ), + ) + + with pytest.raises(RemoteException) as exc_info: + rpc_client.post(endpoint, data={}) + + assert str(exc_info.value.args[0]["type"]) == "Exception" + assert str(exc_info.value.args[0]["message"]) == error_message