diff --git a/mypy.ini b/mypy.ini --- a/mypy.ini +++ b/mypy.ini @@ -8,6 +8,9 @@ [mypy-aiohttp_utils.*] ignore_missing_imports = True +[mypy-aioresponses.*] +ignore_missing_imports = True + [mypy-arrow.*] ignore_missing_imports = True diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,18 +1,22 @@ -# Copyright (C) 2015-2020 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc +import contextlib import functools import inspect import logging import pickle from typing import ( Any, + AsyncIterable, Callable, ClassVar, Dict, + Generic, + Iterable, List, Optional, Tuple, @@ -21,6 +25,7 @@ Union, ) +import aiohttp from flask import Flask, Request, Response, abort, request import requests from werkzeug.exceptions import HTTPException @@ -41,6 +46,14 @@ logger = logging.getLogger(__name__) +# Genericity helpers + +TSession = TypeVar("TSession", requests.Session, aiohttp.ClientSession) +"""Session class used by the underlying HTTP library""" +TResponse = TypeVar("TResponse", requests.Response, aiohttp.ClientResponse) +"""Response class used by the underlying HTTP library""" + + # support for content negotiation @@ -85,7 +98,7 @@ # base API classes -class RemoteException(Exception): +class RemoteException(Generic[TResponse], Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception @@ -95,15 +108,13 @@ """ def __init__( - self, - payload: Optional[Any] = None, - response: Optional[requests.Response] = None, + self, payload: Optional[Any] = None, response: Optional[TResponse] = None, ): if payload is not None: super().__init__(payload) else: super().__init__() - self.response = response + self.response: Optional[TResponse] = response def __str__(self): if ( @@ -184,8 +195,8 @@ attributes[meth_name] = meth_ -class RPCClient(metaclass=MetaRPCClient): - """Proxy to an internal SWH RPC +class BaseRPCClient(Generic[TSession, TResponse], metaclass=MetaRPCClient): + """Base class for :class:`RPCClient` and :class:`AsyncRPCClient`. """ @@ -229,21 +240,111 @@ self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith("/") else url + "/" self.url = base_url - self.session = requests.Session() + + self.timeout = timeout + self.chunk_size = chunk_size + + self.session = self._make_session(kwargs) + + def _url(self, endpoint): + return "%s%s" % (self.url, endpoint) + + def _get_status_code(self, response: TResponse) -> int: + raise NotImplementedError("_get_status_code") + + def _decode_response(self, response: TResponse, check_status=True) -> Any: + raise NotImplementedError("_decode_response") + + def raise_for_status(self, response: TResponse) -> None: + """check response HTTP status code and raise an exception if it denotes an + error; do nothing otherwise + + """ + status_code = self._get_status_code(response) + status_class = status_code // 100 + + if status_code == 404: + raise RemoteException(payload="404 not found", response=response) + + exception = None + + # TODO: only old servers send pickled error; stop trying to unpickle + # after they are all upgraded + try: + if status_class == 4: + data = self._decode_response(response, check_status=False) + if isinstance(data, dict): + # TODO: remove "exception" key check once all servers + # are using new schema + exc_data = data["exception"] if "exception" in data else data + for exc_type in self.reraise_exceptions: + if exc_type.__name__ == exc_data["type"]: + exception = exc_type(*exc_data["args"]) + break + else: + exception = RemoteException(payload=exc_data, response=response) + else: + exception = pickle.loads(data) + + elif status_class == 5: + data = self._decode_response(response, check_status=False) + if "exception_pickled" in data: + exception = pickle.loads(data["exception_pickled"]) + else: + # TODO: remove "exception" key check once all servers + # are using new schema + exc_data = data["exception"] if "exception" in data else data + exception = RemoteException(payload=exc_data, response=response) + + except (TypeError, pickle.UnpicklingError): + raise RemoteException(payload=data, response=response) + + if exception: + raise exception from None + + if status_class != 2: + raise RemoteException( + payload=f"API HTTP error: {status_code} {response.content}", + response=response, + ) + + def __repr__(self): + return "<{} url={}>".format(self.__class__.__name__, self.url) + + +class RPCClient(BaseRPCClient[requests.Session, requests.Response]): + """Proxy to an internal SWH RPC + + """ + + def _make_session(self, kwargs): + session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get("max_retries", 3), pool_connections=kwargs.get("pool_connections", 20), pool_maxsize=kwargs.get("pool_maxsize", 100), ) - self.session.mount(self.url, adapter) + session.mount(self.url, adapter) + return session - self.timeout = timeout - self.chunk_size = chunk_size + def _get_status_code(self, response: requests.Response) -> int: + return response.status_code - def _url(self, endpoint): - return "%s%s" % (self.url, endpoint) + def _decode_response(self, response: requests.Response, check_status=True) -> Any: + if check_status: + self.raise_for_status(response) + return decode_response( + response.content, + response.headers["content-type"], + extra_decoders=self.extra_type_decoders, + ) + + def _stream_response( + self, response: requests.Response, chunk_size: int + ) -> Iterable[Any]: + return response.iter_content(chunk_size) - def raw_verb(self, verb, endpoint, **opts): + def raw_verb(self, verb: str, endpoint: str, **opts) -> TResponse: if "chunk_size" in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. @@ -255,7 +356,7 @@ except requests.exceptions.ConnectionError as e: raise self.api_exception(e) - def post(self, endpoint, data, **opts): + def post(self, endpoint: str, data: Any, **opts) -> Any: if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: @@ -273,89 +374,125 @@ ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) - return response.iter_content(chunk_size) + return self._stream_response(response, chunk_size) else: return self._decode_response(response) - def _encode_data(self, data): + def _encode_data(self, data: Any) -> bytes: return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post - def get(self, endpoint, **opts): + def get(self, endpoint: str, **opts) -> Any: chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts ) if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) - return response.iter_content(chunk_size) + return self._stream_response(response, chunk_size) else: return self._decode_response(response) - def get_stream(self, endpoint, **opts): + def get_stream(self, endpoint: str, **opts) -> Iterable[Any]: return self.get(endpoint, stream=True, **opts) - def raise_for_status(self, response) -> None: - """check response HTTP status code and raise an exception if it denotes an - error; do nothing otherwise - """ - status_code = response.status_code - status_class = response.status_code // 100 +class AsyncRPCClient(BaseRPCClient[aiohttp.ClientSession, aiohttp.ClientResponse]): + """Asynchronous proxy to an internal SWH RPC - if status_code == 404: - raise RemoteException(payload="404 not found", response=response) + """ - exception = None + def _make_session(self, kwargs): + return aiohttp.ClientSession( + connector=aiohttp.TCPConnector(limit=kwargs.get("pool_maxsize", 100),), + timeout=aiohttp.ClientTimeout(total=self.timeout), + connector_owner=True, + ) - # TODO: only old servers send pickled error; stop trying to unpickle - # after they are all upgraded - try: - if status_class == 4: - data = self._decode_response(response, check_status=False) - if isinstance(data, dict): - # TODO: remove "exception" key check once all servers - # are using new schema - exc_data = data["exception"] if "exception" in data else data - for exc_type in self.reraise_exceptions: - if exc_type.__name__ == exc_data["type"]: - exception = exc_type(*exc_data["args"]) - break - else: - exception = RemoteException(payload=exc_data, response=response) - else: - exception = pickle.loads(data) + async def __aenter__(self): + print("open") + await self.session.__aenter__() - elif status_class == 5: - data = self._decode_response(response, check_status=False) - if "exception_pickled" in data: - exception = pickle.loads(data["exception_pickled"]) - else: - # TODO: remove "exception" key check once all servers - # are using new schema - exc_data = data["exception"] if "exception" in data else data - exception = RemoteException(payload=exc_data, response=response) + async def __aexit__(self, *args): + print("close") + await self.session.__aexit__(*args) - except (TypeError, pickle.UnpicklingError): - raise RemoteException(payload=data, response=response) + def _get_status_code(self, response: aiohttp.ClientResponse) -> int: + return response.status - if exception: - raise exception from None - - if status_class != 2: - raise RemoteException( - payload=f"API HTTP error: {status_code} {response.content}", - response=response, - ) - - def _decode_response(self, response, check_status=True): + async def _decode_response( + self, response: aiohttp.ClientResponse, check_status=True + ) -> Any: if check_status: self.raise_for_status(response) - return decode_response(response, extra_decoders=self.extra_type_decoders) + return decode_response( + await response.read(), + response.headers["content-type"], + extra_decoders=self.extra_type_decoders, + ) - def __repr__(self): - return "<{} url={}>".format(self.__class__.__name__, self.url) + @contextlib.asynccontextmanager + async def raw_verb(self, verb: str, endpoint: str, **opts): + if "chunk_size" in opts: + # if the chunk_size argument has been passed, consider the user + # also wants stream=True, otherwise, what's the point. + opts["stream"] = True + if self.timeout and "timeout" not in opts: + opts["timeout"] = self.timeout + try: + async with getattr(self.session, verb)(self._url(endpoint), **opts) as r: + yield r + except requests.exceptions.ConnectionError as e: + raise self.api_exception(e) + + async def post(self, endpoint: str, data: Any, **opts) -> Any: + if isinstance(data, (abc.Iterator, abc.Generator)): + data = (self._encode_data(x) for x in data) + else: + data = self._encode_data(data) + chunk_size = opts.pop("chunk_size", self.chunk_size) + async with self.raw_verb( + "post", + endpoint, + data=data, + headers={ + "content-type": "application/x-msgpack", + "accept": "application/x-msgpack", + }, + **opts, + ) as response: + if ( + opts.get("stream") + or response.headers.get("transfer-encoding") == "chunked" + ): + self.raise_for_status(response) + return await self._stream_response(response, chunk_size) + else: + return await self._decode_response(response) + + def _encode_data(self, data: Any) -> bytes: + return encode_data(data, extra_encoders=self.extra_type_encoders) + + post_stream = post + + async def get(self, endpoint: str, **opts) -> Any: + chunk_size = opts.pop("chunk_size", self.chunk_size) + async with self.raw_verb( + "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts + ) as response: + if ( + opts.get("stream") + or response.headers.get("transfer-encoding") == "chunked" + ): + self.raise_for_status(response) + return await self._stream_response(response, chunk_size) + else: + return await self._decode_response(response) + + async def get_stream(self, endpoint: str, **opts) -> AsyncIterable[Any]: + async for v in self.get(endpoint, stream=True, **opts): + yield v class BytesRequest(Request): diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py --- a/swh/core/api/classes.py +++ b/swh/core/api/classes.py @@ -3,9 +3,19 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import asyncio from dataclasses import dataclass, field import itertools -from typing import Callable, Generic, Iterable, List, Optional, TypeVar +from typing import ( + AsyncIterable, + Awaitable, + Callable, + Generic, + Iterable, + List, + Optional, + TypeVar, +) TResult = TypeVar("TResult") TToken = TypeVar("TToken") @@ -59,3 +69,64 @@ res.results, _stream_results(f, *args, page_token=res.next_page_token, **kwargs), ) + + +async def asyncchain(iterable1, iterable2): + """Like itertools.chain, but async""" + async for item in iterable1: + yield item + async for item in iterable2: + yield item + + +async def asynciter(iterable): + """Like iter(), but async""" + for item in iterable: + yield item + + +async def _stream_results_async(f, *args, page_token, **kwargs): + """Helper for stream_results_async() and stream_results_optional_async()""" + while True: + page_result = await f(*args, page_token=page_token, **kwargs) + for res in page_result.results: + yield res + page_token = page_result.next_page_token + if page_token is None: + break + + +async def stream_results_async( + f: Callable[..., Awaitable[PagedResult[TResult, TToken]]], *args, **kwargs +) -> AsyncIterable[TResult]: + """Consume the paginated result and stream the page results + + """ + if "page_token" in kwargs: + raise TypeError('stream_results has no argument "page_token".') + return await _stream_results(f, *args, page_token=None, **kwargs) + + +async def stream_results_optional_async( + f: Callable[..., Awaitable[Optional[PagedResult[TResult, TToken]]]], *args, **kwargs +) -> Awaitable[Optional[AsyncIterable[TResult]]]: + """Like stream_results(), but for functions ``f`` that return an Optional. + + """ + if "page_token" in kwargs: + raise TypeError('stream_results_optional has no argument "page_token".') + res = await f(*args, page_token=None, **kwargs) + if res is None: + future: asyncio.Future[None] = asyncio.Future(loop=asyncio.get_running_loop()) + future.set_result(None) + return future + else: + if res.next_page_token is None: + return asynciter(res.results) + else: + return asyncchain( + asynciter(res.results), + _stream_results_async( + f, *args, page_token=res.next_page_token, **kwargs + ), + ) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from uuid import UUID +import aiohttp import iso8601 import msgpack from requests import Response @@ -128,15 +129,13 @@ raise ValueError("Limits were reached. Please, check your input.\n" + str(e)) -def decode_response(response: Response, extra_decoders=None) -> Any: - content_type = response.headers["content-type"] - +def decode_response(content: bytes, content_type, extra_decoders=None) -> Any: if content_type.startswith("application/x-msgpack"): - r = msgpack_loads(response.content, extra_decoders=extra_decoders) + r = msgpack_loads(content, extra_decoders=extra_decoders) elif content_type.startswith("application/json"): - r = json_loads(response.text, extra_decoders=extra_decoders) + r = json_loads(content.decode(), extra_decoders=extra_decoders) elif content_type.startswith("text/"): - r = response.text + r = content.decode() else: raise ValueError("Wrong content type `%s` for API response" % content_type) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,25 +1,38 @@ -# Copyright (C) 2018-2019 The Software Heritage developers +# Copyright (C) 2018-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import re +import urllib.parse +import aioresponses import pytest from requests.exceptions import ConnectionError -from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint +from swh.core.api import ( + APIError, + AsyncRPCClient, + RemoteException, + RPCClient, + remote_api_endpoint, +) from swh.core.api.serializers import exception_to_dict, msgpack_dumps from .test_serializers import ExtraType, extra_decoders, extra_encoders +@pytest.fixture +def aiohttp_mock(): + with aioresponses.aioresponses() as m: + yield m + + class ReraiseException(Exception): pass -@pytest.fixture -def rpc_client(requests_mock): +def make_rpc_client(cls): class TestStorage: @remote_api_endpoint("test_endpoint_url") def test_endpoint(self, test_data, db=None, cur=None): @@ -37,36 +50,90 @@ def overridden_method(self, data): return "foo" - class Testclient(RPCClient): + class BaseTestclient(cls): backend_class = TestStorage extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders reraise_exceptions = [ReraiseException] + return BaseTestclient + + +@pytest.fixture +def rpc_client(requests_mock): + BaseTestclient = make_rpc_client(RPCClient) + + class Testclient(BaseTestclient): def overridden_method(self, data): return "bar" + class MockResponse: + def __init__(self, text, status): + self._text = text + self.status = status + + async def text(self): + return self._text + + async def __aexit__(self, exc_type, exc, tb): + pass + + async def __aenter__(self): + return self + def callback(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" if request.path == "/test_endpoint_url": - context.content = b"\xa3egg" + content = b"\xa3egg" elif request.path == "/path/to/endpoint": - context.content = b"\xa4spam" + content = b"\xa4spam" elif request.path == "/serializer_test": - context.content = ( + content = ( b"\x82\xc4\x07swhtype\xa9extratype" b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" ) else: assert False - return context.content + return MockResponse(content, 200) requests_mock.post(re.compile("mock://example.com/"), content=callback) return Testclient(url="mock://example.com") +@pytest.fixture +def async_rpc_client(aiohttp_mock): + BaseTestclient = make_rpc_client(AsyncRPCClient) + + class Testclient(BaseTestclient): + async def overridden_method(self, data): + return asyncio.ensure_future("bar", loop=asyncio.get_running_loop()) + + async def callback(url, *, headers, **kwargs): + assert headers["content-type"] == "application/x-msgpack" + if url.path == "/test_endpoint_url": + content = b"\xa3egg" + elif url.path == "/path/to/endpoint": + content = b"\xa4spam" + elif url.path == "/serializer_test": + content = ( + b"\x82\xc4\x07swhtype\xa9extratype" + b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux" + ) + else: + assert False + return aioresponses.CallbackResult( + body=content, headers={"Content-Type": "application/x-msgpack"} + ) + + aiohttp_mock.post( + re.compile("mock://example.com/.*"), callback=callback, repeat=True + ) + + return Testclient(url="mock://example.com") + + def test_client(rpc_client): assert hasattr(rpc_client, "test_endpoint") @@ -83,6 +150,21 @@ assert res == "spam" +async def test_async_client(async_rpc_client, event_loop): + assert hasattr(async_rpc_client, "test_endpoint") + assert hasattr(async_rpc_client, "something") + + res = await async_rpc_client.test_endpoint("spam") + assert res == "egg" + res = await async_rpc_client.test_endpoint(test_data="spam") + assert res == "egg" + + res = await async_rpc_client.something("whatever") + assert res == "spam" + res = await async_rpc_client.something(data="whatever") + assert res == "spam" + + def test_client_extra_serializers(rpc_client): res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) assert res == ExtraType({"spam": "egg"}, "qux")