diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py new file mode 100644 --- /dev/null +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -0,0 +1,83 @@ +# Copyright (C) 2018-2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +import requests +from swh.core.api import remote_api_endpoint, RPCServerApp, RPCClient +from swh.core.api import error_handler, encode_data_server + + +class RPCTest: + @remote_api_endpoint('endpoint_url') + def endpoint(self, test_data, db=None, cur=None): + assert test_data == 'spam' + return 'egg' + + @remote_api_endpoint('path/to/endpoint') + def something(self, data, db=None, cur=None): + return data + + +class RPCTest2: + @remote_api_endpoint('endpoint_url') + def endpoint(self, test_data, db=None, cur=None): + assert test_data == 'spam' + return 'egg' + + @remote_api_endpoint('path/to/endpoint') + def something(self, data, db=None, cur=None): + return data + + @remote_api_endpoint('not_on_server') + def not_on_server(self, db=None, cur=None): + return 'ok' + + +class RPCTestClient(RPCClient): + backend_class = RPCTest2 + + +@pytest.fixture +def app(): + # needed by the 'swh_rpc_adapter' fixture (defined in pytest_plugin.py) + application = RPCServerApp('testapp', backend_class=RPCTest) + @application.errorhandler(Exception) + def my_error_handler(exception): + return error_handler(exception, encode_data_server) + return application + + +@pytest.fixture +def rpc_client(swh_rpc_adapter): + url = 'mock://example.com' + cli = RPCTestClient(url=url) + cli.session = requests.Session() + cli.session.mount(url, swh_rpc_adapter) + return cli + + +def test_api_client_endpoint_missing(rpc_client): + with pytest.raises(AttributeError): + rpc_client.missing(data='whatever') + + +def test_api_server_endpoint_missing(rpc_client): + with pytest.raises(Exception, match='404 Not Found'): + rpc_client.not_on_server() + + +def test_api_endpoint_kwargs(rpc_client): + res = rpc_client.something(data='whatever') + assert res == 'whatever' + res = rpc_client.endpoint(test_data='spam') + assert res == 'egg' + + +def test_api_endpoint_args(rpc_client): + res = rpc_client.something('whatever') + assert res == 'whatever' + res = rpc_client.endpoint('spam') + assert res == 'egg' diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py --- a/swh/core/pytest_plugin.py +++ b/swh/core/pytest_plugin.py @@ -6,12 +6,17 @@ import logging import re import pytest +import requests from functools import partial from os import path from typing import Dict, List, Optional from urllib.parse import urlparse +from requests.adapters import BaseAdapter +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers + logger = logging.getLogger(__name__) @@ -178,6 +183,56 @@ has_multi_visit=True) +@pytest.yield_fixture +def swh_rpc_adapter(app): + """Fixture that generates a requests.Adapter instance that + can be used to test client/servers code based on swh.core.api classes. + + See swh/core/api/tests/test_rpc_client_server.py for an example of usage. + """ + with app.test_client() as client: + yield RPCTestAdapter(client) + + +class RPCTestAdapter(BaseAdapter): + def __init__(self, client): + self._client = client + + def build_response(self, req, resp): + response = requests.Response() + + # Fallback to None if there's no status_code, for whatever reason. + response.status_code = resp.status_code + + # Make headers case-insensitive. + response.headers = CaseInsensitiveDict(getattr(resp, 'headers', {})) + + # Set encoding. + response.encoding = get_encoding_from_headers(response.headers) + response.raw = resp + response.reason = response.raw.status + + if isinstance(req.url, bytes): + response.url = req.url.decode('utf-8') + else: + response.url = req.url + + # Give the Response some context. + response.request = req + response.connection = self + response._content = resp.data + + return response + + def send(self, request, **kw): + resp = self._client.open( + request.url, method=request.method, + headers=request.headers.items(), + data=request.body, + ) + return self.build_response(request, resp) + + @pytest.yield_fixture def flask_app_client(app): with app.test_client() as client: