Page MenuHomeSoftware Heritage

D6651.diff
No OneTemporary

D6651.diff

diff --git a/mypy.ini b/mypy.ini
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,6 +8,9 @@
[mypy-aiohttp_utils.*]
ignore_missing_imports = True
+[mypy-aioresponses.*]
+ignore_missing_imports = True
+
[mypy-arrow.*]
ignore_missing_imports = True
diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py
--- a/swh/core/api/__init__.py
+++ b/swh/core/api/__init__.py
@@ -1,18 +1,22 @@
-# Copyright (C) 2015-2020 The Software Heritage developers
+# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import abc
+import contextlib
import functools
import inspect
import logging
import pickle
from typing import (
Any,
+ AsyncIterable,
Callable,
ClassVar,
Dict,
+ Generic,
+ Iterable,
List,
Optional,
Tuple,
@@ -21,6 +25,7 @@
Union,
)
+import aiohttp
from flask import Flask, Request, Response, abort, request
import requests
from werkzeug.exceptions import HTTPException
@@ -41,6 +46,14 @@
logger = logging.getLogger(__name__)
+# Genericity helpers
+
+TSession = TypeVar("TSession", requests.Session, aiohttp.ClientSession)
+"""Session class used by the underlying HTTP library"""
+TResponse = TypeVar("TResponse", requests.Response, aiohttp.ClientResponse)
+"""Response class used by the underlying HTTP library"""
+
+
# support for content negotiation
@@ -85,7 +98,7 @@
# base API classes
-class RemoteException(Exception):
+class RemoteException(Generic[TResponse], Exception):
"""raised when remote returned an out-of-band failure notification, e.g., as a
HTTP status code or serialized exception
@@ -95,15 +108,13 @@
"""
def __init__(
- self,
- payload: Optional[Any] = None,
- response: Optional[requests.Response] = None,
+ self, payload: Optional[Any] = None, response: Optional[TResponse] = None,
):
if payload is not None:
super().__init__(payload)
else:
super().__init__()
- self.response = response
+ self.response: Optional[TResponse] = response
def __str__(self):
if (
@@ -184,8 +195,8 @@
attributes[meth_name] = meth_
-class RPCClient(metaclass=MetaRPCClient):
- """Proxy to an internal SWH RPC
+class BaseRPCClient(Generic[TSession, TResponse], metaclass=MetaRPCClient):
+ """Base class for :class:`RPCClient` and :class:`AsyncRPCClient`.
"""
@@ -229,21 +240,111 @@
self.reraise_exceptions = reraise_exceptions
base_url = url if url.endswith("/") else url + "/"
self.url = base_url
- self.session = requests.Session()
+
+ self.timeout = timeout
+ self.chunk_size = chunk_size
+
+ self.session = self._make_session(kwargs)
+
+ def _url(self, endpoint):
+ return "%s%s" % (self.url, endpoint)
+
+ def _get_status_code(self, response: TResponse) -> int:
+ raise NotImplementedError("_get_status_code")
+
+ def _decode_response(self, response: TResponse, check_status=True) -> Any:
+ raise NotImplementedError("_decode_response")
+
+ def raise_for_status(self, response: TResponse) -> None:
+ """check response HTTP status code and raise an exception if it denotes an
+ error; do nothing otherwise
+
+ """
+ status_code = self._get_status_code(response)
+ status_class = status_code // 100
+
+ if status_code == 404:
+ raise RemoteException(payload="404 not found", response=response)
+
+ exception = None
+
+ # TODO: only old servers send pickled error; stop trying to unpickle
+ # after they are all upgraded
+ try:
+ if status_class == 4:
+ data = self._decode_response(response, check_status=False)
+ if isinstance(data, dict):
+ # TODO: remove "exception" key check once all servers
+ # are using new schema
+ exc_data = data["exception"] if "exception" in data else data
+ for exc_type in self.reraise_exceptions:
+ if exc_type.__name__ == exc_data["type"]:
+ exception = exc_type(*exc_data["args"])
+ break
+ else:
+ exception = RemoteException(payload=exc_data, response=response)
+ else:
+ exception = pickle.loads(data)
+
+ elif status_class == 5:
+ data = self._decode_response(response, check_status=False)
+ if "exception_pickled" in data:
+ exception = pickle.loads(data["exception_pickled"])
+ else:
+ # TODO: remove "exception" key check once all servers
+ # are using new schema
+ exc_data = data["exception"] if "exception" in data else data
+ exception = RemoteException(payload=exc_data, response=response)
+
+ except (TypeError, pickle.UnpicklingError):
+ raise RemoteException(payload=data, response=response)
+
+ if exception:
+ raise exception from None
+
+ if status_class != 2:
+ raise RemoteException(
+ payload=f"API HTTP error: {status_code} {response.content}",
+ response=response,
+ )
+
+ def __repr__(self):
+ return "<{} url={}>".format(self.__class__.__name__, self.url)
+
+
+class RPCClient(BaseRPCClient[requests.Session, requests.Response]):
+ """Proxy to an internal SWH RPC
+
+ """
+
+ def _make_session(self, kwargs):
+ session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
max_retries=kwargs.get("max_retries", 3),
pool_connections=kwargs.get("pool_connections", 20),
pool_maxsize=kwargs.get("pool_maxsize", 100),
)
- self.session.mount(self.url, adapter)
+ session.mount(self.url, adapter)
+ return session
- self.timeout = timeout
- self.chunk_size = chunk_size
+ def _get_status_code(self, response: requests.Response) -> int:
+ return response.status_code
- def _url(self, endpoint):
- return "%s%s" % (self.url, endpoint)
+ def _decode_response(self, response: requests.Response, check_status=True) -> Any:
+ if check_status:
+ self.raise_for_status(response)
+ return decode_response(
+ response.content,
+ response.headers["content-type"],
+ extra_decoders=self.extra_type_decoders,
+ )
+
+ def _stream_response(
+ self, response: requests.Response, chunk_size: int
+ ) -> Iterable[Any]:
+ return response.iter_content(chunk_size)
- def raw_verb(self, verb, endpoint, **opts):
+ def raw_verb(self, verb: str, endpoint: str, **opts) -> TResponse:
if "chunk_size" in opts:
# if the chunk_size argument has been passed, consider the user
# also wants stream=True, otherwise, what's the point.
@@ -255,7 +356,7 @@
except requests.exceptions.ConnectionError as e:
raise self.api_exception(e)
- def post(self, endpoint, data, **opts):
+ def post(self, endpoint: str, data: Any, **opts) -> Any:
if isinstance(data, (abc.Iterator, abc.Generator)):
data = (self._encode_data(x) for x in data)
else:
@@ -273,89 +374,125 @@
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
- return response.iter_content(chunk_size)
+ return self._stream_response(response, chunk_size)
else:
return self._decode_response(response)
- def _encode_data(self, data):
+ def _encode_data(self, data: Any) -> bytes:
return encode_data(data, extra_encoders=self.extra_type_encoders)
post_stream = post
- def get(self, endpoint, **opts):
+ def get(self, endpoint: str, **opts) -> Any:
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
- return response.iter_content(chunk_size)
+ return self._stream_response(response, chunk_size)
else:
return self._decode_response(response)
- def get_stream(self, endpoint, **opts):
+ def get_stream(self, endpoint: str, **opts) -> Iterable[Any]:
return self.get(endpoint, stream=True, **opts)
- def raise_for_status(self, response) -> None:
- """check response HTTP status code and raise an exception if it denotes an
- error; do nothing otherwise
- """
- status_code = response.status_code
- status_class = response.status_code // 100
+class AsyncRPCClient(BaseRPCClient[aiohttp.ClientSession, aiohttp.ClientResponse]):
+ """Asynchronous proxy to an internal SWH RPC
- if status_code == 404:
- raise RemoteException(payload="404 not found", response=response)
+ """
- exception = None
+ def _make_session(self, kwargs):
+ return aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(limit=kwargs.get("pool_maxsize", 100),),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
+ connector_owner=True,
+ )
- # TODO: only old servers send pickled error; stop trying to unpickle
- # after they are all upgraded
- try:
- if status_class == 4:
- data = self._decode_response(response, check_status=False)
- if isinstance(data, dict):
- # TODO: remove "exception" key check once all servers
- # are using new schema
- exc_data = data["exception"] if "exception" in data else data
- for exc_type in self.reraise_exceptions:
- if exc_type.__name__ == exc_data["type"]:
- exception = exc_type(*exc_data["args"])
- break
- else:
- exception = RemoteException(payload=exc_data, response=response)
- else:
- exception = pickle.loads(data)
+ async def __aenter__(self):
+ print("open")
+ await self.session.__aenter__()
- elif status_class == 5:
- data = self._decode_response(response, check_status=False)
- if "exception_pickled" in data:
- exception = pickle.loads(data["exception_pickled"])
- else:
- # TODO: remove "exception" key check once all servers
- # are using new schema
- exc_data = data["exception"] if "exception" in data else data
- exception = RemoteException(payload=exc_data, response=response)
+ async def __aexit__(self, *args):
+ print("close")
+ await self.session.__aexit__(*args)
- except (TypeError, pickle.UnpicklingError):
- raise RemoteException(payload=data, response=response)
+ def _get_status_code(self, response: aiohttp.ClientResponse) -> int:
+ return response.status
- if exception:
- raise exception from None
-
- if status_class != 2:
- raise RemoteException(
- payload=f"API HTTP error: {status_code} {response.content}",
- response=response,
- )
-
- def _decode_response(self, response, check_status=True):
+ async def _decode_response(
+ self, response: aiohttp.ClientResponse, check_status=True
+ ) -> Any:
if check_status:
self.raise_for_status(response)
- return decode_response(response, extra_decoders=self.extra_type_decoders)
+ return decode_response(
+ await response.read(),
+ response.headers["content-type"],
+ extra_decoders=self.extra_type_decoders,
+ )
- def __repr__(self):
- return "<{} url={}>".format(self.__class__.__name__, self.url)
+ @contextlib.asynccontextmanager
+ async def raw_verb(self, verb: str, endpoint: str, **opts):
+ if "chunk_size" in opts:
+ # if the chunk_size argument has been passed, consider the user
+ # also wants stream=True, otherwise, what's the point.
+ opts["stream"] = True
+ if self.timeout and "timeout" not in opts:
+ opts["timeout"] = self.timeout
+ try:
+ async with getattr(self.session, verb)(self._url(endpoint), **opts) as r:
+ yield r
+ except requests.exceptions.ConnectionError as e:
+ raise self.api_exception(e)
+
+ async def post(self, endpoint: str, data: Any, **opts) -> Any:
+ if isinstance(data, (abc.Iterator, abc.Generator)):
+ data = (self._encode_data(x) for x in data)
+ else:
+ data = self._encode_data(data)
+ chunk_size = opts.pop("chunk_size", self.chunk_size)
+ async with self.raw_verb(
+ "post",
+ endpoint,
+ data=data,
+ headers={
+ "content-type": "application/x-msgpack",
+ "accept": "application/x-msgpack",
+ },
+ **opts,
+ ) as response:
+ if (
+ opts.get("stream")
+ or response.headers.get("transfer-encoding") == "chunked"
+ ):
+ self.raise_for_status(response)
+ return await self._stream_response(response, chunk_size)
+ else:
+ return await self._decode_response(response)
+
+ def _encode_data(self, data: Any) -> bytes:
+ return encode_data(data, extra_encoders=self.extra_type_encoders)
+
+ post_stream = post
+
+ async def get(self, endpoint: str, **opts) -> Any:
+ chunk_size = opts.pop("chunk_size", self.chunk_size)
+ async with self.raw_verb(
+ "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
+ ) as response:
+ if (
+ opts.get("stream")
+ or response.headers.get("transfer-encoding") == "chunked"
+ ):
+ self.raise_for_status(response)
+ return await self._stream_response(response, chunk_size)
+ else:
+ return await self._decode_response(response)
+
+ async def get_stream(self, endpoint: str, **opts) -> AsyncIterable[Any]:
+ async for v in self.get(endpoint, stream=True, **opts):
+ yield v
class BytesRequest(Request):
diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py
--- a/swh/core/api/classes.py
+++ b/swh/core/api/classes.py
@@ -3,9 +3,19 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import asyncio
from dataclasses import dataclass, field
import itertools
-from typing import Callable, Generic, Iterable, List, Optional, TypeVar
+from typing import (
+ AsyncIterable,
+ Awaitable,
+ Callable,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ TypeVar,
+)
TResult = TypeVar("TResult")
TToken = TypeVar("TToken")
@@ -59,3 +69,64 @@
res.results,
_stream_results(f, *args, page_token=res.next_page_token, **kwargs),
)
+
+
+async def asyncchain(iterable1, iterable2):
+ """Like itertools.chain, but async"""
+ async for item in iterable1:
+ yield item
+ async for item in iterable2:
+ yield item
+
+
+async def asynciter(iterable):
+ """Like iter(), but async"""
+ for item in iterable:
+ yield item
+
+
+async def _stream_results_async(f, *args, page_token, **kwargs):
+ """Helper for stream_results_async() and stream_results_optional_async()"""
+ while True:
+ page_result = await f(*args, page_token=page_token, **kwargs)
+ for res in page_result.results:
+ yield res
+ page_token = page_result.next_page_token
+ if page_token is None:
+ break
+
+
+async def stream_results_async(
+ f: Callable[..., Awaitable[PagedResult[TResult, TToken]]], *args, **kwargs
+) -> AsyncIterable[TResult]:
+ """Consume the paginated result and stream the page results
+
+ """
+ if "page_token" in kwargs:
+ raise TypeError('stream_results has no argument "page_token".')
+ return await _stream_results(f, *args, page_token=None, **kwargs)
+
+
+async def stream_results_optional_async(
+ f: Callable[..., Awaitable[Optional[PagedResult[TResult, TToken]]]], *args, **kwargs
+) -> Awaitable[Optional[AsyncIterable[TResult]]]:
+ """Like stream_results(), but for functions ``f`` that return an Optional.
+
+ """
+ if "page_token" in kwargs:
+ raise TypeError('stream_results_optional has no argument "page_token".')
+ res = await f(*args, page_token=None, **kwargs)
+ if res is None:
+ future: asyncio.Future[None] = asyncio.Future(loop=asyncio.get_running_loop())
+ future.set_result(None)
+ return future
+ else:
+ if res.next_page_token is None:
+ return asynciter(res.results)
+ else:
+ return asyncchain(
+ asynciter(res.results),
+ _stream_results_async(
+ f, *args, page_token=res.next_page_token, **kwargs
+ ),
+ )
diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py
--- a/swh/core/api/serializers.py
+++ b/swh/core/api/serializers.py
@@ -12,6 +12,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
+import aiohttp
import iso8601
import msgpack
from requests import Response
@@ -128,15 +129,13 @@
raise ValueError("Limits were reached. Please, check your input.\n" + str(e))
-def decode_response(response: Response, extra_decoders=None) -> Any:
- content_type = response.headers["content-type"]
-
+def decode_response(content: bytes, content_type, extra_decoders=None) -> Any:
if content_type.startswith("application/x-msgpack"):
- r = msgpack_loads(response.content, extra_decoders=extra_decoders)
+ r = msgpack_loads(content, extra_decoders=extra_decoders)
elif content_type.startswith("application/json"):
- r = json_loads(response.text, extra_decoders=extra_decoders)
+ r = json_loads(content.decode(), extra_decoders=extra_decoders)
elif content_type.startswith("text/"):
- r = response.text
+ r = content.decode()
else:
raise ValueError("Wrong content type `%s` for API response" % content_type)
diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py
--- a/swh/core/api/tests/test_rpc_client.py
+++ b/swh/core/api/tests/test_rpc_client.py
@@ -1,25 +1,38 @@
-# Copyright (C) 2018-2019 The Software Heritage developers
+# Copyright (C) 2018-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
+import urllib.parse
+import aioresponses
import pytest
from requests.exceptions import ConnectionError
-from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint
+from swh.core.api import (
+ APIError,
+ AsyncRPCClient,
+ RemoteException,
+ RPCClient,
+ remote_api_endpoint,
+)
from swh.core.api.serializers import exception_to_dict, msgpack_dumps
from .test_serializers import ExtraType, extra_decoders, extra_encoders
+@pytest.fixture
+def aiohttp_mock():
+ with aioresponses.aioresponses() as m:
+ yield m
+
+
class ReraiseException(Exception):
pass
-@pytest.fixture
-def rpc_client(requests_mock):
+def make_rpc_client(cls):
class TestStorage:
@remote_api_endpoint("test_endpoint_url")
def test_endpoint(self, test_data, db=None, cur=None):
@@ -37,36 +50,90 @@
def overridden_method(self, data):
return "foo"
- class Testclient(RPCClient):
+ class BaseTestclient(cls):
backend_class = TestStorage
extra_type_encoders = extra_encoders
extra_type_decoders = extra_decoders
reraise_exceptions = [ReraiseException]
+ return BaseTestclient
+
+
+@pytest.fixture
+def rpc_client(requests_mock):
+ BaseTestclient = make_rpc_client(RPCClient)
+
+ class Testclient(BaseTestclient):
def overridden_method(self, data):
return "bar"
+ class MockResponse:
+ def __init__(self, text, status):
+ self._text = text
+ self.status = status
+
+ async def text(self):
+ return self._text
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+ async def __aenter__(self):
+ return self
+
def callback(request, context):
assert request.headers["Content-Type"] == "application/x-msgpack"
context.headers["Content-Type"] = "application/x-msgpack"
if request.path == "/test_endpoint_url":
- context.content = b"\xa3egg"
+ content = b"\xa3egg"
elif request.path == "/path/to/endpoint":
- context.content = b"\xa4spam"
+ content = b"\xa4spam"
elif request.path == "/serializer_test":
- context.content = (
+ content = (
b"\x82\xc4\x07swhtype\xa9extratype"
b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux"
)
else:
assert False
- return context.content
+ return MockResponse(content, 200)
requests_mock.post(re.compile("mock://example.com/"), content=callback)
return Testclient(url="mock://example.com")
+@pytest.fixture
+def async_rpc_client(aiohttp_mock):
+ BaseTestclient = make_rpc_client(AsyncRPCClient)
+
+ class Testclient(BaseTestclient):
+ async def overridden_method(self, data):
+ return asyncio.ensure_future("bar", loop=asyncio.get_running_loop())
+
+ async def callback(url, *, headers, **kwargs):
+ assert headers["content-type"] == "application/x-msgpack"
+ if url.path == "/test_endpoint_url":
+ content = b"\xa3egg"
+ elif url.path == "/path/to/endpoint":
+ content = b"\xa4spam"
+ elif url.path == "/serializer_test":
+ content = (
+ b"\x82\xc4\x07swhtype\xa9extratype"
+ b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux"
+ )
+ else:
+ assert False
+ return aioresponses.CallbackResult(
+ body=content, headers={"Content-Type": "application/x-msgpack"}
+ )
+
+ aiohttp_mock.post(
+ re.compile("mock://example.com/.*"), callback=callback, repeat=True
+ )
+
+ return Testclient(url="mock://example.com")
+
+
def test_client(rpc_client):
assert hasattr(rpc_client, "test_endpoint")
@@ -83,6 +150,21 @@
assert res == "spam"
+async def test_async_client(async_rpc_client, event_loop):
+ assert hasattr(async_rpc_client, "test_endpoint")
+ assert hasattr(async_rpc_client, "something")
+
+ res = await async_rpc_client.test_endpoint("spam")
+ assert res == "egg"
+ res = await async_rpc_client.test_endpoint(test_data="spam")
+ assert res == "egg"
+
+ res = await async_rpc_client.something("whatever")
+ assert res == "spam"
+ res = await async_rpc_client.something(data="whatever")
+ assert res == "spam"
+
+
def test_client_extra_serializers(rpc_client):
res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")])
assert res == ExtraType({"spam": "egg"}, "qux")

File Metadata

Mime Type
text/plain
Expires
Jul 3 2025, 8:00 AM (10 w, 3 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3218581

Event Timeline