diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index 4522834..acc8e50 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,100 +1,183 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import OrderedDict +import functools import logging -from typing import Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, Union import aiohttp.web from aiohttp_utils import Response, negotiation from deprecated import deprecated import multidict from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), **kwargs, ) encode_data_server = Response -def render_msgpack(request, data): - return msgpack_dumps(data) +def render_msgpack(request, data, extra_encoders=None): + return msgpack_dumps(data, extra_encoders=extra_encoders) -def render_json(request, data): - return json_dumps(data) +def render_json(request, data, extra_encoders=None): + return json_dumps(data, extra_encoders=extra_encoders) -async def decode_request(request): - content_type = request.headers.get("Content-Type").split(";")[0].strip() - data = await request.read() +def decode_data(data, content_type, extra_decoders=None): + """Decode data according to content type, eventually using some extra decoders. + + """ if not data: return {} if content_type == "application/x-msgpack": - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": - r = json_loads(data) + r = json_loads(data, extra_decoders=extra_decoders) else: - raise ValueError("Wrong content type `%s` for API request" % content_type) + raise ValueError(f"Wrong content type `{content_type}` for API request") + return r +async def decode_request(request, extra_decoders=None): + """Decode asynchronously the request + + """ + data = await request.read() + return decode_data(data, request.content_type, extra_decoders=extra_decoders) + + async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) res = exception_to_dict(e) if isinstance(e, app.client_exception_classes): status = 400 else: status = 500 return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): + """For each endpoint of the given `backend_class`, tells app.route to call + a function that decodes the request and sends it to the backend object + provided by the factory. + + :param Any backend_class: + The class of the backend, which will be analyzed to look + for API endpoints. + :param Optional[Callable[[], backend_class]] backend_factory: + A function with no argument that returns an instance of + `backend_class`. If unset, defaults to calling `backend_class` + constructor directly. + """ + client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" - - def __init__(self, *args, middlewares=(), **kwargs): - middlewares = (error_middleware,) + middlewares - # renderers are sorted in order of increasing desirability (!) - # see mimeparse.best_match() docstring. - renderers = OrderedDict( + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + + def __init__( + self, + app_name: Optional[str] = None, + backend_class: Optional[Callable] = None, + backend_factory: Optional[Union[Callable, str]] = None, + middlewares=(), + **kwargs, + ): + nego_middleware = negotiation.negotiation_middleware( + renderers=self._renderers(), force_rendering=True + ) + middlewares = (nego_middleware, error_middleware,) + middlewares + super().__init__(middlewares=middlewares, **kwargs) + + # swh decorations starts here + self.app_name = app_name + if backend_class is None and backend_factory is not None: + raise ValueError( + "backend_factory should only be provided if backend_class is" + ) + self.backend_class = backend_class + if backend_class is not None: + backend_factory = backend_factory or backend_class + for (meth_name, meth) in backend_class.__dict__.items(): + if hasattr(meth, "_endpoint_path"): + path = meth._endpoint_path + path = path if path.startswith("/") else f"/{path}" + self.router.add_route( + "POST", path, self._endpoint(meth_name, meth, backend_factory) + ) + + def _renderers(self): + """Return an ordered list of renderers in order of increasing desirability (!) + See mimetype.best_match() docstring + + """ + return OrderedDict( [ - ("application/json", render_json), - ("application/x-msgpack", render_msgpack), + ( + "application/json", + lambda request, data: render_json( + request, data, extra_encoders=self.extra_type_encoders + ), + ), + ( + "application/x-msgpack", + lambda request, data: render_msgpack( + request, data, extra_encoders=self.extra_type_encoders + ), + ), ] ) - nego_middleware = negotiation.negotiation_middleware( - renderers=renderers, force_rendering=True - ) - middlewares = (nego_middleware,) + middlewares - super().__init__(*args, middlewares=middlewares, **kwargs) + def _endpoint(self, meth_name, meth, backend_factory): + """Create endpoint out of the method `meth`. + + """ + + @functools.wraps(meth) # Copy signature and doc + async def decorated_meth(request, *args, **kwargs): + obj_meth = getattr(backend_factory(), meth_name) + data = await request.read() + kw = decode_data( + data, request.content_type, extra_decoders=self.extra_type_decoders + ) + result = obj_meth(**kw) + return encode_data_server(result) + + return decorated_meth @deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/tests/conftest.py b/swh/core/api/tests/conftest.py new file mode 100644 index 0000000..b7566b6 --- /dev/null +++ b/swh/core/api/tests/conftest.py @@ -0,0 +1,3 @@ +# This is coming from the aiohttp library directly. Beware the desynchronized +# https://github.com/aio-libs/pytest-aiohttp module which wraps that library... +pytest_plugins = ["aiohttp.pytest_plugin"] diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 8923222..dafbb81 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,233 +1,256 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import json import msgpack import pytest from swh.core.api.asynchronous import ( Response, RPCServerApp, + decode_data, decode_request, encode_msgpack, ) -from swh.core.api.serializers import SWHJSONEncoder, msgpack_dumps +from swh.core.api.serializers import SWHJSONEncoder, json_dumps, msgpack_dumps pytest_plugins = ["aiohttp.pytest_plugin", "pytester"] class TestServerException(Exception): pass class TestClientError(Exception): pass async def root(request): return Response("toor") STRUCT = { "txt": "something stupid", # 'date': datetime.date(2019, 6, 9), # not supported "datetime": datetime.datetime(2019, 6, 9, 10, 12, tzinfo=datetime.timezone.utc), "timedelta": datetime.timedelta(days=-2, hours=3), "int": 42, "float": 3.14, "subdata": { "int": 42, "datetime": datetime.datetime( 2019, 6, 10, 11, 12, tzinfo=datetime.timezone.utc ), }, "list": [ 42, datetime.datetime(2019, 9, 10, 11, 12, tzinfo=datetime.timezone.utc), "ok", ], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) async def server_exception(request): raise TestServerException() async def client_error(request): raise TestClientError() async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(";")[0].strip() dst = dst.split(";")[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() app.client_exception_classes = (TestClientError,) app.router.add_route("GET", "/", root) app.router.add_route("GET", "/struct", struct) app.router.add_route("POST", "/echo", echo) app.router.add_route("GET", "/server_exception", server_exception) app.router.add_route("GET", "/client_error", client_error) app.router.add_route("POST", "/echo-no-nego", echo_no_nego) return app @pytest.fixture def cli(async_app, aiohttp_client, loop): return loop.run_until_complete(aiohttp_client(async_app)) async def test_get_simple(cli) -> None: resp = await cli.get("/") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == "toor" async def test_get_server_exception(cli) -> None: resp = await cli.get("/server_exception") assert resp.status == 500 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestServerException" async def test_get_client_error(cli) -> None: resp = await cli.get("/client_error") assert resp.status == 400 data = await resp.read() data = msgpack.unpackb(data, raw=False) assert data["exception"]["type"] == "TestClientError" async def test_get_simple_nego(cli) -> None: for ctype in ("x-msgpack", "json"): resp = await cli.get("/", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == "toor" async def test_get_struct(cli) -> None: """Test returned structured from a simple GET data is OK""" resp = await cli.get("/struct") assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(cli) -> None: """Test returned structured from a simple GET data is OK""" for ctype in ("x-msgpack", "json"): resp = await cli.get("/struct", headers={"Accept": "application/%s" % ctype}) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(cli) -> None: """Test that msgpack encoded posted struct data is returned as is""" # simple struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps({"toto": 42}), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} # complex struct resp = await cli.post( "/echo", headers={"Content-Type": "application/x-msgpack"}, data=msgpack_dumps(STRUCT), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(cli) -> None: """Test that json encoded posted struct data is returned as is""" resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps({"toto": 42}, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == {"toto": 42} resp = await cli.post( "/echo", headers={"Content-Type": "application/json"}, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(cli) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/%s" % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(cli) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ for ctype in ("x-msgpack", "json"): resp = await cli.post( "/echo-no-nego", headers={ "Content-Type": "application/json", "Accept": "application/%s" % ctype, }, data=json.dumps(STRUCT, cls=SWHJSONEncoder), ) assert resp.status == 200 check_mimetype(resp.headers["Content-Type"], "application/x-msgpack") assert (await decode_request(resp)) == STRUCT + + +def test_async_decode_data_failure(): + with pytest.raises(ValueError, match="Wrong content type"): + decode_data("some-data", "unknown-content-type") + + +@pytest.mark.parametrize("data", [None, "", {}, []]) +def test_async_decode_data_empty_cases(data): + assert decode_data(data, "unknown-content-type") == {} + + +@pytest.mark.parametrize( + "data,content_type,encode_data_fn", + [ + ({"a": 1}, "application/json", json_dumps), + ({"a": 1}, "application/x-msgpack", msgpack_dumps), + ], +) +def test_async_decode_data_nominal(data, content_type, encode_data_fn): + actual_data = decode_data(encode_data_fn(data), content_type) + assert actual_data == data diff --git a/swh/core/api/tests/test_rpc_server_asynchronous.py b/swh/core/api/tests/test_rpc_server_asynchronous.py new file mode 100644 index 0000000..bfa80bf --- /dev/null +++ b/swh/core/api/tests/test_rpc_server_asynchronous.py @@ -0,0 +1,153 @@ +# Copyright (C) 2018-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.core.api import remote_api_endpoint +from swh.core.api.asynchronous import RPCServerApp +from swh.core.api.serializers import json_dumps, msgpack_dumps + +from .test_serializers import ExtraType, extra_decoders, extra_encoders + + +class MyRPCServerApp(RPCServerApp): + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders + + +class BackendStorageTest: + """Backend Storage to use as backend class of the rpc server (test only)""" + + @remote_api_endpoint("test_endpoint_url") + def test_endpoint(self, test_data, db=None, cur=None): + assert test_data == "spam" + return "egg" + + @remote_api_endpoint("path/to/identity") + def identity(self, data, db=None, cur=None): + return data + + @remote_api_endpoint("serializer_test") + def serializer_test(self, data, db=None, cur=None): + assert data == ["foo", ExtraType("bar", b"baz")] + return ExtraType({"spam": "egg"}, "qux") + + +@pytest.fixture +def async_app(): + return MyRPCServerApp("testapp", backend_class=BackendStorageTest) + + +def test_api_async_rpc_server_app_ok(async_app): + assert isinstance(async_app, MyRPCServerApp) + + actual_rpc_server2 = MyRPCServerApp( + "app2", backend_class=BackendStorageTest, backend_factory=BackendStorageTest + ) + assert isinstance(actual_rpc_server2, MyRPCServerApp) + + actual_rpc_server3 = MyRPCServerApp("app3") + assert isinstance(actual_rpc_server3, MyRPCServerApp) + + +def test_api_async_rpc_server_app_misconfigured(): + expected_error = "backend_factory should only be provided if backend_class is" + with pytest.raises(ValueError, match=expected_error): + MyRPCServerApp("failed-app", backend_factory="something-to-make-it-raise") + + +@pytest.fixture +def cli(loop, aiohttp_client, async_app): + """aiohttp client fixture to ease testing + + source: https://docs.aiohttp.org/en/stable/testing.html + """ + loop.set_debug(True) + return loop.run_until_complete(aiohttp_client(async_app)) + + +async def test_api_async_endpoint(cli, async_app): + res = await cli.post( + "/path/to/identity", + headers=[("Content-Type", "application/json"), ("Accept", "application/json")], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/json" + assert await res.read() == json_dumps("toto").encode() + + +async def test_api_async_nego_default_msgpack(cli): + res = await cli.post( + "/path/to/identity", + headers=[("Content-Type", "application/json")], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_nego_default(cli): + res = await cli.post( + "/path/to/identity", + headers=[ + ("Content-Type", "application/json"), + ("Accept", "application/x-msgpack"), + ], + data=json_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_nego_accept(cli): + res = await cli.post( + "/path/to/identity", + headers=[ + ("Accept", "application/x-msgpack"), + ("Content-Type", "application/x-msgpack"), + ], + data=msgpack_dumps({"data": "toto"}), + ) + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("toto") + + +async def test_api_async_rpc_server(cli): + res = await cli.post( + "/test_endpoint_url", + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=msgpack_dumps({"test_data": "spam"}), + ) + + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == msgpack_dumps("egg") + + +async def test_api_async_rpc_server_extra_serializers(cli): + res = await cli.post( + "/serializer_test", + headers=[ + ("Content-Type", "application/x-msgpack"), + ("Accept", "application/x-msgpack"), + ], + data=( + b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\xa3bar\xc4\x03baz" + ), + ) + + assert res.status == 200 + assert res.content_type == "application/x-msgpack" + assert await res.read() == ( + b"\x82\xc4\x07swhtype\xa9extratype\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) diff --git a/tox.ini b/tox.ini index 4e0be5e..b051c9d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,52 +1,53 @@ [tox] envlist=black,flake8,mypy,py3-{core,db,server} [testenv] +passenv = PYTHONASYNCIODEBUG extras = testing-core core: logging db: db, testing-db server: http deps = db: pifpaf cover: pytest-cov commands = db: pifpaf run postgresql -- \ pytest --doctest-modules \ slow: --hypothesis-profile=slow \ cover: --cov={envsitepackagesdir}/swh/core --cov-branch \ core: {envsitepackagesdir}/swh/core/tests \ db: {envsitepackagesdir}/swh/core/db/tests \ server: {envsitepackagesdir}/swh/core/api/tests \ {posargs} [testenv:py3] skip_install = true deps = tox commands = tox -e py3-core-db-server-slow-cover -- {posargs} [testenv:black] skip_install = true deps = black==19.10b0 commands = {envpython} -m black --check swh [testenv:flake8] skip_install = true deps = flake8 commands = {envpython} -m flake8 [testenv:mypy] extras = testing-core logging db testing-db http deps = mypy commands = mypy swh