diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ requests Flask systemd-python +negotiate diff --git a/swh/core/api.py b/swh/core/api.py --- a/swh/core/api.py +++ b/swh/core/api.py @@ -10,12 +10,43 @@ import logging import pickle import requests +import datetime from flask import Flask, Request, Response from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, SWHJSONDecoder) +from negotiate.flask import Formatter + +logger = logging.getLogger(__name__) + + +class SWHJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (datetime.datetime, datetime.date)): + return obj.isoformat() + if isinstance(obj, datetime.timedelta): + return str(obj) + # Let the base class default method raise the TypeError + return super().default(obj) + + +class JSONFormatter(Formatter): + format = 'json' + mimetypes = ['application/json'] + + def render(self, obj): + return json.dumps(obj, cls=SWHJSONEncoder) + + +class MsgpackFormatter(Formatter): + format = 'msgpack' + mimetypes = ['application/x-msgpack'] + + def render(self, obj): + return msgpack_dumps(obj) + class RemoteException(Exception): pass @@ -124,21 +155,28 @@ data = encode_data(data) response = self.raw_post( endpoint, data, params=params, - headers={'content-type': 'application/x-msgpack'}) + headers={'content-type': 'application/x-msgpack', + 'accept': 'application/x-msgpack'}) return self._decode_response(response) def get(self, endpoint, params=None): - response = self.raw_get(endpoint, params=params) + response = self.raw_get( + endpoint, params=params, + headers={'accept': 'application/x-msgpack'}) return self._decode_response(response) def post_stream(self, endpoint, data, params=None): if not isinstance(data, collections.Iterable): raise ValueError("`data` must be Iterable") - response = self.raw_post(endpoint, data, params=params) + response = self.raw_post( + endpoint, data, params=params, + headers={'accept': 'application/x-msgpack'}) + return self._decode_response(response) def get_stream(self, endpoint, params=None, chunk_size=4096): - response = self.raw_get(endpoint, params=params, stream=True) + response = self.raw_get(endpoint, params=params, stream=True, + headers={'accept': 'application/x-msgpack'}) return response.iter_content(chunk_size) def _decode_response(self, response): @@ -171,16 +209,25 @@ encoding_errors = 'surrogateescape' -def encode_data_server(data): +ENCODERS = { + 'application/x-msgpack': msgpack_dumps, + 'application/json': json.dumps, +} + + +def encode_data_server(data, content_type='application/x-msgpack'): + encoded_data = ENCODERS[content_type](data) return Response( - msgpack_dumps(data), - mimetype='application/x-msgpack', - ) + encoded_data, + mimetype=content_type, + ) def decode_request(request): content_type = request.mimetype data = request.get_data() + if not data: + return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data)