diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -4,23 +4,24 @@ # See top-level LICENSE file for more information from collections import abc -import datetime import functools import inspect -import json import logging import pickle import requests import traceback -from typing import Any, ClassVar, List, Optional, Type +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, +) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, SWHJSONDecoder) + msgpack_dumps, msgpack_loads, + json_dumps, json_loads) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, @@ -49,15 +50,8 @@ def _make_response(self, body, content_type): return Response(body, content_type=content_type) - -class SWHJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return str(obj) - # Let the base class default method raise the TypeError - return super().default(obj) + def configure(self, extra_encoders): + self.extra_encoders = extra_encoders class JSONFormatter(Formatter): @@ -65,7 +59,7 @@ mimetypes = ['application/json'] def render(self, obj): - return json.dumps(obj, cls=SWHJSONEncoder) + return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): @@ -73,7 +67,7 @@ mimetypes = ['application/x-msgpack'] def render(self, obj): - return msgpack_dumps(obj) + return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes @@ -184,6 +178,13 @@ has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): @@ -223,9 +224,9 @@ def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): - data = (encode_data(x) for x in data) + data = (self._encode_data(x) for x in data) else: - data = encode_data(data) + data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, @@ -239,6 +240,9 @@ else: return self._decode_response(response) + def _encode_data(self, data): + return encode_data(data, extra_encoders=self.extra_type_encoders) + post_stream = post def get(self, endpoint, **opts): @@ -274,7 +278,7 @@ # after they are all upgraded try: if status_class == 4: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: @@ -287,7 +291,7 @@ exception = pickle.loads(data) elif status_class == 5: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: @@ -305,9 +309,11 @@ payload=f'API HTTP error: {status_code} {response.content}', response=response) - def _decode_response(self, response): - self.raise_for_status(response) - return decode_response(response) + def _decode_response(self, response, check_status=True): + if check_status: + self.raise_for_status(response) + return decode_response( + response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) @@ -319,32 +325,34 @@ encoding_errors = 'surrogateescape' -ENCODERS = { +ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, - 'application/json': json.dumps, + 'application/json': json_dumps, } -def encode_data_server(data, content_type='application/x-msgpack'): - encoded_data = ENCODERS[content_type](data) +def encode_data_server( + data, content_type='application/x-msgpack', extra_type_encoders=None): + encoded_data = ENCODERS[content_type]( + data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) -def decode_request(request): +def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) + r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) @@ -387,6 +395,13 @@ """ request_class = BytesRequest + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) @@ -403,11 +418,12 @@ from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) - @negotiate(MsgpackFormatter) - @negotiate(JSONFormatter) + @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) + @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request(request) + kw = decode_request( + request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -57,7 +57,7 @@ r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder, + r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text @@ -162,6 +162,16 @@ return self.decode_data(data), index +def json_dumps(data: Any, extra_encoders=None) -> str: + return json.dumps(data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + + +def json_loads(data: str, extra_decoders=None) -> Any: + return json.loads(data, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) + + def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -8,20 +8,28 @@ from swh.core.api import remote_api_endpoint, RPCClient +from .test_serializers import ExtraType, extra_encoders, extra_decoders + @pytest.fixture def rpc_client(requests_mock): class TestStorage: @remote_api_endpoint('test_endpoint_url') def test_endpoint(self, test_data, db=None, cur=None): - return 'egg' + ... @remote_api_endpoint('path/to/endpoint') def something(self, data, db=None, cur=None): - return 'spam' + ... + + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + ... class Testclient(RPCClient): backend_class = TestStorage + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders def callback(request, context): assert request.headers['Content-Type'] == 'application/x-msgpack' @@ -30,6 +38,10 @@ context.content = b'\xa3egg' elif request.path == '/path/to/endpoint': context.content = b'\xa4spam' + elif request.path == '/serializer_test': + context.content = ( + b'\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\x81\xa4spam\xa3egg\xa3qux') else: assert False return context.content @@ -54,3 +66,8 @@ assert res == 'spam' res = rpc_client.something(data='whatever') assert res == 'spam' + + +def test_client_extra_serializers(rpc_client): + res = rpc_client.serializer_test(['foo', ExtraType('bar', b'baz')]) + assert res == ExtraType({'spam': 'egg'}, 'qux') diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py --- a/swh/core/api/tests/test_rpc_server.py +++ b/swh/core/api/tests/test_rpc_server.py @@ -11,6 +11,13 @@ from swh.core.api import remote_api_endpoint, RPCServerApp +from .test_serializers import ExtraType, extra_encoders, extra_decoders + + +class MyRPCServerApp(RPCServerApp): + extra_type_encoders = extra_encoders + extra_type_decoders = extra_decoders + @pytest.fixture def app(): @@ -24,7 +31,12 @@ def something(self, data, db=None, cur=None): return data - return RPCServerApp('testapp', backend_class=TestStorage) + @remote_api_endpoint('serializer_test') + def serializer_test(self, data, db=None, cur=None): + assert data == ['foo', ExtraType('bar', b'baz')] + return ExtraType({'spam': 'egg'}, 'qux') + + return MyRPCServerApp('testapp', backend_class=TestStorage) def test_api_endpoint(flask_app_client): @@ -71,3 +83,18 @@ assert res.status_code == 200 assert res.mimetype == 'application/x-msgpack' assert res.data == b'\xa3egg' + + +def test_rpc_server_extra_serializers(flask_app_client): + res = flask_app_client.post( + url_for('serializer_test'), + headers=[('Content-Type', 'application/x-msgpack'), + ('Accept', 'application/x-msgpack')], + data=b'\x81\xa4data\x92\xa3foo\x82\xc4\x07swhtype\xa9extratype' + b'\xc4\x01d\x92\xa3bar\xc4\x03baz') + + assert res.status_code == 200 + assert res.mimetype == 'application/x-msgpack' + assert res.data == ( + b'\x82\xc4\x07swhtype\xa9extratype\xc4' + b'\x01d\x92\x81\xa4spam\xa3egg\xa3qux') diff --git a/swh/core/api/tests/test_serializers.py b/swh/core/api/tests/test_serializers.py --- a/swh/core/api/tests/test_serializers.py +++ b/swh/core/api/tests/test_serializers.py @@ -5,6 +5,7 @@ import datetime import json +from typing import Any, Callable, List, Tuple import unittest from uuid import UUID @@ -30,10 +31,11 @@ return f'ExtraType({self.arg1}, {self.arg2})' def __eq__(self, other): - return (self.arg1, self.arg2) == (other.arg1, other.arg2) + return isinstance(other, ExtraType) \ + and (self.arg1, self.arg2) == (other.arg1, other.arg2) -extra_encoders = [ +extra_encoders: List[Tuple[type, str, Callable[..., Any]]] = [ (ExtraType, 'extratype', lambda o: (o.arg1, o.arg2)) ]