diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py --- a/swh/vault/api/server.py +++ b/swh/vault/api/server.py @@ -1,17 +1,16 @@ -# Copyright (C) 2016-2020 The Software Heritage developers +# Copyright (C) 2016-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from __future__ import annotations -import asyncio import os from typing import Any, Dict, Optional -import aiohttp.web - -from swh.core.api.asynchronous import RPCServerApp +from swh.core.api import RPCServerApp +from swh.core.api import encode_data_server as encode_data +from swh.core.api import error_handler from swh.core.config import config_basepath, merge_configs, read_raw_config from swh.vault import get_vault as get_swhvault from swh.vault.backend import NotFoundExc @@ -24,15 +23,12 @@ "client_max_size": 1024 ** 3, } -vault = None -app = None - -def get_vault(config: Optional[Dict[str, Any]] = None) -> VaultInterface: +def get_vault(): global vault if not vault: - assert config is not None - vault = get_swhvault(**config) + vault = get_swhvault(**app.config["vault"]) + return vault @@ -42,9 +38,23 @@ extra_type_encoders = ENCODERS -@asyncio.coroutine -def index(request): - return aiohttp.web.Response(body="SWH Vault API server") +vault = None +app = VaultServerApp(__name__, backend_class=VaultInterface, backend_factory=get_vault,) + + +@app.errorhandler(NotFoundExc) +def argument_error_handler(exception): + return error_handler(exception, encode_data, status_code=400) + + +@app.errorhandler(Exception) +def my_error_handler(exception): + return error_handler(exception, encode_data) + + +@app.route("/") +def index(): + return "SWH Vault API server" def check_config(cfg: Dict[str, Any]) -> Dict[str, Any]: @@ -83,21 +93,6 @@ return vcfg -def make_app(config: Dict[str, Any]) -> VaultServerApp: - """Ensure the configuration is ok, then instantiate the server application - - """ - config = check_config(config) - app = VaultServerApp( - __name__, - backend_class=VaultInterface, - backend_factory=lambda: get_vault(config), - client_max_size=config["client_max_size"], - ) - app.router.add_route("GET", "/", index) - return app - - def make_app_from_configfile( config_path: Optional[str] = None, **kwargs ) -> VaultServerApp: @@ -105,17 +100,14 @@ application. """ - global app - if not app: - config_path = os.environ.get("SWH_CONFIG_FILENAME", config_path) - if not config_path: - raise ValueError("Missing configuration path.") - if not os.path.isfile(config_path): - raise ValueError(f"Configuration path {config_path} should exist.") - - app_config = read_raw_config(config_basepath(config_path)) - app_config = merge_configs(DEFAULT_CONFIG, app_config) - app = make_app(app_config) + config_path = os.environ.get("SWH_CONFIG_FILENAME", config_path) + if not config_path: + raise ValueError("Missing configuration path.") + if not os.path.isfile(config_path): + raise ValueError(f"Configuration path {config_path} should exist.") + + app_config = read_raw_config(config_basepath(config_path)) + app.config.update(merge_configs(DEFAULT_CONFIG, app_config)) return app diff --git a/swh/vault/tests/test_server.py b/swh/vault/tests/test_server.py --- a/swh/vault/tests/test_server.py +++ b/swh/vault/tests/test_server.py @@ -1,11 +1,10 @@ -# Copyright (C) 2020 The Software Heritage developers +# Copyright (C) 2020-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import copy import os -from typing import Any, Dict import pytest import yaml @@ -13,12 +12,7 @@ from swh.core.api.serializers import json_dumps, msgpack_dumps, msgpack_loads from swh.vault.api.serializers import ENCODERS import swh.vault.api.server -from swh.vault.api.server import ( - VaultServerApp, - check_config, - make_app, - make_app_from_configfile, -) +from swh.vault.api.server import app, check_config, make_app_from_configfile from swh.vault.tests.test_backend import TEST_SWHID @@ -43,6 +37,11 @@ """ app = make_app_from_configfile() assert app is not None + assert "vault" in app.config + + # Cleanup app + del app.config["vault"] + swh.vault.api.server.vault = None def test_make_app_from_file(swh_local_vault_config, tmp_path): @@ -55,79 +54,84 @@ app = make_app_from_configfile(conf_path) assert app is not None + assert "vault" in app.config + + # Cleanup app + del app.config["vault"] + swh.vault.api.server.vault = None @pytest.fixture -def async_app(swh_local_vault_config: Dict[str, Any],) -> VaultServerApp: - """Instantiate the vault server application. +def vault_app(swh_local_vault_config): + # Set app config + app.config["vault"] = swh_local_vault_config["vault"] - Note: This requires the db setup to run (fixture swh_vault in charge of this) + yield app - """ - # make sure a new VaultBackend is instantiated for each test to prevent - # side effects between tests + # Cleanup app + del app.config["vault"] swh.vault.api.server.vault = None - return make_app(swh_local_vault_config) @pytest.fixture -def cli(async_app, aiohttp_client, loop): - return loop.run_until_complete(aiohttp_client(async_app)) +def cli(vault_app): + cli = vault_app.test_client() + return cli -async def test_client_index(cli): - resp = await cli.get("/") - assert resp.status == 200 +def test_client_index(cli): + resp = cli.get("/") + assert resp.status == "200 OK" -async def test_client_cook_notfound(cli): - resp = await cli.post( +def test_client_cook_notfound(cli): + resp = cli.post( "/cook", data=json_dumps( {"bundle_type": "flat", "swhid": TEST_SWHID}, extra_encoders=ENCODERS ), headers=[("Content-Type", "application/json")], ) - assert resp.status == 400 - content = msgpack_loads(await resp.content.read()) + assert resp.status == "400 BAD REQUEST" + content = msgpack_loads(resp.data) assert content["type"] == "NotFoundExc" assert content["args"] == [f"flat {TEST_SWHID} was not found."] -async def test_client_progress_notfound(cli): - resp = await cli.post( +def test_client_progress_notfound(cli): + resp = cli.post( "/progress", data=json_dumps( {"bundle_type": "flat", "swhid": TEST_SWHID}, extra_encoders=ENCODERS ), headers=[("Content-Type", "application/json")], ) - assert resp.status == 400 - content = msgpack_loads(await resp.content.read()) + assert resp.status == "400 BAD REQUEST" + content = msgpack_loads(resp.data) assert content["type"] == "NotFoundExc" assert content["args"] == [f"flat {TEST_SWHID} was not found."] -async def test_client_batch_cook_invalid_type(cli): - resp = await cli.post( +def test_client_batch_cook_invalid_type(cli): + resp = cli.post( "/batch_cook", data=msgpack_dumps({"batch": [("foobar", [])]}), headers={"Content-Type": "application/x-msgpack"}, ) - assert resp.status == 400 - content = msgpack_loads(await resp.content.read()) + assert resp.status == "400 BAD REQUEST" + content = msgpack_loads(resp.data) assert content["type"] == "NotFoundExc" assert content["args"] == ["foobar is an unknown type."] -async def test_client_batch_progress_notfound(cli): - resp = await cli.post( +def test_client_batch_progress_notfound(cli): + resp = cli.post( "/batch_progress", data=msgpack_dumps({"batch_id": 1}), headers={"Content-Type": "application/x-msgpack"}, ) - assert resp.status == 400 - content = msgpack_loads(await resp.content.read()) + assert resp.status == "400 BAD REQUEST" + content = msgpack_loads(resp.data) assert content["type"] == "NotFoundExc" assert content["args"] == ["Batch 1 does not exist."]