diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index dc3bdea..2bba4bb 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,429 +1,420 @@ -# Copyright (C) 2015-2017 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import abc import functools import inspect import logging import pickle import requests -import traceback from typing import ( Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, ) from flask import Flask, Request, Response, request, abort from werkzeug.exceptions import HTTPException from .serializers import (decode_response, encode_data_client as encode_data, msgpack_dumps, msgpack_loads, - json_dumps, json_loads) + json_dumps, json_loads, + exception_to_dict) from .negotiation import (Formatter as FormatterBase, Negotiator as NegotiatorBase, negotiate as _negotiate) logger = logging.getLogger(__name__) # support for content negotiation class Negotiator(NegotiatorBase): def best_mimetype(self): return request.accept_mimetypes.best_match( self.accept_mimetypes, 'application/json') def _abort(self, status_code, err=None): return abort(status_code, err) def negotiate(formatter_cls, *args, **kwargs): return _negotiate(Negotiator, formatter_cls, *args, **kwargs) class Formatter(FormatterBase): def _make_response(self, body, content_type): return Response(body, content_type=content_type) def configure(self, extra_encoders): self.extra_encoders = extra_encoders class JSONFormatter(Formatter): format = 'json' mimetypes = ['application/json'] def render(self, obj): return json_dumps(obj, extra_encoders=self.extra_encoders) class MsgpackFormatter(Formatter): format = 'msgpack' mimetypes = ['application/x-msgpack'] def render(self, obj): return msgpack_dumps(obj, extra_encoders=self.extra_encoders) # base API classes class RemoteException(Exception): """raised when remote returned an out-of-band failure notification, e.g., as a HTTP status code or serialized exception Attributes: response: HTTP response corresponding to the failure """ def __init__(self, payload: Optional[Any] = None, response: Optional[requests.Response] = None): if payload is not None: super().__init__(payload) else: super().__init__() self.response = response def __str__(self): if self.args and isinstance(self.args[0], dict) \ and 'type' in self.args[0] and 'args' in self.args[0]: return ( f'') else: return super().__str__() def remote_api_endpoint(path): def dec(f): f._endpoint_path = path return f return dec class APIError(Exception): """API Error""" def __str__(self): return ('An unexpected error occurred in the backend: {}' .format(self.args)) class MetaRPCClient(type): """Metaclass for RPCClient, which adds a method for each endpoint of the database it is designed to access. See for example :class:`swh.indexer.storage.api.client.RemoteStorage`""" def __new__(cls, name, bases, attributes): # For each method wrapped with @remote_api_endpoint in an API backend # (eg. :class:`swh.indexer.storage.IndexerStorage`), add a new # method in RemoteStorage, with the same documentation. # # Note that, despite the usage of decorator magic (eg. functools.wrap), # this never actually calls an IndexerStorage method. backend_class = attributes.get('backend_class', None) for base in bases: if backend_class is not None: break backend_class = getattr(base, 'backend_class', None) if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): cls.__add_endpoint(meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod def __add_endpoint(meth_name, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc def meth_(*args, **kwargs): # Match arguments and parameters post_data = inspect.getcallargs( wrapped_meth, *args, **kwargs) # Remove arguments that should not be passed self = post_data.pop('self') post_data.pop('cur', None) post_data.pop('db', None) # Send the request. return self.post(meth._endpoint_path, post_data) attributes[meth_name] = meth_ class RPCClient(metaclass=MetaRPCClient): """Proxy to an internal SWH RPC """ backend_class = None # type: ClassVar[Optional[type]] """For each method of `backend_class` decorated with :func:`remote_api_endpoint`, a method with the same prototype and docstring will be added to this class. Calls to this new method will be translated into HTTP requests to a remote server. This backend class will never be instantiated, it only serves as a template.""" api_exception = APIError # type: ClassVar[Type[Exception]] """The exception class to raise in case of communication error with the server.""" reraise_exceptions: ClassVar[List[Type[Exception]]] = [] """On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException.""" extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__(self, url, api_exception=None, timeout=None, chunk_size=4096, reraise_exceptions=None, **kwargs): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() adapter = requests.adapters.HTTPAdapter( max_retries=kwargs.get('max_retries', 3), pool_connections=kwargs.get('pool_connections', 20), pool_maxsize=kwargs.get('pool_maxsize', 100)) self.session.mount(self.url, adapter) self.timeout = timeout self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) def raw_verb(self, verb, endpoint, **opts): if 'chunk_size' in opts: # if the chunk_size argument has been passed, consider the user # also wants stream=True, otherwise, what's the point. opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: return getattr(self.session, verb)( self._url(endpoint), **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) def post(self, endpoint, data, **opts): if isinstance(data, (abc.Iterator, abc.Generator)): data = (self._encode_data(x) for x in data) else: data = self._encode_data(data) chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', 'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def _encode_data(self, data): return encode_data(data, extra_encoders=self.extra_type_encoders) post_stream = post def get(self, endpoint, **opts): chunk_size = opts.pop('chunk_size', self.chunk_size) response = self.raw_verb( 'get', endpoint, headers={'accept': 'application/x-msgpack'}, **opts) if opts.get('stream') or \ response.headers.get('transfer-encoding') == 'chunked': self.raise_for_status(response) return response.iter_content(chunk_size) else: return self._decode_response(response) def get_stream(self, endpoint, **opts): return self.get(endpoint, stream=True, **opts) def raise_for_status(self, response) -> None: """check response HTTP status code and raise an exception if it denotes an error; do nothing otherwise """ status_code = response.status_code status_class = response.status_code // 100 if status_code == 404: raise RemoteException(payload='404 not found', response=response) exception = None # TODO: only old servers send pickled error; stop trying to unpickle # after they are all upgraded try: if status_class == 4: data = self._decode_response(response, check_status=False) if isinstance(data, dict): for exc_type in self.reraise_exceptions: if exc_type.__name__ == data['exception']['type']: exception = exc_type(*data['exception']['args']) break else: exception = RemoteException(payload=data['exception'], response=response) else: exception = pickle.loads(data) elif status_class == 5: data = self._decode_response(response, check_status=False) if 'exception_pickled' in data: exception = pickle.loads(data['exception_pickled']) else: exception = RemoteException(payload=data['exception'], response=response) except (TypeError, pickle.UnpicklingError): raise RemoteException(payload=data, response=response) if exception: raise exception from None if status_class != 2: raise RemoteException( payload=f'API HTTP error: {status_code} {response.content}', response=response) def _decode_response(self, response, check_status=True): if check_status: self.raise_for_status(response) return decode_response( response, extra_decoders=self.extra_type_decoders) def __repr__(self): return '<{} url={}>'.format(self.__class__.__name__, self.url) class BytesRequest(Request): """Request with proper escaping of arbitrary byte sequences.""" encoding = 'utf-8' encoding_errors = 'surrogateescape' ENCODERS: Dict[str, Callable[[Any], Union[bytes, str]]] = { 'application/x-msgpack': msgpack_dumps, 'application/json': json_dumps, } def encode_data_server( data, content_type='application/x-msgpack', extra_type_encoders=None): encoded_data = ENCODERS[content_type]( data, extra_encoders=extra_type_encoders) return Response( encoded_data, mimetype=content_type, ) def decode_request(request, extra_decoders=None): content_type = request.mimetype data = request.get_data() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data, extra_decoders=extra_decoders) elif content_type == 'application/json': # XXX this .decode() is needed for py35. # Should not be needed any more with py37 r = json_loads(data.decode('utf-8'), extra_decoders=extra_decoders) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r def error_handler(exception, encoder, status_code=500): logging.exception(exception) - tb = traceback.format_exception(None, exception, exception.__traceback__) - error = { - 'exception': { - 'type': type(exception).__name__, - 'args': exception.args, - 'message': str(exception), - 'traceback': tb, - } - } - response = encoder(error) + response = encoder(exception_to_dict(exception)) if isinstance(exception, HTTPException): response.status_code = exception.code else: # TODO: differentiate between server errors and client errors response.status_code = status_code return response class RPCServerApp(Flask): """For each endpoint of the given `backend_class`, tells app.route to call a function that decodes the request and sends it to the backend object provided by the factory. :param Any backend_class: The class of the backend, which will be analyzed to look for API endpoints. :param Optional[Callable[[], backend_class]] backend_factory: A function with no argument that returns an instance of `backend_class`. If unset, defaults to calling `backend_class` constructor directly. """ request_class = BytesRequest extra_type_encoders: List[Tuple[type, str, Callable]] = [] """Value of `extra_encoders` passed to `json_dumps` or `msgpack_dumps` to be able to serialize more object types.""" extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" def __init__(self, *args, backend_class=None, backend_factory=None, **kwargs): super().__init__(*args, **kwargs) self.backend_class = backend_class if backend_class is not None: if backend_factory is None: backend_factory = backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, '_endpoint_path'): self.__add_endpoint(meth_name, meth, backend_factory) def __add_endpoint(self, meth_name, meth, backend_factory): from flask import request @self.route('/'+meth._endpoint_path, methods=['POST']) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc def _f(): # Call the actual code obj_meth = getattr(backend_factory(), meth_name) kw = decode_request( request, extra_decoders=self.extra_type_decoders) return obj_meth(**kw) diff --git a/swh/core/api/asynchronous.py b/swh/core/api/asynchronous.py index f38b913..1746915 100644 --- a/swh/core/api/asynchronous.py +++ b/swh/core/api/asynchronous.py @@ -1,88 +1,97 @@ -import json -import logging -import pickle -import sys -import traceback +# Copyright (C) 2017-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + from collections import OrderedDict -import multidict +import logging +from typing import Tuple, Type import aiohttp.web from deprecated import deprecated +import multidict from .serializers import msgpack_dumps, msgpack_loads -from .serializers import SWHJSONDecoder, SWHJSONEncoder +from .serializers import json_dumps, json_loads +from .serializers import exception_to_dict from aiohttp_utils import negotiation, Response def encode_msgpack(data, **kwargs): return aiohttp.web.Response( body=msgpack_dumps(data), headers=multidict.MultiDict( {'Content-Type': 'application/x-msgpack'}), **kwargs ) encode_data_server = Response def render_msgpack(request, data): return msgpack_dumps(data) def render_json(request, data): - return json.dumps(data, cls=SWHJSONEncoder) + return json_dumps(data) async def decode_request(request): content_type = request.headers.get('Content-Type').split(';')[0].strip() data = await request.read() if not data: return {} if content_type == 'application/x-msgpack': r = msgpack_loads(data) elif content_type == 'application/json': - r = json.loads(data.decode(), cls=SWHJSONDecoder) + r = json_loads(data) else: raise ValueError('Wrong content type `%s` for API request' % content_type) return r async def error_middleware(app, handler): async def middleware_handler(request): try: return await handler(request) except Exception as e: if isinstance(e, aiohttp.web.HTTPException): raise logging.exception(e) - exception = traceback.format_exception(*sys.exc_info()) - res = {'exception': exception, - 'exception_pickled': pickle.dumps(e)} - return encode_data_server(res, status=500) + res = exception_to_dict(e) + if isinstance(e, app.client_exception_classes): + status = 400 + else: + status = 500 + return encode_data_server(res, status=status) return middleware_handler class RPCServerApp(aiohttp.web.Application): + client_exception_classes: Tuple[Type[Exception], ...] = () + """Exceptions that should be handled as a client error (eg. object not + found, invalid argument)""" + def __init__(self, *args, middlewares=(), **kwargs): middlewares = (error_middleware,) + middlewares # renderers are sorted in order of increasing desirability (!) # see mimeparse.best_match() docstring. renderers = OrderedDict([ ('application/json', render_json), ('application/x-msgpack', render_msgpack), ]) nego_middleware = negotiation.negotiation_middleware( renderers=renderers, force_rendering=True) middlewares = (nego_middleware,) + middlewares super().__init__(*args, middlewares=middlewares, **kwargs) @deprecated(version='0.0.64', reason='Use the RPCServerApp instead') class SWHRemoteAPI(RPCServerApp): pass diff --git a/swh/core/api/serializers.py b/swh/core/api/serializers.py index 57f37ae..0b81dd5 100644 --- a/swh/core/api/serializers.py +++ b/swh/core/api/serializers.py @@ -1,214 +1,227 @@ -# Copyright (C) 2015-2018 The Software Heritage developers +# Copyright (C) 2015-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import datetime import json +import traceback import types from uuid import UUID import arrow import iso8601 import msgpack from typing import Any, Dict, Union, Tuple from requests import Response ENCODERS = [ (arrow.Arrow, 'arrow', arrow.Arrow.isoformat), (datetime.datetime, 'datetime', datetime.datetime.isoformat), (datetime.timedelta, 'timedelta', lambda o: { 'days': o.days, 'seconds': o.seconds, 'microseconds': o.microseconds, }), (UUID, 'uuid', str), # Only for JSON: (bytes, 'bytes', lambda o: base64.b85encode(o).decode('ascii')), ] DECODERS = { 'arrow': arrow.get, 'datetime': lambda d: iso8601.parse_date(d, default_timezone=None), 'timedelta': lambda d: datetime.timedelta(**d), 'uuid': UUID, # Only for JSON: 'bytes': base64.b85decode, } def encode_data_client(data: Any, extra_encoders=None) -> bytes: try: return msgpack_dumps(data, extra_encoders=extra_encoders) except OverflowError as e: raise ValueError('Limits were reached. Please, check your input.\n' + str(e)) def decode_response(response: Response, extra_decoders=None) -> Any: content_type = response.headers['content-type'] if content_type.startswith('application/x-msgpack'): r = msgpack_loads(response.content, extra_decoders=extra_decoders) elif content_type.startswith('application/json'): r = json_loads(response.text, extra_decoders=extra_decoders) elif content_type.startswith('text/'): r = response.text else: raise ValueError('Wrong content type `%s` for API response' % content_type) return r class SWHJSONEncoder(json.JSONEncoder): """JSON encoder for data structures generated by Software Heritage. This JSON encoder extends the default Python JSON encoder and adds awareness for the following specific types: - bytes (get encoded as a Base85 string); - datetime.datetime (get encoded as an ISO8601 string). Non-standard types get encoded as a a dictionary with two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. SWHJSONEncoder also encodes arbitrary iterables as a list (allowing serialization of generators). Caveats: Limitations in the JSONEncoder extension mechanism prevent us from "escaping" dictionaries that only contain the swhtype and d keys, and therefore arbitrary data structures can't be round-tripped through SWHJSONEncoder and SWHJSONDecoder. """ def __init__(self, extra_encoders=None, **kwargs): super().__init__(**kwargs) self.encoders = ENCODERS if extra_encoders: self.encoders += extra_encoders def default(self, o: Any ) -> Union[Dict[str, Union[Dict[str, int], str]], list]: for (type_, type_name, encoder) in self.encoders: if isinstance(o, type_): return { 'swhtype': type_name, 'd': encoder(o), } try: return super().default(o) except TypeError as e: try: iterable = iter(o) except TypeError: raise e from None else: return list(iterable) class SWHJSONDecoder(json.JSONDecoder): """JSON decoder for data structures encoded with SWHJSONEncoder. This JSON decoder extends the default Python JSON decoder, allowing the decoding of: - bytes (encoded as a Base85 string); - datetime.datetime (encoded as an ISO8601 string). Non-standard types must be encoded as a a dictionary with exactly two keys: - swhtype with value 'bytes' or 'datetime'; - d containing the encoded value. To limit the impact our encoding, if the swhtype key doesn't contain a known value, the dictionary is decoded as-is. """ def __init__(self, extra_decoders=None, **kwargs): super().__init__(**kwargs) self.decoders = DECODERS if extra_decoders: self.decoders = {**self.decoders, **extra_decoders} def decode_data(self, o: Any) -> Any: if isinstance(o, dict): if set(o.keys()) == {'d', 'swhtype'}: if o['swhtype'] == 'bytes': return base64.b85decode(o['d']) decoder = self.decoders.get(o['swhtype']) if decoder: return decoder(self.decode_data(o['d'])) return {key: self.decode_data(value) for key, value in o.items()} if isinstance(o, list): return [self.decode_data(value) for value in o] else: return o def raw_decode(self, s: str, idx: int = 0) -> Tuple[Any, int]: data, index = super().raw_decode(s, idx) return self.decode_data(data), index def json_dumps(data: Any, extra_encoders=None) -> str: return json.dumps(data, cls=SWHJSONEncoder, extra_encoders=extra_encoders) def json_loads(data: str, extra_decoders=None) -> Any: return json.loads(data, cls=SWHJSONDecoder, extra_decoders=extra_decoders) def msgpack_dumps(data: Any, extra_encoders=None) -> bytes: """Write data as a msgpack stream""" encoders = ENCODERS if extra_encoders: encoders += extra_encoders def encode_types(obj): if isinstance(obj, types.GeneratorType): return list(obj) for (type_, type_name, encoder) in encoders: if isinstance(obj, type_): return { b'swhtype': type_name, b'd': encoder(obj), } return obj return msgpack.packb(data, use_bin_type=True, default=encode_types) def msgpack_loads(data: bytes, extra_decoders=None) -> Any: """Read data as a msgpack stream""" decoders = DECODERS if extra_decoders: decoders = {**decoders, **extra_decoders} def decode_types(obj): if set(obj.keys()) == {b'd', b'swhtype'}: decoder = decoders.get(obj[b'swhtype']) if decoder: return decoder(obj[b'd']) return obj try: return msgpack.unpackb(data, raw=False, object_hook=decode_types) except TypeError: # msgpack < 0.5.2 return msgpack.unpackb(data, encoding='utf-8', object_hook=decode_types) + + +def exception_to_dict(exception): + tb = traceback.format_exception(None, exception, exception.__traceback__) + return { + 'exception': { + 'type': type(exception).__name__, + 'args': exception.args, + 'message': str(exception), + 'traceback': tb, + } + } diff --git a/swh/core/api/tests/test_async.py b/swh/core/api/tests/test_async.py index 5086b59..11d03d6 100644 --- a/swh/core/api/tests/test_async.py +++ b/swh/core/api/tests/test_async.py @@ -1,186 +1,223 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import datetime import msgpack import json import pytest from swh.core.api.asynchronous import RPCServerApp, Response from swh.core.api.asynchronous import encode_msgpack, decode_request from swh.core.api.serializers import msgpack_dumps, SWHJSONEncoder pytest_plugins = ['aiohttp.pytest_plugin', 'pytester'] +class TestServerException(Exception): + pass + + +class TestClientError(Exception): + pass + + async def root(request): return Response('toor') STRUCT = {'txt': 'something stupid', # 'date': datetime.date(2019, 6, 9), # not supported 'datetime': datetime.datetime(2019, 6, 9, 10, 12), 'timedelta': datetime.timedelta(days=-2, hours=3), 'int': 42, 'float': 3.14, 'subdata': {'int': 42, 'datetime': datetime.datetime(2019, 6, 10, 11, 12), }, 'list': [42, datetime.datetime(2019, 9, 10, 11, 12), 'ok'], } async def struct(request): return Response(STRUCT) async def echo(request): data = await decode_request(request) return Response(data) +async def server_exception(request): + raise TestServerException() + + +async def client_error(request): + raise TestClientError() + + async def echo_no_nego(request): # let the content negotiation handle the serialization for us... data = await decode_request(request) ret = encode_msgpack(data) return ret def check_mimetype(src, dst): src = src.split(';')[0].strip() dst = dst.split(';')[0].strip() assert src == dst @pytest.fixture def async_app(): app = RPCServerApp() + app.client_exception_classes = (TestClientError,) app.router.add_route('GET', '/', root) app.router.add_route('GET', '/struct', struct) app.router.add_route('POST', '/echo', echo) + app.router.add_route('GET', '/server_exception', server_exception) + app.router.add_route('GET', '/client_error', client_error) app.router.add_route('POST', '/echo-no-nego', echo_no_nego) return app async def test_get_simple(async_app, aiohttp_client) -> None: assert async_app is not None cli = await aiohttp_client(async_app) resp = await cli.get('/') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') data = await resp.read() value = msgpack.unpackb(data, raw=False) assert value == 'toor' +async def test_get_server_exception(async_app, aiohttp_client) -> None: + cli = await aiohttp_client(async_app) + resp = await cli.get('/server_exception') + assert resp.status == 500 + data = await resp.read() + data = msgpack.unpackb(data) + assert data[b'exception'][b'type'] == b'TestServerException' + + +async def test_get_client_error(async_app, aiohttp_client) -> None: + cli = await aiohttp_client(async_app) + resp = await cli.get('/client_error') + assert resp.status == 400 + data = await resp.read() + data = msgpack.unpackb(data) + assert data[b'exception'][b'type'] == b'TestClientError' + + async def test_get_simple_nego(async_app, aiohttp_client) -> None: cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == 'toor' async def test_get_struct(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) resp = await cli.get('/struct') assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_get_struct_nego(async_app, aiohttp_client) -> None: """Test returned structured from a simple GET data is OK""" cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'): resp = await cli.get('/struct', headers={'Accept': 'application/%s' % ctype}) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_msgpack(async_app, aiohttp_client) -> None: """Test that msgpack encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) # simple struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps({'toto': 42})) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} # complex struct resp = await cli.post( '/echo', headers={'Content-Type': 'application/x-msgpack'}, data=msgpack_dumps(STRUCT)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT async def test_post_struct_json(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is""" cli = await aiohttp_client(async_app) resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps({'toto': 42}, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == {'toto': 42} resp = await cli.post( '/echo', headers={'Content-Type': 'application/json'}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') # assert resp.headers['Content-Type'] == 'application/x-msgpack' assert (await decode_request(resp)) == STRUCT async def test_post_struct_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as is using content negotiation (accept json or msgpack). """ cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/%s' % ctype) assert (await decode_request(resp)) == STRUCT async def test_post_struct_no_nego(async_app, aiohttp_client) -> None: """Test that json encoded posted struct data is returned as msgpack when using non-negotiation-compatible handlers. """ cli = await aiohttp_client(async_app) for ctype in ('x-msgpack', 'json'): resp = await cli.post( '/echo-no-nego', headers={'Content-Type': 'application/json', 'Accept': 'application/%s' % ctype}, data=json.dumps(STRUCT, cls=SWHJSONEncoder)) assert resp.status == 200 check_mimetype(resp.headers['Content-Type'], 'application/x-msgpack') assert (await decode_request(resp)) == STRUCT