diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -4,8 +4,9 @@ # See top-level LICENSE file for more information from collections import OrderedDict +import functools import logging -from typing import Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import aiohttp.web from aiohttp_utils import Response, negotiation @@ -32,28 +33,38 @@ encode_data_server = Response -def render_msgpack(request, data): - return msgpack_dumps(data) +def render_msgpack(request, data, extra_encoders=None): + return msgpack_dumps(data, extra_encoders=extra_encoders) -def render_json(request, data): - return json_dumps(data) +def render_json(request, data, extra_encoders=None): + return json_dumps(data, extra_encoders=extra_encoders) -async def decode_request(request): - content_type = request.headers.get("Content-Type").split(";")[0].strip() - data = await request.read() +def decode_data(data, content_type, extra_decoders=None): + """Decode data according to content type, eventually using some extra decoders. + + """ if not data: return {} if content_type == "application/x-msgpack": - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": - r = json_loads(data) + r = json_loads(data, extra_decoders=extra_decoders) else: - raise ValueError("Wrong content type `%s` for API request" % content_type) + raise ValueError(f"Wrong content type `{content_type}` for API request") + return r +async def decode_request(request, extra_decoders=None): + """Decode asynchronously the request + + """ + data = await request.read() + return decode_data(data, request.content_type, extra_decoders=extra_decoders) + + async def error_middleware(app, handler): async def middleware_handler(request): try: @@ -73,26 +84,96 @@ class RPCServerApp(aiohttp.web.Application): + """For each endpoint of the given `backend_class`, tells app.route to call + a function that decodes the request and sends it to the backend object + provided by the factory. + + :param Any backend_class: + The class of the backend, which will be analyzed to look + for API endpoints. + :param Optional[Callable[[], backend_class]] backend_factory: + A function with no argument that returns an instance of + `backend_class`. If unset, defaults to calling `backend_class` + constructor directly. + """ + client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" - - def __init__(self, *args, middlewares=(), **kwargs): + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + + def __init__( + self, + app_name: Optional[str] = None, + backend_class: Optional[Callable] = None, + backend_factory: Optional[Union[Callable, str]] = None, + middlewares=(), + **kwargs, + ): middlewares = (error_middleware,) + middlewares + # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. renderers = OrderedDict( [ - ("application/json", render_json), - ("application/x-msgpack", render_msgpack), + ( + "application/json", + lambda request, data: render_json( + request, data, extra_encoders=self.extra_type_encoders + ), + ), + ( + "application/x-msgpack", + lambda request, data: render_msgpack( + request, data, extra_encoders=self.extra_type_encoders + ), + ), ] ) nego_middleware = negotiation.negotiation_middleware( renderers=renderers, force_rendering=True ) middlewares = (nego_middleware,) + middlewares - - super().__init__(*args, middlewares=middlewares, **kwargs) + super().__init__(middlewares=middlewares, **kwargs) + + # swh decorations starts here + self.app_name = app_name + + # FIXME: Can't we transform this into a middleware? + self.backend_class = backend_class + if backend_class is not None: + if backend_factory is None: + backend_factory = backend_class + for (meth_name, meth) in backend_class.__dict__.items(): + if hasattr(meth, "_endpoint_path"): + path = meth_name if meth_name.startswith("/") else f"/{meth_name}" + self.router.add_route( + "POST", path, self._endpoint(meth_name, meth, backend_factory) + ) + + def _endpoint(self, meth_name, meth, backend_factory): + """Create endpoint out of the method `meth`. + + """ + from asyncio import coroutine + + @coroutine + @functools.wraps(meth) # Copy signature and doc + async def decorated_meth(request, *args, **kwargs): + obj_meth = getattr(backend_factory(), meth_name) + data = await request.read() + kw = decode_data( + data, request.content_type, extra_decoders=self.extra_type_decoders + ) + result = obj_meth(**kw) + return encode_data_server(result) + + return decorated_meth @deprecated(version="0.0.64", reason="Use the RPCServerApp instead") diff --git a/swh/core/api/tests/conftest.py b/swh/core/api/tests/conftest.py new file mode 100644 --- /dev/null +++ b/swh/core/api/tests/conftest.py @@ -0,0 +1,3 @@ +# This is coming from the aiohttp library directly. Beware the desynchronized +# https://github.com/aio-libs/pytest-aiohttp module which wraps that library... +pytest_plugins = ["aiohttp.pytest_plugin"] diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -12,10 +12,11 @@ from swh.core.api.asynchronous import ( Response, RPCServerApp, + decode_data, decode_request, encode_msgpack, ) -from swh.core.api.serializers import SWHJSONEncoder, msgpack_dumps +from swh.core.api.serializers import SWHJSONEncoder, json_dumps, msgpack_dumps pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] @@ -231,3 +232,25 @@ assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT + + +def test_async_decode_data_failure(): + with pytest.raises(ValueError, match="Wrong content type"): + decode_data("some-data", "unknown-content-type") + + +@pytest.mark.parametrize("data", [None, "", {}, []]) +def test_async_decode_data_empty_cases(data): + assert decode_data(data, "unknown-content-type") == {} + + +@pytest.mark.parametrize( + "data,content_type,encode_data_fn", + [ + ({"a": 1}, "application/json", json_dumps), + ({"a": 1}, "application/x-msgpack", msgpack_dumps), + ], +) +def test_async_decode_data_nominal(data, content_type, encode_data_fn): + actual_data = decode_data(encode_data_fn(data), content_type) + assert actual_data == data diff --git a/swh/core/api/tests/test_rpc_server_asynchronous.py b/swh/core/api/tests/test_rpc_server_asynchronous.py new file mode 100644 --- /dev/null +++ b/swh/core/api/tests/test_rpc_server_asynchronous.py @@ -0,0 +1,132 @@ +# Copyright (C) 2018-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.core.api import remote_api_endpoint +from swh.core.api.asynchronous import RPCServerApp +from swh.core.api.serializers import json_dumps, msgpack_dumps + +from .test_serializers import ExtraType, extra_decoders, extra_encoders + + +class MyRPCServerApp(RPCServerApp): + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders + + +@pytest.fixture +def async_app(): + class TestStorage: + @remote_api_endpoint("test_endpoint_url") + def test_endpoint(self, test_data, db=None, cur=None): + assert test_data == "spam" + return "egg" + + @remote_api_endpoint("path/to/endpoint") + def identity(self, data, db=None, cur=None): + return data + + @remote_api_endpoint("serializer_test") + def serializer_test(self, data, db=None, cur=None): + assert data == ["foo", ExtraType("bar", b"baz")] + return ExtraType({"spam": "egg"}, "qux") + + return MyRPCServerApp("testapp", backend_class=TestStorage) + + +@pytest.fixture +def cli(loop, aiohttp_client, async_app): + """aiohttp client fixture to ease testing + + source: https://docs.aiohttp.org/en/stable/testing.html + """ + loop.set_debug(True) + return loop.run_until_complete(aiohttp_client(async_app)) + + +async def test_api_async_endpoint(cli, async_app): + res = await cli.post( + "/identity", + headers=[("Content-Type", "application/json"), ("Accept", "application/json")], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/json" + assert await res.read() == json_dumps("toto").encode() + + +async def test_api_async_nego_default_msgpack(cli): + res = await cli.post( + "/identity", + headers=[("Content-Type", "application/json")], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_nego_default(cli): + res = await cli.post( + "/identity", + headers=[ + ("Content-Type", "application/json"), + ("Accept", "application/x-msgpack"), + ], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_nego_accept(cli): + res = await cli.post( + "/identity", + headers=[ + ("Accept", "application/x-msgpack"), + ("Content-Type", "application/x-msgpack"), + ], + data=msgpack_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_rpc_server(cli): + res = await cli.post( + "/test_endpoint", + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=msgpack_dumps({"test_data": "spam"}), + ) + + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("egg") + + +async def test_api_async_rpc_server_extra_serializers(cli): + res = await cli.post( + "/serializer_test", + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=( + b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\xa3bar\xc4\x03baz" + ), + ) + + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == ( + b"\x82\xc4\x07swhtype\xa9extratype\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) diff --git a/tox.ini b/tox.ini --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist=black,flake8,mypy,py3-{core,db,server} [testenv] +passenv = PYTHONASYNCIODEBUG extras = testing-core core: logging