diff --git a/requirements.txt b/requirements.txt index 50a9d83..0b91831 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ arrow aiohttp msgpack-python psycopg2 python-dateutil vcversioner PyYAML requests Flask systemd-python -negotiate +decorator diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 015a56b..a62316a 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,284 +1,309 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import collections import functools import inspect import json import logging import pickle import requests import datetime -from flask import Flask, Request, Response +from flask import Flask, Request, Response, request, abort from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) -from negotiate.flask import Formatter +from .negotiate import (Formatter as FormatterBase, + Negotiator as NegotiatorBase, + negotiate as _negotiate) + logger = logging.getLogger(__name__) +# support for content negotation + +class Negotiator(NegotiatorBase): + def best_mimetype(self): + return request.accept_mimetypes.best_match( + self.accept_mimetypes, 'text/html') + + def _abort(self, status_code, err=None): + return abort(status_code, err) + + +def negotiate(formatter_cls, *args, **kwargs): + return _negotiate(Negotiator, formatter_cls, *args, **kwargs) + + +class Formatter(FormatterBase): + def _make_response(self, body, content_type): + return Response(body, content_type=content_type) + + class SWHJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.datetime, datetime.date)): return obj.isoformat() if isinstance(obj, datetime.timedelta): return str(obj) # Let the base class default method raise the TypeError return super().default(obj) class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): return json.dumps(obj, cls=SWHJSONEncoder) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): return msgpack_dumps(obj) +# base API classes + class RemoteException(Exception): pass def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class MetaSWHRemoteAPI(type): """Metaclass for SWHRemoteAPI, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class SWHRemoteAPI(metaclass=MetaSWHRemoteAPI): """Proxy to an internal SWH API """ backend_class = None """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" def __init__(self, api_exception, url, timeout=None): super().__init__() self.api_exception = api_exception base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() self.timeout = timeout def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_post(self, endpoint, data, **opts): if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return self.session.post( self._url(endpoint), data=data, **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def raw_get(self, endpoint, params=None, **opts): if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return self.session.get( self._url(endpoint), params=params, **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, params=None): data = encode_data(data) response = self.raw_post( endpoint, data, params=params, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}) return self._decode_response(response) def get(self, endpoint, params=None): response = self.raw_get( endpoint, params=params, headers={'accept': 'application/x-msgpack'}) return self._decode_response(response) def post_stream(self, endpoint, data, params=None): if not isinstance(data, collections.Iterable): raise ValueError("`data` must be Iterable") response = self.raw_post( endpoint, data, params=params, headers={'accept': 'application/x-msgpack'}) return self._decode_response(response) def get_stream(self, endpoint, params=None, chunk_size=4096): response = self.raw_get(endpoint, params=params, stream=True, headers={'accept': 'application/x-msgpack'}) return response.iter_content(chunk_size) def _decode_response(self, response): if response.status_code == 404: return None if response.status_code == 500: data = decode_response(response) if 'exception_pickled' in data: raise pickle.loads(data['exception_pickled']) else: raise RemoteException(data['exception']) # XXX: this breaks language-independence and should be # replaced by proper unserialization if response.status_code == 400: raise pickle.loads(decode_response(response)) elif response.status_code != 200: raise RemoteException( "Unexpected status code for API request: %s (%s)" % ( response.status_code, response.content, ) ) return decode_response(response) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' ENCODERS = { 'application/x-msgpack': msgpack_dumps, 'application/json': json.dumps, } def encode_data_server(data, content_type='application/x-msgpack'): encoded_data = ENCODERS[content_type](data) return Response( encoded_data, mimetype=content_type, ) def decode_request(request): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': r = json.loads(data, cls=SWHJSONDecoder) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder): # XXX: this breaks language-independence and should be # replaced by proper serialization of errors logging.exception(exception) response = encoder(pickle.dumps(exception)) response.status_code = 400 return response class SWHServerAPIApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Callable[[], backend_class] backend_factory: A function with no argument that returns an instance of `backend_class`.""" request_class = BytesRequest def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) if backend_class is not None: if backend_factory is None: raise TypeError('Missing argument backend_factory') for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) return encode_data_server(obj_meth(**decode_request(request))) diff --git a/swh/core/api/negotiate.py b/swh/core/api/negotiate.py new file mode 100644 index 0000000..3fce1b6 --- /dev/null +++ b/swh/core/api/negotiate.py @@ -0,0 +1,152 @@ +# This code is a partial and adapted copy of +# https://github.com/nickstenning/negotiate +# +# Copyright 2012-2013 Nick Stenning +# 2019 The Software Heritage developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from collections import defaultdict +from decorator import decorator + +from inspect import getcallargs + + +class FormatterNotFound(Exception): + pass + + +class Formatter: + format = None + mimetypes = [] + + def __init__(self, request_mimetype=None): + if request_mimetype is None or request_mimetype not in self.mimetypes: + try: + self.response_mimetype = self.mimetypes[0] + except IndexError: + raise NotImplementedError( + "%s.mimetypes should be a non-empty list" % + self.__class__.__name__) + else: + self.response_mimetype = request_mimetype + + def configure(self): + pass + + def render(self, obj): + raise NotImplementedError( + "render() should be implemented by Formatter subclasses") + + def __call__(self, obj): + return self._make_response( + self.render(obj), content_type=self.response_mimetype) + + def _make_response(self, body, content_type): + raise NotImplementedError( + "_make_response() should be implemented by " + "framework-specific subclasses of Formatter" + ) + + +class Negotiator: + + def __init__(self, func): + self.func = func + self._formatters = [] + self._formatters_by_format = defaultdict(list) + self._formatters_by_mimetype = defaultdict(list) + + def __call__(self, *args, **kwargs): + result = self.func(*args, **kwargs) + format = getcallargs(self.func, *args, **kwargs).get('format') + mimetype = self.best_mimetype() + + try: + formatter = self.get_formatter(format, mimetype) + except FormatterNotFound as e: + return self._abort(404, str(e)) + + return formatter(result) + + def register_formatter(self, formatter, *args, **kwargs): + self._formatters.append(formatter) + self._formatters_by_format[formatter.format].append( + (formatter, args, kwargs)) + for mimetype in formatter.mimetypes: + self._formatters_by_mimetype[mimetype].append( + (formatter, args, kwargs)) + + def get_formatter(self, format=None, mimetype=None): + if format is None and mimetype is None: + raise TypeError( + "get_formatter expects one of the 'format' or 'mimetype' " + "kwargs to be set") + + if format is not None: + try: + # the first added will be the most specific + formatter_cls, args, kwargs = ( + self._formatters_by_format[format][0]) + except IndexError: + raise FormatterNotFound( + "Formatter for format '%s' not found!" % format) + elif mimetype is not None: + try: + # the first added will be the most specific + formatter_cls, args, kwargs = ( + self._formatters_by_mimetype[mimetype][0]) + except IndexError: + raise FormatterNotFound( + "Formatter for mimetype '%s' not found!" % mimetype) + + formatter = formatter_cls(request_mimetype=mimetype) + formatter.configure(*args, **kwargs) + return formatter + + @property + def accept_mimetypes(self): + return [m for f in self._formatters for m in f.mimetypes] + + def best_mimetype(self): + raise NotImplementedError( + "best_mimetype() should be implemented in " + "framework-specific subclasses of Negotiator" + ) + + def _abort(self, status_code, err=None): + raise NotImplementedError( + "_abort() should be implemented in framework-specific " + "subclasses of Negotiator" + ) + + +def negotiate(negotiator_cls, formatter_cls, *args, **kwargs): + def _negotiate(f, *args, **kwargs): + return f.negotiator(*args, **kwargs) + + def decorate(f): + if not hasattr(f, 'negotiator'): + f.negotiator = negotiator_cls(f) + + f.negotiator.register_formatter(formatter_cls, *args, **kwargs) + return decorator(_negotiate, f) + + return decorate