diff --git a/swh/core/api.py b/swh/core/api.py --- a/swh/core/api.py +++ b/swh/core/api.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import collections import json import logging import pickle @@ -14,6 +15,10 @@ msgpack_dumps, msgpack_loads, SWHJSONDecoder) +class RemoteException(Exception): + pass + + class SWHRemoteAPI: """Proxy to an internal SWH API @@ -29,40 +34,61 @@ def _url(self, endpoint): return '%s%s' % (self.url, endpoint) - def post(self, endpoint, data): + def raw_post(self, endpoint, data, **opts): try: - response = self.session.post( + return self.session.post( self._url(endpoint), - data=encode_data(data), - headers={'content-type': 'application/x-msgpack'}, + data=data, + **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) - # XXX: this breaks language-independence and should be - # replaced by proper unserialization - if response.status_code == 400: - raise pickle.loads(decode_response(response)) - - return decode_response(response) - - def get(self, endpoint, data=None): + def raw_get(self, endpoint, params=None, **opts): try: - response = self.session.get( + return self.session.get( self._url(endpoint), - params=data, + params=params, + **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) + def post(self, endpoint, data, params=None): + data = encode_data(data) + response = self.raw_post( + endpoint, data, params=params, + headers={'content-type': 'application/x-msgpack'}) + return self._decode_response(response) + + def get(self, endpoint, params=None): + response = self.raw_get(endpoint, params=params) + return self._decode_response(response) + + def post_stream(self, endpoint, data, params=None): + if not isinstance(data, collections.Iterable): + raise ValueError("`data` must be Iterable") + response = self.raw_post(endpoint, data, params=params) + return self._decode_response(response) + + def get_stream(self, endpoint, params=None, chunk_size=4096): + response = self.raw_get(endpoint, params=params, stream=True) + return response.iter_content(chunk_size) + + def _decode_response(self, response): if response.status_code == 404: return None + if response.status_code == 500: + data = decode_response(response) + if 'exception_pickled' in data: + raise pickle.loads(data['exception_pickled']) + else: + raise RemoteException(data['exception']) # XXX: this breaks language-independence and should be # replaced by proper unserialization if response.status_code == 400: raise pickle.loads(decode_response(response)) - else: return decode_response(response) diff --git a/swh/core/api_async.py b/swh/core/api_async.py new file mode 100644 --- /dev/null +++ b/swh/core/api_async.py @@ -0,0 +1,56 @@ +import aiohttp.web +import asyncio +import json +import logging +import multidict +import pickle +import sys +import traceback + +from .serializers import msgpack_dumps, msgpack_loads, SWHJSONDecoder + + +def encode_data_server(data, **kwargs): + return aiohttp.web.Response( + body=msgpack_dumps(data), + headers=multidict.MultiDict({'Content-Type': 'application/x-msgpack'}), + **kwargs + ) + + +@asyncio.coroutine +def decode_request(request): + content_type = request.headers.get('Content-Type') + data = yield from request.read() + + if content_type == 'application/x-msgpack': + r = msgpack_loads(data) + elif content_type == 'application/json': + r = json.loads(data, cls=SWHJSONDecoder) + else: + raise ValueError('Wrong content type `%s` for API request' + % content_type) + return r + + +@asyncio.coroutine +def error_middleware(app, handler): + @asyncio.coroutine + def middleware_handler(request): + try: + return (yield from handler(request)) + except Exception as e: + if isinstance(e, aiohttp.web.HTTPException): + raise + logging.exception(e) + exception = traceback.format_exception(*sys.exc_info()) + res = {'exception': exception, + 'exception_pickled': pickle.dumps(e)} + return encode_data_server(res, status=500) + return middleware_handler + + +class SWHRemoteAPI(aiohttp.web.Application): + def __init__(self, *args, middlewares=(), **kwargs): + middlewares = (error_middleware,) + middlewares + super().__init__(*args, middlewares=middlewares, **kwargs)