Changeset View
Changeset View
Standalone View
Standalone View
swh/core/api/__init__.py
# Copyright (C) 2015-2020 The Software Heritage developers | # Copyright (C) 2015-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import abc | from collections import abc | ||||
import functools | import functools | ||||
import inspect | import inspect | ||||
import logging | import logging | ||||
import pickle | import pickle | ||||
import requests | import requests | ||||
from typing import ( | from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union | ||||
Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, | |||||
) | |||||
from flask import Flask, Request, Response, request, abort | from flask import Flask, Request, Response, request, abort | ||||
from werkzeug.exceptions import HTTPException | from werkzeug.exceptions import HTTPException | ||||
from .serializers import (decode_response, | from .serializers import ( | ||||
decode_response, | |||||
encode_data_client as encode_data, | encode_data_client as encode_data, | ||||
msgpack_dumps, msgpack_loads, | msgpack_dumps, | ||||
json_dumps, json_loads, | msgpack_loads, | ||||
exception_to_dict) | json_dumps, | ||||
json_loads, | |||||
exception_to_dict, | |||||
) | |||||
from .negotiation import (Formatter as FormatterBase, | from .negotiation import ( | ||||
Formatter as FormatterBase, | |||||
Negotiator as NegotiatorBase, | Negotiator as NegotiatorBase, | ||||
negotiate as _negotiate) | negotiate as _negotiate, | ||||
) | |||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
# support for content negotiation | # support for content negotiation | ||||
class Negotiator(NegotiatorBase): | class Negotiator(NegotiatorBase): | ||||
def best_mimetype(self): | def best_mimetype(self): | ||||
return request.accept_mimetypes.best_match( | return request.accept_mimetypes.best_match( | ||||
self.accept_mimetypes, 'application/json') | self.accept_mimetypes, 'application/json' | ||||
) | |||||
def _abort(self, status_code, err=None): | def _abort(self, status_code, err=None): | ||||
return abort(status_code, err) | return abort(status_code, err) | ||||
def negotiate(formatter_cls, *args, **kwargs): | def negotiate(formatter_cls, *args, **kwargs): | ||||
return _negotiate(Negotiator, formatter_cls, *args, **kwargs) | return _negotiate(Negotiator, formatter_cls, *args, **kwargs) | ||||
Show All 19 Lines | class MsgpackFormatter(Formatter): | ||||
mimetypes = ['application/x-msgpack'] | mimetypes = ['application/x-msgpack'] | ||||
def render(self, obj): | def render(self, obj): | ||||
return msgpack_dumps(obj, extra_encoders=self.extra_encoders) | return msgpack_dumps(obj, extra_encoders=self.extra_encoders) | ||||
# base API classes | # base API classes | ||||
class RemoteException(Exception): | class RemoteException(Exception): | ||||
"""raised when remote returned an out-of-band failure notification, e.g., as a | """raised when remote returned an out-of-band failure notification, e.g., as a | ||||
HTTP status code or serialized exception | HTTP status code or serialized exception | ||||
Attributes: | Attributes: | ||||
response: HTTP response corresponding to the failure | response: HTTP response corresponding to the failure | ||||
""" | """ | ||||
def __init__(self, payload: Optional[Any] = None, | |||||
response: Optional[requests.Response] = None): | def __init__( | ||||
self, | |||||
payload: Optional[Any] = None, | |||||
response: Optional[requests.Response] = None, | |||||
): | |||||
if payload is not None: | if payload is not None: | ||||
super().__init__(payload) | super().__init__(payload) | ||||
else: | else: | ||||
super().__init__() | super().__init__() | ||||
self.response = response | self.response = response | ||||
def __str__(self): | def __str__(self): | ||||
if self.args and isinstance(self.args[0], dict) \ | if ( | ||||
and 'type' in self.args[0] and 'args' in self.args[0]: | self.args | ||||
and isinstance(self.args[0], dict) | |||||
and 'type' in self.args[0] | |||||
and 'args' in self.args[0] | |||||
): | |||||
return ( | return ( | ||||
f'<RemoteException {self.response.status_code} ' | f'<RemoteException {self.response.status_code} ' | ||||
f'{self.args[0]["type"]}: {self.args[0]["args"]}>') | f'{self.args[0]["type"]}: {self.args[0]["args"]}>' | ||||
) | |||||
else: | else: | ||||
return super().__str__() | return super().__str__() | ||||
def remote_api_endpoint(path): | def remote_api_endpoint(path): | ||||
def dec(f): | def dec(f): | ||||
f._endpoint_path = path | f._endpoint_path = path | ||||
return f | return f | ||||
return dec | return dec | ||||
class APIError(Exception): | class APIError(Exception): | ||||
"""API Error""" | """API Error""" | ||||
def __str__(self): | def __str__(self): | ||||
return ('An unexpected error occurred in the backend: {}' | return 'An unexpected error occurred in the backend: {}'.format(self.args) | ||||
.format(self.args)) | |||||
class MetaRPCClient(type): | class MetaRPCClient(type): | ||||
"""Metaclass for RPCClient, which adds a method for each endpoint | """Metaclass for RPCClient, which adds a method for each endpoint | ||||
of the database it is designed to access. | of the database it is designed to access. | ||||
See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" | See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" | ||||
def __new__(cls, name, bases, attributes): | def __new__(cls, name, bases, attributes): | ||||
# For each method wrapped with @remote_api_endpoint in an API backend | # For each method wrapped with @remote_api_endpoint in an API backend | ||||
# (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new | # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new | ||||
# method in RemoteStorage, with the same documentation. | # method in RemoteStorage, with the same documentation. | ||||
# | # | ||||
# Note that, despite the usage of decorator magic (eg. functools.wrap), | # Note that, despite the usage of decorator magic (eg. functools.wrap), | ||||
# this never actually calls an IndexerStorage method. | # this never actually calls an IndexerStorage method. | ||||
backend_class = attributes.get('backend_class', None) | backend_class = attributes.get('backend_class', None) | ||||
Show All 9 Lines | class MetaRPCClient(type): | ||||
@staticmethod | @staticmethod | ||||
def __add_endpoint(meth_name, meth, attributes): | def __add_endpoint(meth_name, meth, attributes): | ||||
wrapped_meth = inspect.unwrap(meth) | wrapped_meth = inspect.unwrap(meth) | ||||
@functools.wraps(meth) # Copy signature and doc | @functools.wraps(meth) # Copy signature and doc | ||||
def meth_(*args, **kwargs): | def meth_(*args, **kwargs): | ||||
# Match arguments and parameters | # Match arguments and parameters | ||||
post_data = inspect.getcallargs( | post_data = inspect.getcallargs(wrapped_meth, *args, **kwargs) | ||||
wrapped_meth, *args, **kwargs) | |||||
# Remove arguments that should not be passed | # Remove arguments that should not be passed | ||||
self = post_data.pop('self') | self = post_data.pop('self') | ||||
post_data.pop('cur', None) | post_data.pop('cur', None) | ||||
post_data.pop('db', None) | post_data.pop('db', None) | ||||
# Send the request. | # Send the request. | ||||
return self.post(meth._endpoint_path, post_data) | return self.post(meth._endpoint_path, post_data) | ||||
if meth_name not in attributes: | if meth_name not in attributes: | ||||
attributes[meth_name] = meth_ | attributes[meth_name] = meth_ | ||||
class RPCClient(metaclass=MetaRPCClient): | class RPCClient(metaclass=MetaRPCClient): | ||||
"""Proxy to an internal SWH RPC | """Proxy to an internal SWH RPC | ||||
""" | """ | ||||
Show All 18 Lines | class RPCClient(metaclass=MetaRPCClient): | ||||
extra_type_encoders: List[Tuple[type, str, Callable]] = [] | extra_type_encoders: List[Tuple[type, str, Callable]] = [] | ||||
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | ||||
to be able to serialize more object types.""" | to be able to serialize more object types.""" | ||||
extra_type_decoders: Dict[str, Callable] = {} | extra_type_decoders: Dict[str, Callable] = {} | ||||
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | ||||
to be able to deserialize more object types.""" | to be able to deserialize more object types.""" | ||||
def __init__(self, url, api_exception=None, | def __init__( | ||||
timeout=None, chunk_size=4096, | self, | ||||
reraise_exceptions=None, **kwargs): | url, | ||||
api_exception=None, | |||||
timeout=None, | |||||
chunk_size=4096, | |||||
reraise_exceptions=None, | |||||
**kwargs, | |||||
): | |||||
if api_exception: | if api_exception: | ||||
self.api_exception = api_exception | self.api_exception = api_exception | ||||
if reraise_exceptions: | if reraise_exceptions: | ||||
self.reraise_exceptions = reraise_exceptions | self.reraise_exceptions = reraise_exceptions | ||||
base_url = url if url.endswith('/') else url + '/' | base_url = url if url.endswith('/') else url + '/' | ||||
self.url = base_url | self.url = base_url | ||||
self.session = requests.Session() | self.session = requests.Session() | ||||
adapter = requests.adapters.HTTPAdapter( | adapter = requests.adapters.HTTPAdapter( | ||||
max_retries=kwargs.get('max_retries', 3), | max_retries=kwargs.get('max_retries', 3), | ||||
pool_connections=kwargs.get('pool_connections', 20), | pool_connections=kwargs.get('pool_connections', 20), | ||||
pool_maxsize=kwargs.get('pool_maxsize', 100)) | pool_maxsize=kwargs.get('pool_maxsize', 100), | ||||
) | |||||
self.session.mount(self.url, adapter) | self.session.mount(self.url, adapter) | ||||
self.timeout = timeout | self.timeout = timeout | ||||
self.chunk_size = chunk_size | self.chunk_size = chunk_size | ||||
def _url(self, endpoint): | def _url(self, endpoint): | ||||
return '%s%s' % (self.url, endpoint) | return '%s%s' % (self.url, endpoint) | ||||
def raw_verb(self, verb, endpoint, **opts): | def raw_verb(self, verb, endpoint, **opts): | ||||
if 'chunk_size' in opts: | if 'chunk_size' in opts: | ||||
# if the chunk_size argument has been passed, consider the user | # if the chunk_size argument has been passed, consider the user | ||||
# also wants stream=True, otherwise, what's the point. | # also wants stream=True, otherwise, what's the point. | ||||
opts['stream'] = True | opts['stream'] = True | ||||
if self.timeout and 'timeout' not in opts: | if self.timeout and 'timeout' not in opts: | ||||
opts['timeout'] = self.timeout | opts['timeout'] = self.timeout | ||||
try: | try: | ||||
return getattr(self.session, verb)( | return getattr(self.session, verb)(self._url(endpoint), **opts) | ||||
self._url(endpoint), | |||||
**opts | |||||
) | |||||
except requests.exceptions.ConnectionError as e: | except requests.exceptions.ConnectionError as e: | ||||
raise self.api_exception(e) | raise self.api_exception(e) | ||||
def post(self, endpoint, data, **opts): | def post(self, endpoint, data, **opts): | ||||
if isinstance(data, (abc.Iterator, abc.Generator)): | if isinstance(data, (abc.Iterator, abc.Generator)): | ||||
data = (self._encode_data(x) for x in data) | data = (self._encode_data(x) for x in data) | ||||
else: | else: | ||||
data = self._encode_data(data) | data = self._encode_data(data) | ||||
chunk_size = opts.pop('chunk_size', self.chunk_size) | chunk_size = opts.pop('chunk_size', self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
'post', endpoint, data=data, | 'post', | ||||
headers={'content-type': 'application/x-msgpack', | endpoint, | ||||
'accept': 'application/x-msgpack'}, | data=data, | ||||
**opts) | headers={ | ||||
if opts.get('stream') or \ | 'content-type': 'application/x-msgpack', | ||||
response.headers.get('transfer-encoding') == 'chunked': | 'accept': 'application/x-msgpack', | ||||
}, | |||||
**opts, | |||||
) | |||||
if opts.get('stream') or response.headers.get('transfer-encoding') == 'chunked': | |||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return response.iter_content(chunk_size) | return response.iter_content(chunk_size) | ||||
else: | else: | ||||
return self._decode_response(response) | return self._decode_response(response) | ||||
def _encode_data(self, data): | def _encode_data(self, data): | ||||
return encode_data(data, extra_encoders=self.extra_type_encoders) | return encode_data(data, extra_encoders=self.extra_type_encoders) | ||||
post_stream = post | post_stream = post | ||||
def get(self, endpoint, **opts): | def get(self, endpoint, **opts): | ||||
chunk_size = opts.pop('chunk_size', self.chunk_size) | chunk_size = opts.pop('chunk_size', self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
'get', endpoint, | 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts | ||||
headers={'accept': 'application/x-msgpack'}, | ) | ||||
**opts) | if opts.get('stream') or response.headers.get('transfer-encoding') == 'chunked': | ||||
if opts.get('stream') or \ | |||||
response.headers.get('transfer-encoding') == 'chunked': | |||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return response.iter_content(chunk_size) | return response.iter_content(chunk_size) | ||||
else: | else: | ||||
return self._decode_response(response) | return self._decode_response(response) | ||||
def get_stream(self, endpoint, **opts): | def get_stream(self, endpoint, **opts): | ||||
return self.get(endpoint, stream=True, **opts) | return self.get(endpoint, stream=True, **opts) | ||||
Show All 16 Lines | def raise_for_status(self, response) -> None: | ||||
if status_class == 4: | if status_class == 4: | ||||
data = self._decode_response(response, check_status=False) | data = self._decode_response(response, check_status=False) | ||||
if isinstance(data, dict): | if isinstance(data, dict): | ||||
for exc_type in self.reraise_exceptions: | for exc_type in self.reraise_exceptions: | ||||
if exc_type.__name__ == data['exception']['type']: | if exc_type.__name__ == data['exception']['type']: | ||||
exception = exc_type(*data['exception']['args']) | exception = exc_type(*data['exception']['args']) | ||||
break | break | ||||
else: | else: | ||||
exception = RemoteException(payload=data['exception'], | exception = RemoteException( | ||||
response=response) | payload=data['exception'], response=response | ||||
) | |||||
else: | else: | ||||
exception = pickle.loads(data) | exception = pickle.loads(data) | ||||
elif status_class == 5: | elif status_class == 5: | ||||
data = self._decode_response(response, check_status=False) | data = self._decode_response(response, check_status=False) | ||||
if 'exception_pickled' in data: | if 'exception_pickled' in data: | ||||
exception = pickle.loads(data['exception_pickled']) | exception = pickle.loads(data['exception_pickled']) | ||||
else: | else: | ||||
exception = RemoteException(payload=data['exception'], | exception = RemoteException( | ||||
response=response) | payload=data['exception'], response=response | ||||
) | |||||
except (TypeError, pickle.UnpicklingError): | except (TypeError, pickle.UnpicklingError): | ||||
raise RemoteException(payload=data, response=response) | raise RemoteException(payload=data, response=response) | ||||
if exception: | if exception: | ||||
raise exception from None | raise exception from None | ||||
if status_class != 2: | if status_class != 2: | ||||
raise RemoteException( | raise RemoteException( | ||||
payload=f'API HTTP error: {status_code} {response.content}', | payload=f'API HTTP error: {status_code} {response.content}', | ||||
response=response) | response=response, | ||||
) | |||||
def _decode_response(self, response, check_status=True): | def _decode_response(self, response, check_status=True): | ||||
if check_status: | if check_status: | ||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return decode_response( | return decode_response(response, extra_decoders=self.extra_type_decoders) | ||||
response, extra_decoders=self.extra_type_decoders) | |||||
def __repr__(self): | def __repr__(self): | ||||
return '<{} url={}>'.format(self.__class__.__name__, self.url) | return '<{} url={}>'.format(self.__class__.__name__, self.url) | ||||
class BytesRequest(Request): | class BytesRequest(Request): | ||||
"""Request with proper escaping of arbitrary byte sequences.""" | """Request with proper escaping of arbitrary byte sequences.""" | ||||
encoding = 'utf-8' | encoding = 'utf-8' | ||||
encoding_errors = 'surrogateescape' | encoding_errors = 'surrogateescape' | ||||
ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { | ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { | ||||
'application/x-msgpack': msgpack_dumps, | 'application/x-msgpack': msgpack_dumps, | ||||
'application/json': json_dumps, | 'application/json': json_dumps, | ||||
} | } | ||||
def encode_data_server( | def encode_data_server( | ||||
data, content_type='application/x-msgpack', extra_type_encoders=None): | data, content_type='application/x-msgpack', extra_type_encoders=None | ||||
encoded_data = ENCODERS[content_type]( | ): | ||||
data, extra_encoders=extra_type_encoders) | encoded_data = ENCODERS[content_type](data, extra_encoders=extra_type_encoders) | ||||
return Response( | return Response(encoded_data, mimetype=content_type) | ||||
encoded_data, | |||||
mimetype=content_type, | |||||
) | |||||
def decode_request(request, extra_decoders=None): | def decode_request(request, extra_decoders=None): | ||||
content_type = request.mimetype | content_type = request.mimetype | ||||
data = request.get_data() | data = request.get_data() | ||||
if not data: | if not data: | ||||
return {} | return {} | ||||
if content_type == 'application/x-msgpack': | if content_type == 'application/x-msgpack': | ||||
r = msgpack_loads(data, extra_decoders=extra_decoders) | r = msgpack_loads(data, extra_decoders=extra_decoders) | ||||
elif content_type == 'application/json': | elif content_type == 'application/json': | ||||
# XXX this .decode() is needed for py35. | # XXX this .decode() is needed for py35. | ||||
# Should not be needed any more with py37 | # Should not be needed any more with py37 | ||||
r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) | r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) | ||||
else: | else: | ||||
raise ValueError('Wrong content type `%s` for API request' | raise ValueError('Wrong content type `%s` for API request' % content_type) | ||||
% content_type) | |||||
return r | return r | ||||
def error_handler(exception, encoder, status_code=500): | def error_handler(exception, encoder, status_code=500): | ||||
logging.exception(exception) | logging.exception(exception) | ||||
response = encoder(exception_to_dict(exception)) | response = encoder(exception_to_dict(exception)) | ||||
if isinstance(exception, HTTPException): | if isinstance(exception, HTTPException): | ||||
Show All 12 Lines | class RPCServerApp(Flask): | ||||
:param Any backend_class: | :param Any backend_class: | ||||
The class of the backend, which will be analyzed to look | The class of the backend, which will be analyzed to look | ||||
for API endpoints. | for API endpoints. | ||||
:param Optional[Callable[[], backend_class]] backend_factory: | :param Optional[Callable[[], backend_class]] backend_factory: | ||||
A function with no argument that returns an instance of | A function with no argument that returns an instance of | ||||
`backend_class`. If unset, defaults to calling `backend_class` | `backend_class`. If unset, defaults to calling `backend_class` | ||||
constructor directly. | constructor directly. | ||||
""" | """ | ||||
request_class = BytesRequest | request_class = BytesRequest | ||||
extra_type_encoders: List[Tuple[type, str, Callable]] = [] | extra_type_encoders: List[Tuple[type, str, Callable]] = [] | ||||
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | ||||
to be able to serialize more object types.""" | to be able to serialize more object types.""" | ||||
extra_type_decoders: Dict[str, Callable] = {} | extra_type_decoders: Dict[str, Callable] = {} | ||||
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | ||||
to be able to deserialize more object types.""" | to be able to deserialize more object types.""" | ||||
def __init__(self, *args, backend_class=None, backend_factory=None, | def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): | ||||
**kwargs): | |||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
self.backend_class = backend_class | self.backend_class = backend_class | ||||
if backend_class is not None: | if backend_class is not None: | ||||
if backend_factory is None: | if backend_factory is None: | ||||
backend_factory = backend_class | backend_factory = backend_class | ||||
for (meth_name, meth) in backend_class.__dict__.items(): | for (meth_name, meth) in backend_class.__dict__.items(): | ||||
if hasattr(meth, '_endpoint_path'): | if hasattr(meth, '_endpoint_path'): | ||||
self.__add_endpoint(meth_name, meth, backend_factory) | self.__add_endpoint(meth_name, meth, backend_factory) | ||||
def __add_endpoint(self, meth_name, meth, backend_factory): | def __add_endpoint(self, meth_name, meth, backend_factory): | ||||
from flask import request | from flask import request | ||||
@self.route('/'+meth._endpoint_path, methods=['POST']) | @self.route('/' + meth._endpoint_path, methods=['POST']) | ||||
@negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) | @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) | ||||
@negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) | @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) | ||||
@functools.wraps(meth) # Copy signature and doc | @functools.wraps(meth) # Copy signature and doc | ||||
def _f(): | def _f(): | ||||
# Call the actual code | # Call the actual code | ||||
obj_meth = getattr(backend_factory(), meth_name) | obj_meth = getattr(backend_factory(), meth_name) | ||||
kw = decode_request( | kw = decode_request(request, extra_decoders=self.extra_type_decoders) | ||||
request, extra_decoders=self.extra_type_decoders) | |||||
return obj_meth(**kw) | return obj_meth(**kw) |