Changeset View
Changeset View
Standalone View
Standalone View
swh/core/api/__init__.py
# Copyright (C) 2015-2017 The Software Heritage developers | # Copyright (C) 2015-2017 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from collections import abc | from collections import abc | ||||
import datetime | |||||
import functools | import functools | ||||
import inspect | import inspect | ||||
import json | |||||
import logging | import logging | ||||
import pickle | import pickle | ||||
import requests | import requests | ||||
import traceback | import traceback | ||||
from typing import Any, ClassVar, List, Optional, Type | from typing import ( | ||||
Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, | |||||
) | |||||
from flask import Flask, Request, Response, request, abort | from flask import Flask, Request, Response, request, abort | ||||
from werkzeug.exceptions import HTTPException | from werkzeug.exceptions import HTTPException | ||||
from .serializers import (decode_response, | from .serializers import (decode_response, | ||||
encode_data_client as encode_data, | encode_data_client as encode_data, | ||||
msgpack_dumps, msgpack_loads, SWHJSONDecoder) | msgpack_dumps, msgpack_loads, | ||||
json_dumps, json_loads) | |||||
from .negotiation import (Formatter as FormatterBase, | from .negotiation import (Formatter as FormatterBase, | ||||
Negotiator as NegotiatorBase, | Negotiator as NegotiatorBase, | ||||
negotiate as _negotiate) | negotiate as _negotiate) | ||||
logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
Show All 12 Lines | |||||
def negotiate(formatter_cls, *args, **kwargs): | def negotiate(formatter_cls, *args, **kwargs): | ||||
return _negotiate(Negotiator, formatter_cls, *args, **kwargs) | return _negotiate(Negotiator, formatter_cls, *args, **kwargs) | ||||
class Formatter(FormatterBase): | class Formatter(FormatterBase): | ||||
def _make_response(self, body, content_type): | def _make_response(self, body, content_type): | ||||
return Response(body, content_type=content_type) | return Response(body, content_type=content_type) | ||||
def configure(self, extra_encoders): | |||||
class SWHJSONEncoder(json.JSONEncoder): | self.extra_encoders = extra_encoders | ||||
def default(self, obj): | |||||
if isinstance(obj, (datetime.datetime, datetime.date)): | |||||
return obj.isoformat() | |||||
if isinstance(obj, datetime.timedelta): | |||||
return str(obj) | |||||
# Let the base class default method raise the TypeError | |||||
return super().default(obj) | |||||
class JSONFormatter(Formatter): | class JSONFormatter(Formatter): | ||||
format = 'json' | format = 'json' | ||||
mimetypes = ['application/json'] | mimetypes = ['application/json'] | ||||
def render(self, obj): | def render(self, obj): | ||||
return json.dumps(obj, cls=SWHJSONEncoder) | return json_dumps(obj, extra_encoders=self.extra_encoders) | ||||
class MsgpackFormatter(Formatter): | class MsgpackFormatter(Formatter): | ||||
format = 'msgpack' | format = 'msgpack' | ||||
mimetypes = ['application/x-msgpack'] | mimetypes = ['application/x-msgpack'] | ||||
def render(self, obj): | def render(self, obj): | ||||
return msgpack_dumps(obj) | return msgpack_dumps(obj, extra_encoders=self.extra_encoders) | ||||
# base API classes | # base API classes | ||||
class RemoteException(Exception): | class RemoteException(Exception): | ||||
"""raised when remote returned an out-of-band failure notification, e.g., as a | """raised when remote returned an out-of-band failure notification, e.g., as a | ||||
HTTP status code or serialized exception | HTTP status code or serialized exception | ||||
▲ Show 20 Lines • Show All 94 Lines • ▼ Show 20 Lines | class RPCClient(metaclass=MetaRPCClient): | ||||
"""The exception class to raise in case of communication error with | """The exception class to raise in case of communication error with | ||||
the server.""" | the server.""" | ||||
reraise_exceptions: ClassVar[List[Type[Exception]]] = [] | reraise_exceptions: ClassVar[List[Type[Exception]]] = [] | ||||
"""On server errors, if any of the exception classes in this list | """On server errors, if any of the exception classes in this list | ||||
has the same name as the error name, then the exception will | has the same name as the error name, then the exception will | ||||
be instantiated and raised instead of a generic RemoteException.""" | be instantiated and raised instead of a generic RemoteException.""" | ||||
extra_type_encoders: List[Tuple[type, str, Callable]] = [] | |||||
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | |||||
to be able to serialize more object types.""" | |||||
extra_type_decoders: Dict[str, Callable] = {} | |||||
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | |||||
to be able to deserialize more object types.""" | |||||
def __init__(self, url, api_exception=None, | def __init__(self, url, api_exception=None, | ||||
timeout=None, chunk_size=4096, | timeout=None, chunk_size=4096, | ||||
reraise_exceptions=None, **kwargs): | reraise_exceptions=None, **kwargs): | ||||
if api_exception: | if api_exception: | ||||
self.api_exception = api_exception | self.api_exception = api_exception | ||||
if reraise_exceptions: | if reraise_exceptions: | ||||
self.reraise_exceptions = reraise_exceptions | self.reraise_exceptions = reraise_exceptions | ||||
base_url = url if url.endswith('/') else url + '/' | base_url = url if url.endswith('/') else url + '/' | ||||
Show All 23 Lines | def raw_verb(self, verb, endpoint, **opts): | ||||
self._url(endpoint), | self._url(endpoint), | ||||
**opts | **opts | ||||
) | ) | ||||
except requests.exceptions.ConnectionError as e: | except requests.exceptions.ConnectionError as e: | ||||
raise self.api_exception(e) | raise self.api_exception(e) | ||||
def post(self, endpoint, data, **opts): | def post(self, endpoint, data, **opts): | ||||
if isinstance(data, (abc.Iterator, abc.Generator)): | if isinstance(data, (abc.Iterator, abc.Generator)): | ||||
data = (encode_data(x) for x in data) | data = (self._encode_data(x) for x in data) | ||||
else: | else: | ||||
data = encode_data(data) | data = self._encode_data(data) | ||||
chunk_size = opts.pop('chunk_size', self.chunk_size) | chunk_size = opts.pop('chunk_size', self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
'post', endpoint, data=data, | 'post', endpoint, data=data, | ||||
headers={'content-type': 'application/x-msgpack', | headers={'content-type': 'application/x-msgpack', | ||||
'accept': 'application/x-msgpack'}, | 'accept': 'application/x-msgpack'}, | ||||
**opts) | **opts) | ||||
if opts.get('stream') or \ | if opts.get('stream') or \ | ||||
response.headers.get('transfer-encoding') == 'chunked': | response.headers.get('transfer-encoding') == 'chunked': | ||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return response.iter_content(chunk_size) | return response.iter_content(chunk_size) | ||||
else: | else: | ||||
return self._decode_response(response) | return self._decode_response(response) | ||||
def _encode_data(self, data): | |||||
return encode_data(data, extra_encoders=self.extra_type_encoders) | |||||
post_stream = post | post_stream = post | ||||
def get(self, endpoint, **opts): | def get(self, endpoint, **opts): | ||||
chunk_size = opts.pop('chunk_size', self.chunk_size) | chunk_size = opts.pop('chunk_size', self.chunk_size) | ||||
response = self.raw_verb( | response = self.raw_verb( | ||||
'get', endpoint, | 'get', endpoint, | ||||
headers={'accept': 'application/x-msgpack'}, | headers={'accept': 'application/x-msgpack'}, | ||||
**opts) | **opts) | ||||
Show All 19 Lines | def raise_for_status(self, response) -> None: | ||||
raise RemoteException(payload='404 not found', response=response) | raise RemoteException(payload='404 not found', response=response) | ||||
exception = None | exception = None | ||||
# TODO: only old servers send pickled error; stop trying to unpickle | # TODO: only old servers send pickled error; stop trying to unpickle | ||||
# after they are all upgraded | # after they are all upgraded | ||||
try: | try: | ||||
if status_class == 4: | if status_class == 4: | ||||
data = decode_response(response) | data = self._decode_response(response, check_status=False) | ||||
if isinstance(data, dict): | if isinstance(data, dict): | ||||
for exc_type in self.reraise_exceptions: | for exc_type in self.reraise_exceptions: | ||||
if exc_type.__name__ == data['exception']['type']: | if exc_type.__name__ == data['exception']['type']: | ||||
exception = exc_type(*data['exception']['args']) | exception = exc_type(*data['exception']['args']) | ||||
break | break | ||||
else: | else: | ||||
exception = RemoteException(payload=data['exception'], | exception = RemoteException(payload=data['exception'], | ||||
response=response) | response=response) | ||||
else: | else: | ||||
exception = pickle.loads(data) | exception = pickle.loads(data) | ||||
elif status_class == 5: | elif status_class == 5: | ||||
data = decode_response(response) | data = self._decode_response(response, check_status=False) | ||||
if 'exception_pickled' in data: | if 'exception_pickled' in data: | ||||
exception = pickle.loads(data['exception_pickled']) | exception = pickle.loads(data['exception_pickled']) | ||||
else: | else: | ||||
exception = RemoteException(payload=data['exception'], | exception = RemoteException(payload=data['exception'], | ||||
response=response) | response=response) | ||||
except (TypeError, pickle.UnpicklingError): | except (TypeError, pickle.UnpicklingError): | ||||
raise RemoteException(payload=data, response=response) | raise RemoteException(payload=data, response=response) | ||||
if exception: | if exception: | ||||
raise exception from None | raise exception from None | ||||
if status_class != 2: | if status_class != 2: | ||||
raise RemoteException( | raise RemoteException( | ||||
payload=f'API HTTP error: {status_code} {response.content}', | payload=f'API HTTP error: {status_code} {response.content}', | ||||
response=response) | response=response) | ||||
def _decode_response(self, response): | def _decode_response(self, response, check_status=True): | ||||
if check_status: | |||||
self.raise_for_status(response) | self.raise_for_status(response) | ||||
return decode_response(response) | return decode_response( | ||||
response, extra_decoders=self.extra_type_decoders) | |||||
def __repr__(self): | def __repr__(self): | ||||
return '<{} url={}>'.format(self.__class__.__name__, self.url) | return '<{} url={}>'.format(self.__class__.__name__, self.url) | ||||
class BytesRequest(Request): | class BytesRequest(Request): | ||||
"""Request with proper escaping of arbitrary byte sequences.""" | """Request with proper escaping of arbitrary byte sequences.""" | ||||
encoding = 'utf-8' | encoding = 'utf-8' | ||||
encoding_errors = 'surrogateescape' | encoding_errors = 'surrogateescape' | ||||
ENCODERS = { | ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { | ||||
'application/x-msgpack': msgpack_dumps, | 'application/x-msgpack': msgpack_dumps, | ||||
'application/json': json.dumps, | 'application/json': json_dumps, | ||||
} | } | ||||
def encode_data_server(data, content_type='application/x-msgpack'): | def encode_data_server( | ||||
encoded_data = ENCODERS[content_type](data) | data, content_type='application/x-msgpack', extra_type_encoders=None): | ||||
encoded_data = ENCODERS[content_type]( | |||||
data, extra_encoders=extra_type_encoders) | |||||
return Response( | return Response( | ||||
encoded_data, | encoded_data, | ||||
mimetype=content_type, | mimetype=content_type, | ||||
) | ) | ||||
def decode_request(request): | def decode_request(request, extra_decoders=None): | ||||
content_type = request.mimetype | content_type = request.mimetype | ||||
data = request.get_data() | data = request.get_data() | ||||
if not data: | if not data: | ||||
return {} | return {} | ||||
if content_type == 'application/x-msgpack': | if content_type == 'application/x-msgpack': | ||||
r = msgpack_loads(data) | r = msgpack_loads(data, extra_decoders=extra_decoders) | ||||
elif content_type == 'application/json': | elif content_type == 'application/json': | ||||
# XXX this .decode() is needed for py35. | # XXX this .decode() is needed for py35. | ||||
# Should not be needed any more with py37 | # Should not be needed any more with py37 | ||||
r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) | r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) | ||||
else: | else: | ||||
raise ValueError('Wrong content type `%s` for API request' | raise ValueError('Wrong content type `%s` for API request' | ||||
% content_type) | % content_type) | ||||
return r | return r | ||||
def error_handler(exception, encoder, status_code=500): | def error_handler(exception, encoder, status_code=500): | ||||
Show All 26 Lines | :param Any backend_class: | ||||
for API endpoints. | for API endpoints. | ||||
:param Optional[Callable[[], backend_class]] backend_factory: | :param Optional[Callable[[], backend_class]] backend_factory: | ||||
A function with no argument that returns an instance of | A function with no argument that returns an instance of | ||||
`backend_class`. If unset, defaults to calling `backend_class` | `backend_class`. If unset, defaults to calling `backend_class` | ||||
constructor directly. | constructor directly. | ||||
""" | """ | ||||
request_class = BytesRequest | request_class = BytesRequest | ||||
extra_type_encoders: List[Tuple[type, str, Callable]] = [] | |||||
"""Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` | |||||
to be able to serialize more object types.""" | |||||
extra_type_decoders: Dict[str, Callable] = {} | |||||
"""Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` | |||||
to be able to deserialize more object types.""" | |||||
def __init__(self, *args, backend_class=None, backend_factory=None, | def __init__(self, *args, backend_class=None, backend_factory=None, | ||||
**kwargs): | **kwargs): | ||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
self.backend_class = backend_class | self.backend_class = backend_class | ||||
if backend_class is not None: | if backend_class is not None: | ||||
if backend_factory is None: | if backend_factory is None: | ||||
backend_factory = backend_class | backend_factory = backend_class | ||||
for (meth_name, meth) in backend_class.__dict__.items(): | for (meth_name, meth) in backend_class.__dict__.items(): | ||||
if hasattr(meth, '_endpoint_path'): | if hasattr(meth, '_endpoint_path'): | ||||
self.__add_endpoint(meth_name, meth, backend_factory) | self.__add_endpoint(meth_name, meth, backend_factory) | ||||
def __add_endpoint(self, meth_name, meth, backend_factory): | def __add_endpoint(self, meth_name, meth, backend_factory): | ||||
from flask import request | from flask import request | ||||
@self.route('/'+meth._endpoint_path, methods=['POST']) | @self.route('/'+meth._endpoint_path, methods=['POST']) | ||||
@negotiate(MsgpackFormatter) | @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) | ||||
@negotiate(JSONFormatter) | @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) | ||||
@functools.wraps(meth) # Copy signature and doc | @functools.wraps(meth) # Copy signature and doc | ||||
def _f(): | def _f(): | ||||
# Call the actual code | # Call the actual code | ||||
obj_meth = getattr(backend_factory(), meth_name) | obj_meth = getattr(backend_factory(), meth_name) | ||||
kw = decode_request(request) | kw = decode_request( | ||||
request, extra_decoders=self.extra_type_decoders) | |||||
return obj_meth(**kw) | return obj_meth(**kw) |