diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -160,11 +160,12 @@ if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): - cls.__add_endpoint(meth_name, meth, attributes) + http_method = meth._method # POST by default + cls.__add_endpoint(http_method, meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod - def __add_endpoint(meth_name, meth, attributes): + def __add_endpoint(http_method: str, meth_name: str, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc @@ -178,7 +179,11 @@ post_data.pop("db", None) # Send the request. - return self.post(meth._endpoint_path, post_data) + if http_method == "POST": + return self.post(meth._endpoint_path, post_data) + else: + data = post_data or {} + return self.get(meth._endpoint_path, data=data) if meth_name not in attributes: attributes[meth_name] = meth_ @@ -256,15 +261,11 @@ raise self.api_exception(e) def post(self, endpoint, data, **opts): - if isinstance(data, (abc.Iterator, abc.Generator)): - data = (self._encode_data(x) for x in data) - else: - data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, - data=data, + data=self._encode_data(data), headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", @@ -278,15 +279,34 @@ return self._decode_response(response) def _encode_data(self, data): - return encode_data(data, extra_encoders=self.extra_type_encoders) + if isinstance(data, (abc.Iterator, abc.Generator)): + data = ( + encode_data(x, extra_encoders=self.extra_type_encoders) for x in data + ) + else: + data = encode_data(data, extra_encoders=self.extra_type_encoders) + return data post_stream = post - def get(self, endpoint, **opts): + def get(self, endpoint: str, data={}, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) - response = self.raw_verb( - "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts - ) + if data: + response = self.raw_verb( + "get", + endpoint, + headers={ + "accept": "application/x-msgpack", + "content-type": "application/x-msgpack", + }, + data=self._encode_data(data), + **opts, + ) + else: + response = self.raw_verb( + "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) @@ -440,12 +460,13 @@ backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): - self.__add_endpoint(meth_name, meth, backend_factory) + http_method = meth._method # default to POST + self.__add_endpoint(http_method, meth_name, meth, backend_factory) - def __add_endpoint(self, meth_name, meth, backend_factory): + def __add_endpoint(self, http_method: str, meth_name: str, meth, backend_factory): from flask import request - @self.route("/" + meth._endpoint_path, methods=["POST"]) + @self.route(f"/{meth._endpoint_path}", methods=[http_method]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2019 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -27,6 +27,10 @@ def serializer_test(self, data, db=None, cur=None): ... + @remote_api_endpoint("another_endpoint", method="GET") + def some_get_method(self, data): + ... + @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): return "foo" @@ -39,7 +43,7 @@ def overridden_method(self, data): return "bar" - def callback(request, context): + def callback_post(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" if request.path == "/test_endpoint_url": @@ -55,7 +59,17 @@ assert False return context.content - requests_mock.post(re.compile("mock://example.com/"), content=callback) + def callback_get(request, context): + context.headers["Content-Type"] = "application/x-msgpack" + + if request.path == "/another_endpoint": + context.content = b"\xc4\x0eanother-result" + else: + assert False + return context.content + + requests_mock.post(re.compile("mock://example.com/"), content=callback_post) + requests_mock.get(re.compile("mock://example.com/"), content=callback_get) return Testclient(url="mock://example.com") @@ -75,6 +89,9 @@ res = rpc_client.something(data="whatever") assert res == "spam" + res = rpc_client.some_get_method(data="something") + assert res == b"another-result" + def test_client_extra_serializers(rpc_client): res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -30,6 +30,10 @@ def raise_typeerror(self): raise TypeError("Did I pass through?") + @remote_api_endpoint("stuff", method="GET") + def get_stuff(self, test_input, db=None, cur=None): + return test_input + # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient @@ -54,6 +58,10 @@ def raise_typeerror(self): return "data" + @remote_api_endpoint("stuff", method="GET") + def get_stuff(self, test_input, db=None, cur=None): + return test_input + class RPCTestClient(RPCClient): backend_class = RPCTest2 @@ -98,6 +106,11 @@ assert res == "egg" +def test_api_endpoint_get_stuff(swh_rpc_client): + res = swh_rpc_client.get_stuff("something") + assert res == "something" + + def test_api_endpoint_args(swh_rpc_client): res = swh_rpc_client.something("whatever") assert res == "whatever"