diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -6,12 +6,18 @@ import re import pytest +from requests.exceptions import ConnectionError -from swh.core.api import RPCClient, remote_api_endpoint +from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint +from swh.core.api.serializers import exception_to_dict, msgpack_dumps from .test_serializers import ExtraType, extra_decoders, extra_encoders +class ReraiseException(Exception): + pass + + @pytest.fixture def rpc_client(requests_mock): class TestStorage: @@ -35,6 +41,7 @@ backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders + reraise_exceptions = [ReraiseException] def overridden_method(self, data): return "bar" @@ -84,3 +91,64 @@ def test_client_overridden_method(rpc_client): res = rpc_client.overridden_method("foo") assert res == "bar" + + +def test_client_connexion_error(rpc_client, requests_mock): + error_message = "unreachable host" + requests_mock.post( + re.compile("mock://example.com/connection_error"), + exc=ConnectionError(error_message), + ) + + with pytest.raises(APIError) as exc_info: + rpc_client.post("connection_error", data={}) + + assert type(exc_info.value.args[0]) == ConnectionError + assert str(exc_info.value.args[0]) == error_message + + +def _exception_response(exception, status_code): + def callback(request, context): + assert request.headers["Content-Type"] == "application/x-msgpack" + context.headers["Content-Type"] = "application/x-msgpack" + context.content = msgpack_dumps(exception_to_dict(exception)) + context.status_code = status_code + return context.content + + return callback + + +def test_client_reraise_exception(rpc_client, requests_mock): + error_message = "something went wrong" + endpoint = "reraise_exception" + + requests_mock.post( + re.compile(f"mock://example.com/{endpoint}"), + content=_exception_response( + exception=ReraiseException(error_message), status_code=400 + ), + ) + + with pytest.raises(ReraiseException) as exc_info: + rpc_client.post(endpoint, data={}) + + assert str(exc_info.value) == error_message + + +@pytest.mark.parametrize("status_code", [400, 500]) +def test_client_raise_remote_exception(rpc_client, requests_mock, status_code): + error_message = "something went wrong" + endpoint = "raise_remote_exception" + + requests_mock.post( + re.compile(f"mock://example.com/{endpoint}"), + content=_exception_response( + exception=Exception(error_message), status_code=status_code + ), + ) + + with pytest.raises(RemoteException) as exc_info: + rpc_client.post(endpoint, data={}) + + assert str(exc_info.value.args[0]["type"]) == "Exception" + assert str(exc_info.value.args[0]["message"]) == error_message