Changeset View
Changeset View
Standalone View
Standalone View
swh/core/api/tests/test_rpc_client.py
# Copyright (C) 2018-2019 The Software Heritage developers | # Copyright (C) 2018-2019 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
import re | import re | ||||
import pytest | import pytest | ||||
from requests.exceptions import ConnectionError | |||||
from swh.core.api import RPCClient, remote_api_endpoint | from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint | ||||
from swh.core.api.serializers import exception_to_dict, msgpack_dumps | |||||
from .test_serializers import ExtraType, extra_decoders, extra_encoders | from .test_serializers import ExtraType, extra_decoders, extra_encoders | ||||
class ReraiseException(Exception): | |||||
pass | |||||
@pytest.fixture | @pytest.fixture | ||||
def rpc_client(requests_mock): | def rpc_client(requests_mock): | ||||
class TestStorage: | class TestStorage: | ||||
@remote_api_endpoint("test_endpoint_url") | @remote_api_endpoint("test_endpoint_url") | ||||
def test_endpoint(self, test_data, db=None, cur=None): | def test_endpoint(self, test_data, db=None, cur=None): | ||||
... | ... | ||||
@remote_api_endpoint("path/to/endpoint") | @remote_api_endpoint("path/to/endpoint") | ||||
def something(self, data, db=None, cur=None): | def something(self, data, db=None, cur=None): | ||||
... | ... | ||||
@remote_api_endpoint("serializer_test") | @remote_api_endpoint("serializer_test") | ||||
def serializer_test(self, data, db=None, cur=None): | def serializer_test(self, data, db=None, cur=None): | ||||
... | ... | ||||
@remote_api_endpoint("overridden/endpoint") | @remote_api_endpoint("overridden/endpoint") | ||||
def overridden_method(self, data): | def overridden_method(self, data): | ||||
return "foo" | return "foo" | ||||
class Testclient(RPCClient): | class Testclient(RPCClient): | ||||
backend_class = TestStorage | backend_class = TestStorage | ||||
extra_type_encoders = extra_encoders | extra_type_encoders = extra_encoders | ||||
extra_type_decoders = extra_decoders | extra_type_decoders = extra_decoders | ||||
reraise_exceptions = [ReraiseException] | |||||
def overridden_method(self, data): | def overridden_method(self, data): | ||||
return "bar" | return "bar" | ||||
def callback(request, context): | def callback(request, context): | ||||
assert request.headers["Content-Type"] == "application/x-msgpack" | assert request.headers["Content-Type"] == "application/x-msgpack" | ||||
context.headers["Content-Type"] = "application/x-msgpack" | context.headers["Content-Type"] = "application/x-msgpack" | ||||
if request.path == "/test_endpoint_url": | if request.path == "/test_endpoint_url": | ||||
Show All 33 Lines | |||||
def test_client_extra_serializers(rpc_client): | def test_client_extra_serializers(rpc_client): | ||||
res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) | res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) | ||||
assert res == ExtraType({"spam": "egg"}, "qux") | assert res == ExtraType({"spam": "egg"}, "qux") | ||||
def test_client_overridden_method(rpc_client): | def test_client_overridden_method(rpc_client): | ||||
res = rpc_client.overridden_method("foo") | res = rpc_client.overridden_method("foo") | ||||
assert res == "bar" | assert res == "bar" | ||||
def test_client_connexion_error(rpc_client, requests_mock): | |||||
""" | |||||
ConnectionError should be wrapped and raised as an APIError. | |||||
""" | |||||
error_message = "unreachable host" | |||||
requests_mock.post( | |||||
re.compile("mock://example.com/connection_error"), | |||||
exc=ConnectionError(error_message), | |||||
) | |||||
with pytest.raises(APIError) as exc_info: | |||||
rpc_client.post("connection_error", data={}) | |||||
assert type(exc_info.value.args[0]) == ConnectionError | |||||
assert str(exc_info.value.args[0]) == error_message | |||||
def _exception_response(exception, status_code, old_exception_schema=False): | |||||
def callback(request, context): | |||||
assert request.headers["Content-Type"] == "application/x-msgpack" | |||||
context.headers["Content-Type"] = "application/x-msgpack" | |||||
exc_dict = exception_to_dict(exception) | |||||
if old_exception_schema: | |||||
exc_dict = {"exception": exc_dict} | |||||
context.content = msgpack_dumps(exc_dict) | |||||
context.status_code = status_code | |||||
return context.content | |||||
return callback | |||||
@pytest.mark.parametrize("old_exception_schema", [False, True]) | |||||
def test_client_reraise_exception(rpc_client, requests_mock, old_exception_schema): | |||||
""" | |||||
Exception caught server-side and whitelisted will be raised again client-side. | |||||
""" | |||||
error_message = "something went wrong" | |||||
endpoint = "reraise_exception" | |||||
requests_mock.post( | |||||
re.compile(f"mock://example.com/{endpoint}"), | |||||
content=_exception_response( | |||||
exception=ReraiseException(error_message), | |||||
status_code=400, | |||||
old_exception_schema=old_exception_schema, | |||||
), | |||||
) | |||||
with pytest.raises(ReraiseException) as exc_info: | |||||
rpc_client.post(endpoint, data={}) | |||||
assert str(exc_info.value) == error_message | |||||
vlorentz: with a docstring, explaining it raises a generic RemoteException because Exception isn't… | |||||
@pytest.mark.parametrize( | |||||
"status_code, old_exception_schema", | |||||
[(400, False), (500, False), (400, True), (500, True),], | |||||
) | |||||
def test_client_raise_remote_exception( | |||||
rpc_client, requests_mock, status_code, old_exception_schema | |||||
): | |||||
""" | |||||
Exception caught server-side and not whitelisted will be wrapped and raised | |||||
as a RemoteException client-side. | |||||
""" | |||||
error_message = "something went wrong" | |||||
endpoint = "raise_remote_exception" | |||||
requests_mock.post( | |||||
re.compile(f"mock://example.com/{endpoint}"), | |||||
content=_exception_response( | |||||
exception=Exception(error_message), | |||||
status_code=status_code, | |||||
old_exception_schema=old_exception_schema, | |||||
), | |||||
) | |||||
with pytest.raises(RemoteException) as exc_info: | |||||
rpc_client.post(endpoint, data={}) | |||||
assert str(exc_info.value.args[0]["type"]) == "Exception" | |||||
assert str(exc_info.value.args[0]["message"]) == error_message |
with a docstring, explaining it raises a generic RemoteException because Exception isn't whitelisted