Changeset View
Changeset View
Standalone View
Standalone View
swh/core/api/__init__.py
# Copyright (C) 2015-2020 The Software Heritage developers | # Copyright (C) 2015-2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import abc | from collections import abc | ||||
import contextlib | |||||
import functools | import functools | ||||
import inspect | import inspect | ||||
import logging | import logging | ||||
import pickle | import pickle | ||||
from typing import ( | from typing import ( | ||||
Any, | Any, | ||||
AsyncIterable, | |||||
Callable, | Callable, | ||||
ClassVar, | ClassVar, | ||||
Dict, | Dict, | ||||
Generic, | |||||
Iterable, | |||||
List, | List, | ||||
Optional, | Optional, | ||||
Tuple, | Tuple, | ||||
Type, | Type, | ||||
TypeVar, | TypeVar, | ||||
Union, | Union, | ||||
) | ) | ||||
import aiohttp | |||||
from flask import Flask, Request, Response, abort, request | from flask import Flask, Request, Response, abort, request | ||||
import requests | import requests | ||||
from werkzeug.exceptions import HTTPException | from werkzeug.exceptions import HTTPException | ||||
from .negotiation import Formatter as FormatterBase | from .negotiation import Formatter as FormatterBase | ||||
from .negotiation import Negotiator as NegotiatorBase | from .negotiation import Negotiator as NegotiatorBase | ||||
from .negotiation import negotiate as _negotiate | from .negotiation import negotiate as _negotiate | ||||
from .serializers import ( | from .serializers import ( | ||||
exception_to_dict, | exception_to_dict, | ||||
json_dumps, | json_dumps, | ||||
json_loads, | json_loads, | ||||
msgpack_dumps, | msgpack_dumps, | ||||
msgpack_loads, | msgpack_loads, | ||||
) | ) | ||||
from .serializers import decode_response | from .serializers import decode_response | ||||
from .serializers import encode_data_client as encode_data | from .serializers import encode_data_client as encode_data | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
# Genericity helpers | |||||
TSession = TypeVar("TSession", requests.Session, aiohttp.ClientSession) | |||||
"""Session class used by the underlying HTTP library""" | |||||
TResponse = TypeVar("TResponse", requests.Response, aiohttp.ClientResponse) | |||||
"""Response class used by the underlying HTTP library""" | |||||
# support for content negotiation | # support for content negotiation | ||||
class Negotiator(NegotiatorBase): | class Negotiator(NegotiatorBase): | ||||
def best_mimetype(self): | def best_mimetype(self): | ||||
return request.accept_mimetypes.best_match( | return request.accept_mimetypes.best_match( | ||||
self.accept_mimetypes, "application/json" | self.accept_mimetypes, "application/json" | ||||
) | ) | ||||
Show All 28 Lines | class MsgpackFormatter(Formatter): | ||||
def render(self, obj): | def render(self, obj): | ||||
return msgpack_dumps(obj, extra_encoders=self.extra_encoders) | return msgpack_dumps(obj, extra_encoders=self.extra_encoders) | ||||
# base API classes | # base API classes | ||||
class RemoteException(Exception): | class RemoteException(Generic[TResponse], Exception): | ||||
"""raised when remote returned an out-of-band failure notification, e.g., as a | """raised when remote returned an out-of-band failure notification, e.g., as a | ||||
HTTP status code or serialized exception | HTTP status code or serialized exception | ||||
Attributes: | Attributes: | ||||
response: HTTP response corresponding to the failure | response: HTTP response corresponding to the failure | ||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, | self, payload: Optional[Any] = None, response: Optional[TResponse] = None, | ||||
payload: Optional[Any] = None, | |||||
response: Optional[requests.Response] = None, | |||||
): | ): | ||||
if payload is not None: | if payload is not None: | ||||
super().__init__(payload) | super().__init__(payload) | ||||
else: | else: | ||||
super().__init__() | super().__init__() | ||||
self.response = response | self.response: Optional[TResponse] = response | ||||
def __str__(self): | def __str__(self): | ||||
if ( | if ( | ||||
self.args | self.args | ||||
and isinstance(self.args[0], dict) | and isinstance(self.args[0], dict) | ||||
and "type" in self.args[0] | and "type" in self.args[0] | ||||
and "args" in self.args[0] | and "args" in self.args[0] | ||||
): | ): | ||||
▲ Show 20 Lines • Show All 64 Lines • ▼ Show 20 Lines | def __add_endpoint(meth_name, meth, attributes): | ||||
# Send the request. | # Send the request. | ||||
return self.post(meth._endpoint_path, post_data) | return self.post(meth._endpoint_path, post_data) | ||||
if meth_name not in attributes: | if meth_name not in attributes: | ||||
attributes[meth_name] = meth_ | attributes[meth_name] = meth_ | ||||
class RPCClient(metaclass=MetaRPCClient): | class BaseRPCClient(Generic[TSession, TResponse], metaclass=MetaRPCClient): | ||||
"""Proxy to an internal SWH RPC | """Base class for :class:`RPCClient` and :class:`AsyncRPCClient`. | ||||
""" | """ | ||||
backend_class = None # type: ClassVar[Optional[type]] | backend_class = None # type: ClassVar[Optional[type]] | ||||
"""For each method of `backend_class` decorated with | """For each method of `backend_class` decorated with | ||||
:func:`remote_api_endpoint`, a method with the same prototype and | :func:`remote_api_endpoint`, a method with the same prototype and | ||||
docstring will be added to this class. Calls to this new method will | docstring will be added to this class. Calls to this new method will | ||||
be translated into HTTP requests to a remote server. | be translated into HTTP requests to a remote server. | ||||
Show All 27 Lines | def __init__( | ||||
**kwargs, | **kwargs, | ||||
): | ): | ||||
if api_exception: | if api_exception: | ||||
self.api_exception = api_exception | self.api_exception = api_exception | ||||
if reraise_exceptions: | if reraise_exceptions: | ||||
self.reraise_exceptions = reraise_exceptions | self.reraise_exceptions = reraise_exceptions | ||||
base_url = url if url.endswith("/") else url + "/" | base_url = url if url.endswith("/") else url + "/" | ||||
self.url = base_url | self.url = base_url | ||||
self.session = requests.Session() | |||||
self.timeout = timeout | |||||
self.chunk_size = chunk_size | |||||
self.session = self._make_session(kwargs) | |||||
def _url(self, endpoint): | |||||
return "%s%s" % (self.url, endpoint) | |||||
def _get_status_code(self, response: TResponse) -> int: | |||||
raise NotImplementedError("_get_status_code") | |||||
def _decode_response(self, response: TResponse, check_status=True) -> Any: | |||||
raise NotImplementedError("_decode_response") | |||||
def raise_for_status(self, response: TResponse) -> None: | |||||
"""check response HTTP status code and raise an exception if it denotes an | |||||
error; do nothing otherwise | |||||
""" | |||||
status_code = self._get_status_code(response) | |||||
status_class = status_code // 100 | |||||
if status_code == 404: | |||||
raise RemoteException(payload="404 not found", response=response) | |||||
exception = None | |||||
# TODO: only old servers send pickled error; stop trying to unpickle | |||||
# after they are all upgraded | |||||
try: | |||||
if status_class == 4: | |||||
data = self._decode_response(response, check_status=False) | |||||
if isinstance(data, dict): | |||||
# TODO: remove "exception" key check once all servers | |||||
# are using new schema | |||||
exc_data = data["exception"] if "exception" in data else data | |||||
for exc_type in self.reraise_exceptions: | |||||
if exc_type.__name__ == exc_data["type"]: | |||||
exception = exc_type(*exc_data["args"]) | |||||
break | |||||
else: | |||||
exception = RemoteException(payload=exc_data, response=response) | |||||
else: | |||||
exception = pickle.loads(data) | |||||
elif status_class == 5: | |||||
data = self._decode_response(response, check_status=False) | |||||
if "exception_pickled" in data: | |||||
exception = pickle.loads(data["exception_pickled"]) | |||||
else: | |||||
# TODO: remove "exception" key check once all servers | |||||
# are using new schema | |||||
exc_data = data["exception"] if "exception" in data else data | |||||
exception = RemoteException(payload=exc_data, response=response) | |||||
except (TypeError, pickle.UnpicklingError): | |||||
raise RemoteException(payload=data, response=response) | |||||
if exception: | |||||
raise exception from None | |||||
if status_class != 2: | |||||
raise RemoteException( | |||||
payload=f"API HTTP error: {status_code} {response.content}", | |||||
response=response, | |||||
) | |||||
def __repr__(self): | |||||
return "<{} url={}>".format(self.__class__.__name__, self.url) | |||||
class RPCClient(BaseRPCClient[requests.Session, requests.Response]): | |||||
"""Proxy to an internal SWH RPC | |||||
""" | |||||
def _make_session(self, kwargs): | |||||
session = requests.Session() | |||||
adapter = requests.adapters.HTTPAdapter( | adapter = requests.adapters.HTTPAdapter( | ||||
max_retries=kwargs.get("max_retries", 3), | max_retries=kwargs.get("max_retries", 3), | ||||
pool_connections=kwargs.get("pool_connections", 20), | pool_connections=kwargs.get("pool_connections", 20), | ||||
pool_maxsize=kwargs.get("pool_maxsize", 100), | pool_maxsize=kwargs.get("pool_maxsize", 100), | ||||
) | ) | ||||
self.session.mount(self.url, adapter) | session.mount(self.url, adapter) | ||||
return session | |||||
self.timeout = timeout | def _get_status_code(self, response: requests.Response) -> int: | ||||
self.chunk_size = chunk_size | return response.status_code | ||||
def _url(self, endpoint): | def _decode_response(self, response: requests.Response, check_status=True) -> Any: | ||||
return "%s%s" % (self.url, endpoint) | if check_status: | ||||
self.raise_for_status(response) | |||||
return decode_response( | |||||
response.content, | |||||
response.headers["content-type"], | |||||
extra_decoders=self.extra_type_decoders, | |||||
) | |||||
def _stream_response( | |||||
self, response: requests.Response, chunk_size: int | |||||
) -> Iterable[Any]: | |||||
return response.iter_content(chunk_size) | |||||
def raw_verb(self, verb, endpoint, **opts): | def raw_verb(self, verb: str, endpoint: str, **opts) -> TResponse: | ||||
if "chunk_size" in opts: | if "chunk_size" in opts: | ||||
# if the chunk_size argument has been passed, consider the user | # if the chunk_size argument has been passed, consider the user | ||||
# also wants stream=True, otherwise, what's the point. | # also wants stream=True, otherwise, what's the point. | ||||
opts["stream"] = True | opts["stream"] = True | ||||
if self.timeout and "timeout" not in opts: | if self.timeout and "timeout" not in opts: | ||||
opts["timeout"] = self.timeout | opts["timeout"] = self.timeout | ||||
try: | try: | ||||
return getattr(self.session, verb)(self._url(endpoint), **opts) | return getattr(self.session, verb)(self._url(endpoint), **opts) | ||||
except requests.exceptions.ConnectionError as e: | except requests.exceptions.ConnectionError as e: | ||||
raise self.api_exception(e) | raise self.api_exception(e) | ||||
def post(self, endpoint, data, **opts): | def post(self, endpoint: str, data: Any, **opts) -> Any: | ||||
if isinstance(data, (abc.Iterator, abc.Generator)): | if isinstance(data, (abc.Iterator, abc.Generator)): | ||||
data = (self._encode_data(x) for x in data) | data = (self._encode_data(x) for x in data) | ||||
else: | else: | ||||
data = self._encode_data(data) | data = self._encode_data(data) | ||||
chunk_size = opts.pop("chunk_size", self.chunk_size) | chunk_size = opts.pop("chunk_size", self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
"post", | "post", | ||||
endpoint, | endpoint, | ||||
data=data, | data=data, | ||||
headers={ | headers={ | ||||
"content-type": "application/x-msgpack", | "content-type": "application/x-msgpack", | ||||
"accept": "application/x-msgpack", | "accept": "application/x-msgpack", | ||||
}, | }, | ||||
**opts, | **opts, | ||||
) | ) | ||||
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": | if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": | ||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return response.iter_content(chunk_size) | return self._stream_response(response, chunk_size) | ||||
else: | else: | ||||
return self._decode_response(response) | return self._decode_response(response) | ||||
def _encode_data(self, data): | def _encode_data(self, data: Any) -> bytes: | ||||
return encode_data(data, extra_encoders=self.extra_type_encoders) | return encode_data(data, extra_encoders=self.extra_type_encoders) | ||||
post_stream = post | post_stream = post | ||||
def get(self, endpoint, **opts): | def get(self, endpoint: str, **opts) -> Any: | ||||
chunk_size = opts.pop("chunk_size", self.chunk_size) | chunk_size = opts.pop("chunk_size", self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts | "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts | ||||
) | ) | ||||
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": | if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": | ||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return response.iter_content(chunk_size) | return self._stream_response(response, chunk_size) | ||||
else: | else: | ||||
return self._decode_response(response) | return self._decode_response(response) | ||||
def get_stream(self, endpoint, **opts): | def get_stream(self, endpoint: str, **opts) -> Iterable[Any]: | ||||
return self.get(endpoint, stream=True, **opts) | return self.get(endpoint, stream=True, **opts) | ||||
def raise_for_status(self, response) -> None: | |||||
"""check response HTTP status code and raise an exception if it denotes an | class AsyncRPCClient(BaseRPCClient[aiohttp.ClientSession, aiohttp.ClientResponse]): | ||||
error; do nothing otherwise | """Asynchronous proxy to an internal SWH RPC | ||||
""" | """ | ||||
status_code = response.status_code | |||||
status_class = response.status_code // 100 | |||||
if status_code == 404: | def _make_session(self, kwargs): | ||||
raise RemoteException(payload="404 not found", response=response) | return aiohttp.ClientSession( | ||||
connector=aiohttp.TCPConnector(limit=kwargs.get("pool_maxsize", 100),), | |||||
timeout=aiohttp.ClientTimeout(total=self.timeout), | |||||
connector_owner=True, | |||||
) | |||||
exception = None | async def __aenter__(self): | ||||
print("open") | |||||
await self.session.__aenter__() | |||||
async def __aexit__(self, *args): | |||||
print("close") | |||||
await self.session.__aexit__(*args) | |||||
def _get_status_code(self, response: aiohttp.ClientResponse) -> int: | |||||
return response.status | |||||
async def _decode_response( | |||||
self, response: aiohttp.ClientResponse, check_status=True | |||||
) -> Any: | |||||
if check_status: | |||||
self.raise_for_status(response) | |||||
return decode_response( | |||||
await response.read(), | |||||
response.headers["content-type"], | |||||
extra_decoders=self.extra_type_decoders, | |||||
) | |||||
# TODO: only old servers send pickled error; stop trying to unpickle | @contextlib.asynccontextmanager | ||||
# after they are all upgraded | async def raw_verb(self, verb: str, endpoint: str, **opts): | ||||
if "chunk_size" in opts: | |||||
# if the chunk_size argument has been passed, consider the user | |||||
# also wants stream=True, otherwise, what's the point. | |||||
opts["stream"] = True | |||||
if self.timeout and "timeout" not in opts: | |||||
opts["timeout"] = self.timeout | |||||
try: | try: | ||||
if status_class == 4: | async with getattr(self.session, verb)(self._url(endpoint), **opts) as r: | ||||
data = self._decode_response(response, check_status=False) | yield r | ||||
if isinstance(data, dict): | except requests.exceptions.ConnectionError as e: | ||||
# TODO: remove "exception" key check once all servers | raise self.api_exception(e) | ||||
# are using new schema | |||||
exc_data = data["exception"] if "exception" in data else data | |||||
for exc_type in self.reraise_exceptions: | |||||
if exc_type.__name__ == exc_data["type"]: | |||||
exception = exc_type(*exc_data["args"]) | |||||
break | |||||
else: | |||||
exception = RemoteException(payload=exc_data, response=response) | |||||
else: | |||||
exception = pickle.loads(data) | |||||
elif status_class == 5: | async def post(self, endpoint: str, data: Any, **opts) -> Any: | ||||
data = self._decode_response(response, check_status=False) | if isinstance(data, (abc.Iterator, abc.Generator)): | ||||
if "exception_pickled" in data: | data = (self._encode_data(x) for x in data) | ||||
exception = pickle.loads(data["exception_pickled"]) | |||||
else: | else: | ||||
# TODO: remove "exception" key check once all servers | data = self._encode_data(data) | ||||
# are using new schema | chunk_size = opts.pop("chunk_size", self.chunk_size) | ||||
exc_data = data["exception"] if "exception" in data else data | async with self.raw_verb( | ||||
exception = RemoteException(payload=exc_data, response=response) | "post", | ||||
endpoint, | |||||
except (TypeError, pickle.UnpicklingError): | data=data, | ||||
raise RemoteException(payload=data, response=response) | headers={ | ||||
"content-type": "application/x-msgpack", | |||||
"accept": "application/x-msgpack", | |||||
}, | |||||
**opts, | |||||
) as response: | |||||
if ( | |||||
opts.get("stream") | |||||
or response.headers.get("transfer-encoding") == "chunked" | |||||
): | |||||
self.raise_for_status(response) | |||||
return await self._stream_response(response, chunk_size) | |||||
else: | |||||
return await self._decode_response(response) | |||||
if exception: | def _encode_data(self, data: Any) -> bytes: | ||||
raise exception from None | return encode_data(data, extra_encoders=self.extra_type_encoders) | ||||
if status_class != 2: | post_stream = post | ||||
raise RemoteException( | |||||
payload=f"API HTTP error: {status_code} {response.content}", | |||||
response=response, | |||||
) | |||||
def _decode_response(self, response, check_status=True): | async def get(self, endpoint: str, **opts) -> Any: | ||||
if check_status: | chunk_size = opts.pop("chunk_size", self.chunk_size) | ||||
async with self.raw_verb( | |||||
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts | |||||
) as response: | |||||
if ( | |||||
opts.get("stream") | |||||
or response.headers.get("transfer-encoding") == "chunked" | |||||
): | |||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return decode_response(response, extra_decoders=self.extra_type_decoders) | return await self._stream_response(response, chunk_size) | ||||
else: | |||||
return await self._decode_response(response) | |||||
def __repr__(self): | async def get_stream(self, endpoint: str, **opts) -> AsyncIterable[Any]: | ||||
return "<{} url={}>".format(self.__class__.__name__, self.url) | async for v in self.get(endpoint, stream=True, **opts): | ||||
yield v | |||||
class BytesRequest(Request): | class BytesRequest(Request): | ||||
"""Request with proper escaping of arbitrary byte sequences.""" | """Request with proper escaping of arbitrary byte sequences.""" | ||||
encoding = "utf-8" | encoding = "utf-8" | ||||
encoding_errors = "surrogateescape" | encoding_errors = "surrogateescape" | ||||
▲ Show 20 Lines • Show All 94 Lines • Show Last 20 Lines |