diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ requests Flask systemd-python -negotiate +decorator diff --git a/swh/core/api.py b/swh/core/api/__init__.py rename from swh/core/api.py rename to swh/core/api/__init__.py --- a/swh/core/api.py +++ b/swh/core/api/__init__.py @@ -12,16 +12,39 @@ import requests import datetime -from flask import Flask, Request, Response +from flask import Flask, Request, Response, request, abort from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) -from negotiate.flask import Formatter +from .negotiate import (Formatter as FormatterBase, + Negotiator as NegotiatorBase, + negotiate as _negotiate) + logger = logging.getLogger(__name__) +# support for content negotation + +class Negotiator(NegotiatorBase): + def best_mimetype(self): + return request.accept_mimetypes.best_match( + self.accept_mimetypes, 'text/html') + + def _abort(self, status_code, err=None): + return abort(status_code, err) + + +def negotiate(formatter_cls, *args, **kwargs): + return _negotiate(Negotiator, formatter_cls, *args, **kwargs) + + +class Formatter(FormatterBase): + def _make_response(self, body, content_type): + return Response(body, content_type=content_type) + + class SWHJSONEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, (datetime.datetime, datetime.date)): @@ -48,6 +71,8 @@ return msgpack_dumps(obj) +# base API classes + class RemoteException(Exception): pass diff --git a/swh/core/api_async.py b/swh/core/api/asynchronous.py copy from swh/core/api_async.py copy to swh/core/api/asynchronous.py diff --git a/swh/core/api/negotiate.py b/swh/core/api/negotiate.py new file mode 100644 --- /dev/null +++ b/swh/core/api/negotiate.py @@ -0,0 +1,152 @@ +# This code is a partial and adapted copy of +# https://github.com/nickstenning/negotiate +# +# Copyright 2012-2013 Nick Stenning +# 2019 Software Heritage +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from collections import defaultdict +from decorator import decorator + +from inspect import getcallargs + + +class FormatterNotFound(Exception): + pass + + +class Formatter: + format = None + mimetypes = [] + + def __init__(self, request_mimetype=None): + if request_mimetype is None or request_mimetype not in self.mimetypes: + try: + self.response_mimetype = self.mimetypes[0] + except IndexError: + raise NotImplementedError( + "%s.mimetypes should be a non-empty list" % + self.__class__.__name__) + else: + self.response_mimetype = request_mimetype + + def configure(self): + pass + + def render(self, obj): + raise NotImplementedError( + "render() should be implemented by Formatter subclasses") + + def __call__(self, obj): + return self._make_response( + self.render(obj), content_type=self.response_mimetype) + + def _make_response(self, body, content_type): + raise NotImplementedError( + "_make_response() should be implemented by " + "framework-specific subclasses of Formatter" + ) + + +class Negotiator: + + def __init__(self, func): + self.func = func + self._formatters = [] + self._formatters_by_format = defaultdict(list) + self._formatters_by_mimetype = defaultdict(list) + + def __call__(self, *args, **kwargs): + result = self.func(*args, **kwargs) + format = getcallargs(self.func, *args, **kwargs).get('format') + mimetype = self.best_mimetype() + + try: + formatter = self.get_formatter(format, mimetype) + except FormatterNotFound as e: + return self._abort(404, str(e)) + + return formatter(result) + + def register_formatter(self, formatter, *args, **kwargs): + self._formatters.append(formatter) + self._formatters_by_format[formatter.format].append( + (formatter, args, kwargs)) + for mimetype in formatter.mimetypes: + self._formatters_by_mimetype[mimetype].append( + (formatter, args, kwargs)) + + def get_formatter(self, format=None, mimetype=None): + if format is None and mimetype is None: + raise TypeError( + "get_formatter expects one of the 'format' or 'mimetype' " + "kwargs to be set") + + if format is not None: + try: + # the first added will be the most specific + formatter_cls, args, kwargs = ( + self._formatters_by_format[format][0]) + except IndexError: + raise FormatterNotFound( + "Formatter for format '%s' not found!" % format) + elif mimetype is not None: + try: + # the first added will be the most specific + formatter_cls, args, kwargs = ( + self._formatters_by_mimetype[mimetype][0]) + except IndexError: + raise FormatterNotFound( + "Formatter for mimetype '%s' not found!" % mimetype) + + formatter = formatter_cls(request_mimetype=mimetype) + formatter.configure(*args, **kwargs) + return formatter + + @property + def accept_mimetypes(self): + return [m for f in self._formatters for m in f.mimetypes] + + def best_mimetype(self): + raise NotImplementedError( + "best_mimetype() should be implemented in " + "framework-specific subclasses of Negotiator" + ) + + def _abort(self, status_code, err=None): + raise NotImplementedError( + "_abort() should be implemented in framework-specific " + "subclasses of Negotiator" + ) + + +def negotiate(negotiator_cls, formatter_cls, *args, **kwargs): + def _negotiate(f, *args, **kwargs): + return f.negotiator(*args, **kwargs) + + def decorate(f): + if not hasattr(f, 'negotiator'): + f.negotiator = negotiator_cls(f) + + f.negotiator.register_formatter(formatter_cls, *args, **kwargs) + return decorator(_negotiate, f) + + return decorate diff --git a/swh/core/serializers.py b/swh/core/api/serializers.py rename from swh/core/serializers.py rename to swh/core/api/serializers.py diff --git a/swh/core/api_async.py b/swh/core/api_async.py --- a/swh/core/api_async.py +++ b/swh/core/api_async.py @@ -1,56 +1 @@ -import aiohttp.web -import asyncio -import json -import logging -import multidict -import pickle -import sys -import traceback - -from .serializers import msgpack_dumps, msgpack_loads, SWHJSONDecoder - - -def encode_data_server(data, **kwargs): - return aiohttp.web.Response( - body=msgpack_dumps(data), - headers=multidict.MultiDict({'Content-Type': 'application/x-msgpack'}), - **kwargs - ) - - -@asyncio.coroutine -def decode_request(request): - content_type = request.headers.get('Content-Type') - data = yield from request.read() - - if content_type == 'application/x-msgpack': - r = msgpack_loads(data) - elif content_type == 'application/json': - r = json.loads(data, cls=SWHJSONDecoder) - else: - raise ValueError('Wrong content type `%s` for API request' - % content_type) - return r - - -@asyncio.coroutine -def error_middleware(app, handler): - @asyncio.coroutine - def middleware_handler(request): - try: - return (yield from handler(request)) - except Exception as e: - if isinstance(e, aiohttp.web.HTTPException): - raise - logging.exception(e) - exception = traceback.format_exception(*sys.exc_info()) - res = {'exception': exception, - 'exception_pickled': pickle.dumps(e)} - return encode_data_server(res, status=500) - return middleware_handler - - -class SWHRemoteAPI(aiohttp.web.Application): - def __init__(self, *args, middlewares=(), **kwargs): - middlewares = (error_middleware,) + middlewares - super().__init__(*args, middlewares=middlewares, **kwargs) +from swh.core.api.asynchronous import * # noqa, for bw compat diff --git a/swh/core/tests/test_serializers.py b/swh/core/tests/test_serializers.py --- a/swh/core/tests/test_serializers.py +++ b/swh/core/tests/test_serializers.py @@ -10,7 +10,7 @@ import arrow -from swh.core.serializers import ( +from swh.core.api.serializers import ( SWHJSONDecoder, SWHJSONEncoder, msgpack_dumps,