Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9337309
D6651.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
23 KB
Subscribers
None
D6651.diff
View Options
diff --git a/mypy.ini b/mypy.ini
--- a/mypy.ini
+++ b/mypy.ini
@@ -8,6 +8,9 @@
[mypy-aiohttp_utils.*]
ignore_missing_imports = True
+[mypy-aioresponses.*]
+ignore_missing_imports = True
+
[mypy-arrow.*]
ignore_missing_imports = True
diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py
--- a/swh/core/api/__init__.py
+++ b/swh/core/api/__init__.py
@@ -1,18 +1,22 @@
-# Copyright (C) 2015-2020 The Software Heritage developers
+# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import abc
+import contextlib
import functools
import inspect
import logging
import pickle
from typing import (
Any,
+ AsyncIterable,
Callable,
ClassVar,
Dict,
+ Generic,
+ Iterable,
List,
Optional,
Tuple,
@@ -21,6 +25,7 @@
Union,
)
+import aiohttp
from flask import Flask, Request, Response, abort, request
import requests
from werkzeug.exceptions import HTTPException
@@ -41,6 +46,14 @@
logger = logging.getLogger(__name__)
+# Genericity helpers
+
+TSession = TypeVar("TSession", requests.Session, aiohttp.ClientSession)
+"""Session class used by the underlying HTTP library"""
+TResponse = TypeVar("TResponse", requests.Response, aiohttp.ClientResponse)
+"""Response class used by the underlying HTTP library"""
+
+
# support for content negotiation
@@ -85,7 +98,7 @@
# base API classes
-class RemoteException(Exception):
+class RemoteException(Generic[TResponse], Exception):
"""raised when remote returned an out-of-band failure notification, e.g., as a
HTTP status code or serialized exception
@@ -95,15 +108,13 @@
"""
def __init__(
- self,
- payload: Optional[Any] = None,
- response: Optional[requests.Response] = None,
+ self, payload: Optional[Any] = None, response: Optional[TResponse] = None,
):
if payload is not None:
super().__init__(payload)
else:
super().__init__()
- self.response = response
+ self.response: Optional[TResponse] = response
def __str__(self):
if (
@@ -184,8 +195,8 @@
attributes[meth_name] = meth_
-class RPCClient(metaclass=MetaRPCClient):
- """Proxy to an internal SWH RPC
+class BaseRPCClient(Generic[TSession, TResponse], metaclass=MetaRPCClient):
+ """Base class for :class:`RPCClient` and :class:`AsyncRPCClient`.
"""
@@ -229,21 +240,111 @@
self.reraise_exceptions = reraise_exceptions
base_url = url if url.endswith("/") else url + "/"
self.url = base_url
- self.session = requests.Session()
+
+ self.timeout = timeout
+ self.chunk_size = chunk_size
+
+ self.session = self._make_session(kwargs)
+
+ def _url(self, endpoint):
+ return "%s%s" % (self.url, endpoint)
+
+ def _get_status_code(self, response: TResponse) -> int:
+ raise NotImplementedError("_get_status_code")
+
+ def _decode_response(self, response: TResponse, check_status=True) -> Any:
+ raise NotImplementedError("_decode_response")
+
+ def raise_for_status(self, response: TResponse) -> None:
+ """check response HTTP status code and raise an exception if it denotes an
+ error; do nothing otherwise
+
+ """
+ status_code = self._get_status_code(response)
+ status_class = status_code // 100
+
+ if status_code == 404:
+ raise RemoteException(payload="404 not found", response=response)
+
+ exception = None
+
+ # TODO: only old servers send pickled error; stop trying to unpickle
+ # after they are all upgraded
+ try:
+ if status_class == 4:
+ data = self._decode_response(response, check_status=False)
+ if isinstance(data, dict):
+ # TODO: remove "exception" key check once all servers
+ # are using new schema
+ exc_data = data["exception"] if "exception" in data else data
+ for exc_type in self.reraise_exceptions:
+ if exc_type.__name__ == exc_data["type"]:
+ exception = exc_type(*exc_data["args"])
+ break
+ else:
+ exception = RemoteException(payload=exc_data, response=response)
+ else:
+ exception = pickle.loads(data)
+
+ elif status_class == 5:
+ data = self._decode_response(response, check_status=False)
+ if "exception_pickled" in data:
+ exception = pickle.loads(data["exception_pickled"])
+ else:
+ # TODO: remove "exception" key check once all servers
+ # are using new schema
+ exc_data = data["exception"] if "exception" in data else data
+ exception = RemoteException(payload=exc_data, response=response)
+
+ except (TypeError, pickle.UnpicklingError):
+ raise RemoteException(payload=data, response=response)
+
+ if exception:
+ raise exception from None
+
+ if status_class != 2:
+ raise RemoteException(
+ payload=f"API HTTP error: {status_code} {response.content}",
+ response=response,
+ )
+
+ def __repr__(self):
+ return "<{} url={}>".format(self.__class__.__name__, self.url)
+
+
+class RPCClient(BaseRPCClient[requests.Session, requests.Response]):
+ """Proxy to an internal SWH RPC
+
+ """
+
+ def _make_session(self, kwargs):
+ session = requests.Session()
adapter = requests.adapters.HTTPAdapter(
max_retries=kwargs.get("max_retries", 3),
pool_connections=kwargs.get("pool_connections", 20),
pool_maxsize=kwargs.get("pool_maxsize", 100),
)
- self.session.mount(self.url, adapter)
+ session.mount(self.url, adapter)
+ return session
- self.timeout = timeout
- self.chunk_size = chunk_size
+ def _get_status_code(self, response: requests.Response) -> int:
+ return response.status_code
- def _url(self, endpoint):
- return "%s%s" % (self.url, endpoint)
+ def _decode_response(self, response: requests.Response, check_status=True) -> Any:
+ if check_status:
+ self.raise_for_status(response)
+ return decode_response(
+ response.content,
+ response.headers["content-type"],
+ extra_decoders=self.extra_type_decoders,
+ )
+
+ def _stream_response(
+ self, response: requests.Response, chunk_size: int
+ ) -> Iterable[Any]:
+ return response.iter_content(chunk_size)
- def raw_verb(self, verb, endpoint, **opts):
+ def raw_verb(self, verb: str, endpoint: str, **opts) -> TResponse:
if "chunk_size" in opts:
# if the chunk_size argument has been passed, consider the user
# also wants stream=True, otherwise, what's the point.
@@ -255,7 +356,7 @@
except requests.exceptions.ConnectionError as e:
raise self.api_exception(e)
- def post(self, endpoint, data, **opts):
+ def post(self, endpoint: str, data: Any, **opts) -> Any:
if isinstance(data, (abc.Iterator, abc.Generator)):
data = (self._encode_data(x) for x in data)
else:
@@ -273,89 +374,125 @@
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
- return response.iter_content(chunk_size)
+ return self._stream_response(response, chunk_size)
else:
return self._decode_response(response)
- def _encode_data(self, data):
+ def _encode_data(self, data: Any) -> bytes:
return encode_data(data, extra_encoders=self.extra_type_encoders)
post_stream = post
- def get(self, endpoint, **opts):
+ def get(self, endpoint: str, **opts) -> Any:
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
- return response.iter_content(chunk_size)
+ return self._stream_response(response, chunk_size)
else:
return self._decode_response(response)
- def get_stream(self, endpoint, **opts):
+ def get_stream(self, endpoint: str, **opts) -> Iterable[Any]:
return self.get(endpoint, stream=True, **opts)
- def raise_for_status(self, response) -> None:
- """check response HTTP status code and raise an exception if it denotes an
- error; do nothing otherwise
- """
- status_code = response.status_code
- status_class = response.status_code // 100
+class AsyncRPCClient(BaseRPCClient[aiohttp.ClientSession, aiohttp.ClientResponse]):
+ """Asynchronous proxy to an internal SWH RPC
- if status_code == 404:
- raise RemoteException(payload="404 not found", response=response)
+ """
- exception = None
+ def _make_session(self, kwargs):
+ return aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(limit=kwargs.get("pool_maxsize", 100),),
+ timeout=aiohttp.ClientTimeout(total=self.timeout),
+ connector_owner=True,
+ )
- # TODO: only old servers send pickled error; stop trying to unpickle
- # after they are all upgraded
- try:
- if status_class == 4:
- data = self._decode_response(response, check_status=False)
- if isinstance(data, dict):
- # TODO: remove "exception" key check once all servers
- # are using new schema
- exc_data = data["exception"] if "exception" in data else data
- for exc_type in self.reraise_exceptions:
- if exc_type.__name__ == exc_data["type"]:
- exception = exc_type(*exc_data["args"])
- break
- else:
- exception = RemoteException(payload=exc_data, response=response)
- else:
- exception = pickle.loads(data)
+ async def __aenter__(self):
+ print("open")
+ await self.session.__aenter__()
- elif status_class == 5:
- data = self._decode_response(response, check_status=False)
- if "exception_pickled" in data:
- exception = pickle.loads(data["exception_pickled"])
- else:
- # TODO: remove "exception" key check once all servers
- # are using new schema
- exc_data = data["exception"] if "exception" in data else data
- exception = RemoteException(payload=exc_data, response=response)
+ async def __aexit__(self, *args):
+ print("close")
+ await self.session.__aexit__(*args)
- except (TypeError, pickle.UnpicklingError):
- raise RemoteException(payload=data, response=response)
+ def _get_status_code(self, response: aiohttp.ClientResponse) -> int:
+ return response.status
- if exception:
- raise exception from None
-
- if status_class != 2:
- raise RemoteException(
- payload=f"API HTTP error: {status_code} {response.content}",
- response=response,
- )
-
- def _decode_response(self, response, check_status=True):
+ async def _decode_response(
+ self, response: aiohttp.ClientResponse, check_status=True
+ ) -> Any:
if check_status:
self.raise_for_status(response)
- return decode_response(response, extra_decoders=self.extra_type_decoders)
+ return decode_response(
+ await response.read(),
+ response.headers["content-type"],
+ extra_decoders=self.extra_type_decoders,
+ )
- def __repr__(self):
- return "<{} url={}>".format(self.__class__.__name__, self.url)
+ @contextlib.asynccontextmanager
+ async def raw_verb(self, verb: str, endpoint: str, **opts):
+ if "chunk_size" in opts:
+ # if the chunk_size argument has been passed, consider the user
+ # also wants stream=True, otherwise, what's the point.
+ opts["stream"] = True
+ if self.timeout and "timeout" not in opts:
+ opts["timeout"] = self.timeout
+ try:
+ async with getattr(self.session, verb)(self._url(endpoint), **opts) as r:
+ yield r
+ except requests.exceptions.ConnectionError as e:
+ raise self.api_exception(e)
+
+ async def post(self, endpoint: str, data: Any, **opts) -> Any:
+ if isinstance(data, (abc.Iterator, abc.Generator)):
+ data = (self._encode_data(x) for x in data)
+ else:
+ data = self._encode_data(data)
+ chunk_size = opts.pop("chunk_size", self.chunk_size)
+ async with self.raw_verb(
+ "post",
+ endpoint,
+ data=data,
+ headers={
+ "content-type": "application/x-msgpack",
+ "accept": "application/x-msgpack",
+ },
+ **opts,
+ ) as response:
+ if (
+ opts.get("stream")
+ or response.headers.get("transfer-encoding") == "chunked"
+ ):
+ self.raise_for_status(response)
+ return await self._stream_response(response, chunk_size)
+ else:
+ return await self._decode_response(response)
+
+ def _encode_data(self, data: Any) -> bytes:
+ return encode_data(data, extra_encoders=self.extra_type_encoders)
+
+ post_stream = post
+
+ async def get(self, endpoint: str, **opts) -> Any:
+ chunk_size = opts.pop("chunk_size", self.chunk_size)
+ async with self.raw_verb(
+ "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
+ ) as response:
+ if (
+ opts.get("stream")
+ or response.headers.get("transfer-encoding") == "chunked"
+ ):
+ self.raise_for_status(response)
+ return await self._stream_response(response, chunk_size)
+ else:
+ return await self._decode_response(response)
+
+ async def get_stream(self, endpoint: str, **opts) -> AsyncIterable[Any]:
+ async for v in self.get(endpoint, stream=True, **opts):
+ yield v
class BytesRequest(Request):
diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py
--- a/swh/core/api/classes.py
+++ b/swh/core/api/classes.py
@@ -3,9 +3,19 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
+import asyncio
from dataclasses import dataclass, field
import itertools
-from typing import Callable, Generic, Iterable, List, Optional, TypeVar
+from typing import (
+ AsyncIterable,
+ Awaitable,
+ Callable,
+ Generic,
+ Iterable,
+ List,
+ Optional,
+ TypeVar,
+)
TResult = TypeVar("TResult")
TToken = TypeVar("TToken")
@@ -59,3 +69,64 @@
res.results,
_stream_results(f, *args, page_token=res.next_page_token, **kwargs),
)
+
+
+async def asyncchain(iterable1, iterable2):
+ """Like itertools.chain, but async"""
+ async for item in iterable1:
+ yield item
+ async for item in iterable2:
+ yield item
+
+
+async def asynciter(iterable):
+ """Like iter(), but async"""
+ for item in iterable:
+ yield item
+
+
+async def _stream_results_async(f, *args, page_token, **kwargs):
+ """Helper for stream_results_async() and stream_results_optional_async()"""
+ while True:
+ page_result = await f(*args, page_token=page_token, **kwargs)
+ for res in page_result.results:
+ yield res
+ page_token = page_result.next_page_token
+ if page_token is None:
+ break
+
+
+async def stream_results_async(
+ f: Callable[..., Awaitable[PagedResult[TResult, TToken]]], *args, **kwargs
+) -> AsyncIterable[TResult]:
+ """Consume the paginated result and stream the page results
+
+ """
+ if "page_token" in kwargs:
+ raise TypeError('stream_results has no argument "page_token".')
+ return await _stream_results(f, *args, page_token=None, **kwargs)
+
+
+async def stream_results_optional_async(
+ f: Callable[..., Awaitable[Optional[PagedResult[TResult, TToken]]]], *args, **kwargs
+) -> Awaitable[Optional[AsyncIterable[TResult]]]:
+ """Like stream_results(), but for functions ``f`` that return an Optional.
+
+ """
+ if "page_token" in kwargs:
+ raise TypeError('stream_results_optional has no argument "page_token".')
+ res = await f(*args, page_token=None, **kwargs)
+ if res is None:
+ future: asyncio.Future[None] = asyncio.Future(loop=asyncio.get_running_loop())
+ future.set_result(None)
+ return future
+ else:
+ if res.next_page_token is None:
+ return asynciter(res.results)
+ else:
+ return asyncchain(
+ asynciter(res.results),
+ _stream_results_async(
+ f, *args, page_token=res.next_page_token, **kwargs
+ ),
+ )
diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py
--- a/swh/core/api/serializers.py
+++ b/swh/core/api/serializers.py
@@ -12,6 +12,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
+import aiohttp
import iso8601
import msgpack
from requests import Response
@@ -128,15 +129,13 @@
raise ValueError("Limits were reached. Please, check your input.\n" + str(e))
-def decode_response(response: Response, extra_decoders=None) -> Any:
- content_type = response.headers["content-type"]
-
+def decode_response(content: bytes, content_type, extra_decoders=None) -> Any:
if content_type.startswith("application/x-msgpack"):
- r = msgpack_loads(response.content, extra_decoders=extra_decoders)
+ r = msgpack_loads(content, extra_decoders=extra_decoders)
elif content_type.startswith("application/json"):
- r = json_loads(response.text, extra_decoders=extra_decoders)
+ r = json_loads(content.decode(), extra_decoders=extra_decoders)
elif content_type.startswith("text/"):
- r = response.text
+ r = content.decode()
else:
raise ValueError("Wrong content type `%s` for API response" % content_type)
diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py
--- a/swh/core/api/tests/test_rpc_client.py
+++ b/swh/core/api/tests/test_rpc_client.py
@@ -1,25 +1,38 @@
-# Copyright (C) 2018-2019 The Software Heritage developers
+# Copyright (C) 2018-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import re
+import urllib.parse
+import aioresponses
import pytest
from requests.exceptions import ConnectionError
-from swh.core.api import APIError, RemoteException, RPCClient, remote_api_endpoint
+from swh.core.api import (
+ APIError,
+ AsyncRPCClient,
+ RemoteException,
+ RPCClient,
+ remote_api_endpoint,
+)
from swh.core.api.serializers import exception_to_dict, msgpack_dumps
from .test_serializers import ExtraType, extra_decoders, extra_encoders
+@pytest.fixture
+def aiohttp_mock():
+ with aioresponses.aioresponses() as m:
+ yield m
+
+
class ReraiseException(Exception):
pass
-@pytest.fixture
-def rpc_client(requests_mock):
+def make_rpc_client(cls):
class TestStorage:
@remote_api_endpoint("test_endpoint_url")
def test_endpoint(self, test_data, db=None, cur=None):
@@ -37,36 +50,90 @@
def overridden_method(self, data):
return "foo"
- class Testclient(RPCClient):
+ class BaseTestclient(cls):
backend_class = TestStorage
extra_type_encoders = extra_encoders
extra_type_decoders = extra_decoders
reraise_exceptions = [ReraiseException]
+ return BaseTestclient
+
+
+@pytest.fixture
+def rpc_client(requests_mock):
+ BaseTestclient = make_rpc_client(RPCClient)
+
+ class Testclient(BaseTestclient):
def overridden_method(self, data):
return "bar"
+ class MockResponse:
+ def __init__(self, text, status):
+ self._text = text
+ self.status = status
+
+ async def text(self):
+ return self._text
+
+ async def __aexit__(self, exc_type, exc, tb):
+ pass
+
+ async def __aenter__(self):
+ return self
+
def callback(request, context):
assert request.headers["Content-Type"] == "application/x-msgpack"
context.headers["Content-Type"] = "application/x-msgpack"
if request.path == "/test_endpoint_url":
- context.content = b"\xa3egg"
+ content = b"\xa3egg"
elif request.path == "/path/to/endpoint":
- context.content = b"\xa4spam"
+ content = b"\xa4spam"
elif request.path == "/serializer_test":
- context.content = (
+ content = (
b"\x82\xc4\x07swhtype\xa9extratype"
b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux"
)
else:
assert False
- return context.content
+ return MockResponse(content, 200)
requests_mock.post(re.compile("mock://example.com/"), content=callback)
return Testclient(url="mock://example.com")
+@pytest.fixture
+def async_rpc_client(aiohttp_mock):
+ BaseTestclient = make_rpc_client(AsyncRPCClient)
+
+ class Testclient(BaseTestclient):
+ async def overridden_method(self, data):
+ return asyncio.ensure_future("bar", loop=asyncio.get_running_loop())
+
+ async def callback(url, *, headers, **kwargs):
+ assert headers["content-type"] == "application/x-msgpack"
+ if url.path == "/test_endpoint_url":
+ content = b"\xa3egg"
+ elif url.path == "/path/to/endpoint":
+ content = b"\xa4spam"
+ elif url.path == "/serializer_test":
+ content = (
+ b"\x82\xc4\x07swhtype\xa9extratype"
+ b"\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux"
+ )
+ else:
+ assert False
+ return aioresponses.CallbackResult(
+ body=content, headers={"Content-Type": "application/x-msgpack"}
+ )
+
+ aiohttp_mock.post(
+ re.compile("mock://example.com/.*"), callback=callback, repeat=True
+ )
+
+ return Testclient(url="mock://example.com")
+
+
def test_client(rpc_client):
assert hasattr(rpc_client, "test_endpoint")
@@ -83,6 +150,21 @@
assert res == "spam"
+async def test_async_client(async_rpc_client, event_loop):
+ assert hasattr(async_rpc_client, "test_endpoint")
+ assert hasattr(async_rpc_client, "something")
+
+ res = await async_rpc_client.test_endpoint("spam")
+ assert res == "egg"
+ res = await async_rpc_client.test_endpoint(test_data="spam")
+ assert res == "egg"
+
+ res = await async_rpc_client.something("whatever")
+ assert res == "spam"
+ res = await async_rpc_client.something(data="whatever")
+ assert res == "spam"
+
+
def test_client_extra_serializers(rpc_client):
res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")])
assert res == ExtraType({"spam": "egg"}, "qux")
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Jul 3 2025, 8:00 AM (10 w, 3 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3218581
Attached To
D6651: WIP: Add AsyncRPCClient
Event Timeline
Log In to Comment