diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -4,17 +4,20 @@ # See top-level LICENSE file for more information from collections import abc +import datetime import functools import inspect import json import logging import pickle import requests -import datetime +import traceback -from typing import Any, ClassVar, Optional, Type +from typing import Any, ClassVar, List, Optional, Type from flask import Flask, Request, Response, request, abort +from werkzeug.exceptions import HTTPException + from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) @@ -91,6 +94,15 @@ super().__init__() self.response = response + def __str__(self): + if self.args and isinstance(self.args[0], dict) \ + and 'type' in self.args[0] and 'args' in self.args[0]: + return ( + f'') + else: + return super().__str__() + def remote_api_endpoint(path): def dec(f): @@ -167,10 +179,18 @@ """The exception class to raise in case of communication error with the server.""" + reraise_exceptions: ClassVar[List[Type[Exception]]] = [] + """On server errors, if any of the exception classes in this list + has the same name as the error name, then the exception will + be instantiated and raised instead of a generic RemoteException.""" + def __init__(self, url, api_exception=None, - timeout=None, chunk_size=4096, **kwargs): + timeout=None, chunk_size=4096, + reraise_exceptions=None, **kwargs): if api_exception: self.api_exception = api_exception + if reraise_exceptions: + self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() @@ -242,8 +262,6 @@ error; do nothing otherwise """ - # XXX: unpickling below breaks language-independence and should be - # replaced by proper language-independent [de]serialization status_code = response.status_code status_class = response.status_code // 100 @@ -252,10 +270,21 @@ exception = None + # TODO: only old servers send pickled error; stop trying to unpickle + # after they are all upgraded try: if status_class == 4: data = decode_response(response) - exception = pickle.loads(data) + if isinstance(data, dict): + for exc_type in self.reraise_exceptions: + if exc_type.__name__ == data['exception']['type']: + exception = exc_type(*data['exception']['args']) + break + else: + exception = RemoteException(payload=data['exception'], + response=response) + else: + exception = pickle.loads(data) elif status_class == 5: data = decode_response(response) @@ -277,11 +306,8 @@ response=response) def _decode_response(self, response): - if response.status_code == 404: - return None - else: - self.raise_for_status(response) - return decode_response(response) + self.raise_for_status(response) + return decode_response(response) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) @@ -326,12 +352,23 @@ return r -def error_handler(exception, encoder): - # XXX: this breaks language-independence and should be - # replaced by proper serialization of errors +def error_handler(exception, encoder, status_code=500): logging.exception(exception) - response = encoder(pickle.dumps(exception)) - response.status_code = 400 + tb = traceback.format_exception(None, exception, exception.__traceback__) + error = { + 'exception': { + 'type': type(exception).__name__, + 'args': exception.args, + 'message': str(exception), + 'traceback': tb, + } + } + response = encoder(error) + if isinstance(exception, HTTPException): + response.status_code = exception.code + else: + # TODO: differentiate between server errors and client errors + response.status_code = status_code return response diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -6,7 +6,7 @@ import pytest from swh.core.api import remote_api_endpoint, RPCServerApp, RPCClient -from swh.core.api import error_handler, encode_data_server +from swh.core.api import error_handler, encode_data_server, RemoteException # this class is used on the server part @@ -79,7 +79,7 @@ def test_api_server_endpoint_missing(swh_rpc_client): # A 'missing' endpoint (server-side) should raise an exception # due to a 404, since at the end, we do a GET/POST an inexistent URL - with pytest.raises(Exception, match='404 Not Found'): + with pytest.raises(Exception, match='404 not found'): swh_rpc_client.not_on_server() @@ -98,5 +98,10 @@ def test_api_typeerror(swh_rpc_client): - with pytest.raises(TypeError, match='Did I pass through?'): + with pytest.raises(RemoteException) as exc_info: swh_rpc_client.raise_typeerror() + + assert exc_info.value.args[0]['type'] == 'TypeError' + assert exc_info.value.args[0]['args'] == ['Did I pass through?'] + assert str(exc_info.value) \ + == ""