diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -429,11 +429,14 @@ def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) + if backend_class is None and backend_factory is not None: + raise ValueError( + "backend_factory should only be provided if backend_class is" + ) self.backend_class = backend_class if backend_class is not None: - if backend_factory is None: - backend_factory = backend_class + backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -25,24 +25,43 @@ extra_type_decoders = extra_decoders +class TestStorage: + @remote_api_endpoint("test_endpoint_url") + def endpoint_test(self, test_data, db=None, cur=None): + assert test_data == "spam" + return "egg" + + @remote_api_endpoint("path/to/endpoint") + def something(self, data, db=None, cur=None): + return data + + @remote_api_endpoint("serializer_test") + def serializer_test(self, data, db=None, cur=None): + assert data == ["foo", ExtraType("bar", b"baz")] + return ExtraType({"spam": "egg"}, "qux") + + @pytest.fixture def app(): - class TestStorage: - @remote_api_endpoint("test_endpoint_url") - def test_endpoint(self, test_data, db=None, cur=None): - assert test_data == "spam" - return "egg" + return MyRPCServerApp("testapp", backend_class=TestStorage) - @remote_api_endpoint("path/to/endpoint") - def something(self, data, db=None, cur=None): - return data - @remote_api_endpoint("serializer_test") - def serializer_test(self, data, db=None, cur=None): - assert data == ["foo", ExtraType("bar", b"baz")] - return ExtraType({"spam": "egg"}, "qux") +def test_api_rpc_server_app_ok(app): + assert isinstance(app, MyRPCServerApp) - return MyRPCServerApp("testapp", backend_class=TestStorage) + actual_rpc_server2 = MyRPCServerApp( + "app2", backend_class=TestStorage, backend_factory=TestStorage + ) + assert isinstance(actual_rpc_server2, MyRPCServerApp) + + actual_rpc_server3 = MyRPCServerApp("app3") + assert isinstance(actual_rpc_server3, MyRPCServerApp) + + +def test_api_rpc_server_app_misconfigured(): + expected_error = "backend_factory should only be provided if backend_class is" + with pytest.raises(ValueError, match=expected_error): + MyRPCServerApp("failed-app", backend_factory="something-to-make-it-raise") def test_api_endpoint(flask_app_client): @@ -82,7 +101,7 @@ def test_rpc_server(flask_app_client): res = flask_app_client.post( - url_for("test_endpoint"), + url_for("endpoint_test"), headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"),