diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -123,9 +123,10 @@ F = TypeVar("F", bound=Callable) -def remote_api_endpoint(path) -> Callable[[F], F]: +def remote_api_endpoint(path: str, method: str = "POST") -> Callable[[F], F]: def dec(f: F) -> F: f._endpoint_path = path # type: ignore + f._method = method # type: ignore return f return dec diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -133,9 +133,12 @@ for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): path = meth._endpoint_path + http_method = meth._method path = path if path.startswith("/") else f"/{path}" self.router.add_route( - "POST", path, self._endpoint(meth_name, meth, backend_factory) + http_method, + path, + self._endpoint(meth_name, meth, backend_factory), ) def _renderers(self): diff --git a/swh/core/api/tests/test_init.py b/swh/core/api/tests/test_init.py new file mode 100644 --- /dev/null +++ b/swh/core/api/tests/test_init.py @@ -0,0 +1,29 @@ +# Copyright (C) 2017-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +from swh.core.api import remote_api_endpoint + + +def test_remote_api_endpoint(): + @remote_api_endpoint("hello_route") + def hello(): + pass + + assert hasattr(hello, "_endpoint_path") + assert hello._endpoint_path == "hello_route" + assert hasattr(hello, "_method") + assert hello._method == "POST" + + +def test_remote_api_endpoint_2(): + @remote_api_endpoint("another_route", method="GET") + def hello2(): + pass + + assert hasattr(hello2, "_endpoint_path") + assert hello2._endpoint_path == "another_route" + assert hasattr(hello2, "_method") + assert hello2._method == "GET" diff --git a/swh/core/api/tests/test_rpc_server_asynchronous.py b/swh/core/api/tests/test_rpc_server_asynchronous.py --- a/swh/core/api/tests/test_rpc_server_asynchronous.py +++ b/swh/core/api/tests/test_rpc_server_asynchronous.py @@ -20,7 +20,7 @@ class BackendStorageTest: """Backend Storage to use as backend class of the rpc server (test only)""" - @remote_api_endpoint("test_endpoint_url") + @remote_api_endpoint("test_endpoint_url", method="GET") def test_endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @@ -119,7 +119,7 @@ async def test_api_async_rpc_server(cli): - res = await cli.post( + res = await cli.get( "/test_endpoint_url", headers=[ ("Content-Type", "application/x-msgpack"),