diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -4,23 +4,24 @@ # See top-level LICENSE file for more information from collections import abc -import datetime import functools import inspect -import json import logging import pickle import requests import traceback -from typing import Any, ClassVar, List, Optional, Type +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, +) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, - msgpack_dumps, msgpack_loads, SWHJSONDecoder) + msgpack_dumps, msgpack_loads, + json_dumps, json_loads) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, @@ -49,15 +50,8 @@ def _make_response(self, body, content_type): return Response(body, content_type=content_type) - -class SWHJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, (datetime.datetime, datetime.date)): - return obj.isoformat() - if isinstance(obj, datetime.timedelta): - return str(obj) - # Let the base class default method raise the TypeError - return super().default(obj) + def configure(self, extra_encoders): + self.extra_encoders = extra_encoders class JSONFormatter(Formatter): @@ -65,7 +59,7 @@ mimetypes = ['application/json'] def render(self, obj): - return json.dumps(obj, cls=SWHJSONEncoder) + return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): @@ -73,7 +67,7 @@ mimetypes = ['application/x-msgpack'] def render(self, obj): - return msgpack_dumps(obj) + return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes @@ -184,6 +178,13 @@ has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): @@ -223,9 +224,9 @@ def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): - data = (encode_data(x) for x in data) + data = (self._encode_data(x) for x in data) else: - data = encode_data(data) + data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, @@ -239,6 +240,9 @@ else: return self._decode_response(response) + def _encode_data(self, data): + return encode_data(data, extra_encoders=self.extra_type_encoders) + post_stream = post def get(self, endpoint, **opts): @@ -274,7 +278,7 @@ # after they are all upgraded try: if status_class == 4: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: @@ -287,7 +291,7 @@ exception = pickle.loads(data) elif status_class == 5: - data = decode_response(response) + data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: @@ -305,9 +309,11 @@ payload=f'API HTTP error: {status_code} {response.content}', response=response) - def _decode_response(self, response): - self.raise_for_status(response) - return decode_response(response) + def _decode_response(self, response, check_status=True): + if check_status: + self.raise_for_status(response) + return decode_response( + response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) @@ -319,32 +325,34 @@ encoding_errors = 'surrogateescape' -ENCODERS = { +ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, - 'application/json': json.dumps, + 'application/json': json_dumps, } -def encode_data_server(data, content_type='application/x-msgpack'): - encoded_data = ENCODERS[content_type](data) +def encode_data_server( + data, content_type='application/x-msgpack', extra_type_encoders=None): + encoded_data = ENCODERS[content_type]( + data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) -def decode_request(request): +def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': - r = msgpack_loads(data) + r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 - r = json.loads(data.decode('utf-8'), cls=SWHJSONDecoder) + r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) @@ -387,6 +395,13 @@ """ request_class = BytesRequest + extra_type_encoders: List[Tuple[type, str, Callable]] = [] + """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` + to be able to serialize more object types.""" + extra_type_decoders: Dict[str, Callable] = {} + """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` + to be able to deserialize more object types.""" + def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) @@ -403,11 +418,12 @@ from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) - @negotiate(MsgpackFormatter) - @negotiate(JSONFormatter) + @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) + @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) - kw = decode_request(request) + kw = decode_request( + request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -57,7 +57,7 @@ r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith('application/json'): - r = json.loads(response.text, cls=SWHJSONDecoder, + r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text @@ -162,6 +162,16 @@ return self.decode_data(data), index +def json_dumps(data: Any, extra_encoders=None) -> str: + return json.dumps(data, cls=SWHJSONEncoder, + extra_encoders=extra_encoders) + + +def json_loads(data: str, extra_decoders=None) -> Any: + return json.loads(data, cls=SWHJSONDecoder, + extra_decoders=extra_decoders) + + def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS