diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index f88f959..8ffc3b0 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,455 +1,456 @@ # Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union, ) from flask import Flask, Request, Response, abort, request import requests from werkzeug.exceptions import HTTPException from .negotiation import Formatter as FormatterBase from .negotiation import Negotiator as NegotiatorBase from .negotiation import negotiate as _negotiate from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) from .serializers import decode_response from .serializers import encode_data_client as encode_data logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, "application/json" ) def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders=None): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = "json" mimetypes = ["application/json"] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = "msgpack" mimetypes = ["application/x-msgpack"] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__( self, payload: Optional[Any] = None, response: Optional[requests.Response] = None, ): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if ( self.args and isinstance(self.args[0], dict) and "type" in self.args[0] and "args" in self.args[0] ): return ( f"' ) else: return super().__str__() F = TypeVar("F", bound=Callable) -def remote_api_endpoint(path) -> Callable[[F], F]: +def remote_api_endpoint(path: str, method: str = "POST") -> Callable[[F], F]: def dec(f: F) -> F: f._endpoint_path = path # type: ignore + f._method = method # type: ignore return f return dec class APIError(Exception): """API Error""" def __str__(self): return "An unexpected error occurred in the backend: {}".format(self.args) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get("backend_class", None) for base in bases: if backend_class is not None: break backend_class = getattr(base, "backend_class", None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop("self") post_data.pop("cur", None) post_data.pop("db", None) # Send the request. return self.post(meth._endpoint_path, post_data) if meth_name not in attributes: attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs, ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith("/") else url + "/" self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get("max_retries", 3), pool_connections=kwargs.get("pool_connections", 20), pool_maxsize=kwargs.get("pool_maxsize", 100), ) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return "%s%s" % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts["stream"] = True if self.timeout and "timeout" not in opts: opts["timeout"] = self.timeout try: return getattr(self.session, verb)(self._url(endpoint), **opts) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, data=data, headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", }, **opts, ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload="404 not found", response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data["exception"]["type"]: exception = exc_type(*data["exception"]["args"]) break else: exception = RemoteException( payload=data["exception"], response=response ) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) if "exception_pickled" in data: exception = pickle.loads(data["exception_pickled"]) else: exception = RemoteException( payload=data["exception"], response=response ) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f"API HTTP error: {status_code} {response.content}", response=response, ) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) return decode_response(response, extra_decoders=self.extra_type_decoders) def __repr__(self): return "<{} url={}>".format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = "utf-8" encoding_errors = "surrogateescape" ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { "application/x-msgpack": msgpack_dumps, "application/json": json_dumps, } def encode_data_server( data, content_type="application/x-msgpack", extra_type_encoders=None ): encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) return Response(encoded_data, mimetype=content_type,) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json_loads(data.decode("utf-8"), extra_decoders=extra_decoders) else: raise ValueError("Wrong content type `%s` for API request" % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) if backend_class is None and backend_factory is not None: raise ValueError( "backend_factory should only be provided if backend_class is" ) self.backend_class = backend_class if backend_class is not None: backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route("/" + meth._endpoint_path, methods=["POST"]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) kw = decode_request(request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index acc8e50..65f3916 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,183 +1,186 @@ # Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import OrderedDict import functools import logging from typing import Callable, Dict, List, Optional, Tuple, Type, Union import aiohttp.web from aiohttp_utils import Response, negotiation from deprecated import deprecated import multidict from .serializers import ( exception_to_dict, json_dumps, json_loads, msgpack_dumps, msgpack_loads, ) def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict({"Content-Type": "application/x-msgpack"}), **kwargs, ) encode_data_server = Response def render_msgpack(request, data, extra_encoders=None): return msgpack_dumps(data, extra_encoders=extra_encoders) def render_json(request, data, extra_encoders=None): return json_dumps(data, extra_encoders=extra_encoders) def decode_data(data, content_type, extra_decoders=None): """Decode data according to content type, eventually using some extra decoders. """ if not data: return {} if content_type == "application/x-msgpack": r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == "application/json": r = json_loads(data, extra_decoders=extra_decoders) else: raise ValueError(f"Wrong content type `{content_type}` for API request") return r async def decode_request(request, extra_decoders=None): """Decode asynchronously the request """ data = await request.read() return decode_data(data, request.content_type, extra_decoders=extra_decoders) async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) res = exception_to_dict(e) if isinstance(e, app.client_exception_classes): status = 400 else: status = 500 return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ client_exception_classes: Tuple[Type[Exception], ...] = () """Exceptions that should be handled as a client error (eg. object not found, invalid argument)""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__( self, app_name: Optional[str] = None, backend_class: Optional[Callable] = None, backend_factory: Optional[Union[Callable, str]] = None, middlewares=(), **kwargs, ): nego_middleware = negotiation.negotiation_middleware( renderers=self._renderers(), force_rendering=True ) middlewares = (nego_middleware, error_middleware,) + middlewares super().__init__(middlewares=middlewares, **kwargs) # swh decorations starts here self.app_name = app_name if backend_class is None and backend_factory is not None: raise ValueError( "backend_factory should only be provided if backend_class is" ) self.backend_class = backend_class if backend_class is not None: backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): path = meth._endpoint_path + http_method = meth._method path = path if path.startswith("/") else f"/{path}" self.router.add_route( - "POST", path, self._endpoint(meth_name, meth, backend_factory) + http_method, + path, + self._endpoint(meth_name, meth, backend_factory), ) def _renderers(self): """Return an ordered list of renderers in order of increasing desirability (!) See mimetype.best_match() docstring """ return OrderedDict( [ ( "application/json", lambda request, data: render_json( request, data, extra_encoders=self.extra_type_encoders ), ), ( "application/x-msgpack", lambda request, data: render_msgpack( request, data, extra_encoders=self.extra_type_encoders ), ), ] ) def _endpoint(self, meth_name, meth, backend_factory): """Create endpoint out of the method `meth`. """ @functools.wraps(meth) # Copy signature and doc async def decorated_meth(request, *args, **kwargs): obj_meth = getattr(backend_factory(), meth_name) data = await request.read() kw = decode_data( data, request.content_type, extra_decoders=self.extra_type_decoders ) result = obj_meth(**kw) return encode_data_server(result) return decorated_meth @deprecated(version="0.0.64", reason="Use the RPCServerApp instead") class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/tests/test_init.py b/swh/core/api/tests/test_init.py new file mode 100644 index 0000000..2495336 --- /dev/null +++ b/swh/core/api/tests/test_init.py @@ -0,0 +1,29 @@ +# Copyright (C) 2017-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +from swh.core.api import remote_api_endpoint + + +def test_remote_api_endpoint(): + @remote_api_endpoint("hello_route") + def hello(): + pass + + assert hasattr(hello, "_endpoint_path") + assert hello._endpoint_path == "hello_route" + assert hasattr(hello, "_method") + assert hello._method == "POST" + + +def test_remote_api_endpoint_2(): + @remote_api_endpoint("another_route", method="GET") + def hello2(): + pass + + assert hasattr(hello2, "_endpoint_path") + assert hello2._endpoint_path == "another_route" + assert hasattr(hello2, "_method") + assert hello2._method == "GET" diff --git a/swh/core/api/tests/test_rpc_server_asynchronous.py b/swh/core/api/tests/test_rpc_server_asynchronous.py index bfa80bf..77cd16e 100644 --- a/swh/core/api/tests/test_rpc_server_asynchronous.py +++ b/swh/core/api/tests/test_rpc_server_asynchronous.py @@ -1,153 +1,153 @@ # Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.core.api import remote_api_endpoint from swh.core.api.asynchronous import RPCServerApp from swh.core.api.serializers import json_dumps, msgpack_dumps from .test_serializers import ExtraType, extra_decoders, extra_encoders class MyRPCServerApp(RPCServerApp): extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders class BackendStorageTest: """Backend Storage to use as backend class of the rpc server (test only)""" - @remote_api_endpoint("test_endpoint_url") + @remote_api_endpoint("test_endpoint_url", method="GET") def test_endpoint(self, test_data, db=None, cur=None): assert test_data == "spam" return "egg" @remote_api_endpoint("path/to/identity") def identity(self, data, db=None, cur=None): return data @remote_api_endpoint("serializer_test") def serializer_test(self, data, db=None, cur=None): assert data == ["foo", ExtraType("bar", b"baz")] return ExtraType({"spam": "egg"}, "qux") @pytest.fixture def async_app(): return MyRPCServerApp("testapp", backend_class=BackendStorageTest) def test_api_async_rpc_server_app_ok(async_app): assert isinstance(async_app, MyRPCServerApp) actual_rpc_server2 = MyRPCServerApp( "app2", backend_class=BackendStorageTest, backend_factory=BackendStorageTest ) assert isinstance(actual_rpc_server2, MyRPCServerApp) actual_rpc_server3 = MyRPCServerApp("app3") assert isinstance(actual_rpc_server3, MyRPCServerApp) def test_api_async_rpc_server_app_misconfigured(): expected_error = "backend_factory should only be provided if backend_class is" with pytest.raises(ValueError, match=expected_error): MyRPCServerApp("failed-app", backend_factory="something-to-make-it-raise") @pytest.fixture def cli(loop, aiohttp_client, async_app): """aiohttp client fixture to ease testing source: https://docs.aiohttp.org/en/stable/testing.html """ loop.set_debug(True) return loop.run_until_complete(aiohttp_client(async_app)) async def test_api_async_endpoint(cli, async_app): res = await cli.post( "/path/to/identity", headers=[("Content-Type", "application/json"), ("Accept", "application/json")], data=json_dumps({"data": "toto"}), ) assert res.status == 200 assert res.content_type == "application/json" assert await res.read() == json_dumps("toto").encode() async def test_api_async_nego_default_msgpack(cli): res = await cli.post( "/path/to/identity", headers=[("Content-Type", "application/json")], data=json_dumps({"data": "toto"}), ) assert res.status == 200 assert res.content_type == "application/x-msgpack" assert await res.read() == msgpack_dumps("toto") async def test_api_async_nego_default(cli): res = await cli.post( "/path/to/identity", headers=[ ("Content-Type", "application/json"), ("Accept", "application/x-msgpack"), ], data=json_dumps({"data": "toto"}), ) assert res.status == 200 assert res.content_type == "application/x-msgpack" assert await res.read() == msgpack_dumps("toto") async def test_api_async_nego_accept(cli): res = await cli.post( "/path/to/identity", headers=[ ("Accept", "application/x-msgpack"), ("Content-Type", "application/x-msgpack"), ], data=msgpack_dumps({"data": "toto"}), ) assert res.status == 200 assert res.content_type == "application/x-msgpack" assert await res.read() == msgpack_dumps("toto") async def test_api_async_rpc_server(cli): - res = await cli.post( + res = await cli.get( "/test_endpoint_url", headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=msgpack_dumps({"test_data": "spam"}), ) assert res.status == 200 assert res.content_type == "application/x-msgpack" assert await res.read() == msgpack_dumps("egg") async def test_api_async_rpc_server_extra_serializers(cli): res = await cli.post( "/serializer_test", headers=[ ("Content-Type", "application/x-msgpack"), ("Accept", "application/x-msgpack"), ], data=( b"\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\xa3bar\xc4\x03baz" ), ) assert res.status == 200 assert res.content_type == "application/x-msgpack" assert await res.read() == ( b"\x82\xc4\x07swhtype\xa9extratype\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" )