diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -9,7 +9,6 @@ import logging import pickle import requests -import traceback from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, @@ -21,7 +20,8 @@ from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, - json_dumps, json_loads) + json_dumps, json_loads, + exception_to_dict) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, @@ -362,16 +362,7 @@ def error_handler(exception, encoder, status_code=500): logging.exception(exception) - tb = traceback.format_exception(None, exception, exception.__traceback__) - error = { - 'exception': { - 'type': type(exception).__name__, - 'args': exception.args, - 'message': str(exception), - 'traceback': tb, - } - } - response = encoder(error) + response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,16 +1,19 @@ -import json -import logging -import pickle -import sys -import traceback +# Copyright (C) 2017-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from collections import OrderedDict -import multidict +import logging +from typing import Tuple, Type import aiohttp.web from deprecated import deprecated +import multidict from .serializers import msgpack_dumps, msgpack_loads -from .serializers import SWHJSONDecoder, SWHJSONEncoder +from .serializers import json_dumps, json_loads +from .serializers import exception_to_dict from aiohttp_utils import negotiation, Response @@ -32,7 +35,7 @@ def render_json(request, data): - return json.dumps(data, cls=SWHJSONEncoder) + return json_dumps(data) async def decode_request(request): @@ -43,7 +46,7 @@ if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': - r = json.loads(data.decode(), cls=SWHJSONDecoder) + r = json_loads(data) else: raise ValueError('Wrong content type `%s` for API request' % content_type) @@ -58,14 +61,20 @@ if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) - exception = traceback.format_exception(*sys.exc_info()) - res = {'exception': exception, - 'exception_pickled': pickle.dumps(e)} - return encode_data_server(res, status=500) + res = exception_to_dict(e) + if isinstance(e, app.client_exception_classes): + status = 400 + else: + status = 500 + return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): + client_exception_classes: Tuple[Type[Exception], ...] = () + """Exceptions that should be handled as a client error (eg. object not + found, invalid argument)""" + def __init__(self, *args, middlewares=(), **kwargs): middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -6,6 +6,7 @@ import base64 import datetime import json +import traceback import types from uuid import UUID @@ -212,3 +213,15 @@ except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) + + +def exception_to_dict(exception): + tb = traceback.format_exception(None, exception, exception.__traceback__) + return { + 'exception': { + 'type': type(exception).__name__, + 'args': exception.args, + 'message': str(exception), + 'traceback': tb, + } + } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -18,6 +18,14 @@ pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] +class TestServerException(Exception): + pass + + +class TestClientError(Exception): + pass + + async def root(request): return Response('toor') @@ -43,6 +51,14 @@ return Response(data) +async def server_exception(request): + raise TestServerException() + + +async def client_error(request): + raise TestClientError() + + async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) @@ -59,9 +75,12 @@ @pytest.fixture def async_app(): app = RPCServerApp() + app.client_exception_classes = (TestClientError,) app.router.add_route('GET', '/', root) app.router.add_route('GET', '/struct', struct) app.router.add_route('POST', '/echo', echo) + app.router.add_route('GET', '/server_exception', server_exception) + app.router.add_route('GET', '/client_error', client_error) app.router.add_route('POST', '/echo-no-nego', echo_no_nego) return app @@ -78,6 +97,24 @@ assert value == 'toor' +async def test_get_server_exception(async_app, aiohttp_client) -> None: + cli = await aiohttp_client(async_app) + resp = await cli.get('/server_exception') + assert resp.status == 500 + data = await resp.read() + data = msgpack.unpackb(data) + assert data[b'exception'][b'type'] == b'TestServerException' + + +async def test_get_client_error(async_app, aiohttp_client) -> None: + cli = await aiohttp_client(async_app) + resp = await cli.get('/client_error') + assert resp.status == 400 + data = await resp.read() + data = msgpack.unpackb(data) + assert data[b'exception'][b'type'] == b'TestClientError' + + async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'):