diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,3 +5,4 @@ swh.loader.core swh.loader.git >= 0.0.52 swh.storage[testing] +pytest-mock diff --git a/swh/vault/__init__.py b/swh/vault/__init__.py --- a/swh/vault/__init__.py +++ b/swh/vault/__init__.py @@ -1,21 +1,32 @@ -# Copyright (C) 2018 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information + +from __future__ import annotations + +import importlib import logging +from typing import Dict +import warnings logger = logging.getLogger(__name__) -def get_vault(cls="remote", args={}): +BACKEND_TYPES: Dict[str, str] = { + "remote": ".api.client.RemoteVaultClient", + "local": ".backend.VaultBackend", +} + + +def get_vault(cls: str = "remote", **kwargs): """ Get a vault object of class `vault_class` with arguments `vault_args`. Args: - vault (dict): dictionary with keys: - - cls (str): vault's class, either 'remote' - - args (dict): dictionary with keys + cls: vault's class, either 'remote' or 'local' + kwargs: arguments to pass to the class' constructor Returns: an instance of VaultBackend (either local or remote) @@ -24,18 +35,20 @@ ValueError if passed an unknown storage class. """ - if cls == "remote": - from .api.client import RemoteVaultClient as Vault - elif cls == "local": - from swh.scheduler import get_scheduler - from swh.storage import get_storage - from swh.vault.backend import VaultBackend as Vault - from swh.vault.cache import VaultCache - - args["cache"] = VaultCache(**args["cache"]) - args["storage"] = get_storage(**args["storage"]) - args["scheduler"] = get_scheduler(**args["scheduler"]) - else: - raise ValueError("Unknown storage class `%s`" % cls) - logger.debug("Instantiating %s with %s" % (Vault, args)) - return Vault(**args) + if "args" in kwargs: + warnings.warn( + 'Explicit "args" key is deprecated, use keys directly instead.', + DeprecationWarning, + ) + kwargs = kwargs["args"] + + class_path = BACKEND_TYPES.get(cls) + if class_path is None: + raise ValueError( + f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}" + ) + + (module_path, class_name) = class_path.rsplit(".", 1) + module = importlib.import_module(module_path, package=__package__) + Vault = getattr(module, class_name) + return Vault(**kwargs) diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py --- a/swh/vault/api/server.py +++ b/swh/vault/api/server.py @@ -214,7 +214,7 @@ if not args.get(key): raise ValueError("invalid configuration; missing %s config entry." % key) - return get_vault("local", args) + return get_vault("local", **args) def make_app_from_configfile(config_file=None, **kwargs): diff --git a/swh/vault/backend.py b/swh/vault/backend.py --- a/swh/vault/backend.py +++ b/swh/vault/backend.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2018 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -12,7 +12,10 @@ from swh.core.db import BaseDb from swh.core.db.common import db_transaction from swh.model import hashutil +from swh.scheduler import get_scheduler from swh.scheduler.utils import create_oneshot_task_dict +from swh.storage import get_storage +from swh.vault.cache import VaultCache from swh.vault.cookers import get_cooker_cls from swh.vault.exc import NotFoundExc @@ -67,11 +70,11 @@ Backend for the Software Heritage vault. """ - def __init__(self, db, cache, scheduler, storage=None, **config): + def __init__(self, db, **config): self.config = config - self.cache = cache - self.scheduler = scheduler - self.storage = storage + self.cache = VaultCache(**config["cache"]) + self.scheduler = get_scheduler(**config["scheduler"]) + self.storage = get_storage(**config["storage"]) self.smtp_server = smtplib.SMTP() self._pool = psycopg2.pool.ThreadedConnectionPool( diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py --- a/swh/vault/tests/conftest.py +++ b/swh/vault/tests/conftest.py @@ -1,6 +1,7 @@ import glob import os import subprocess +from typing import Any, Dict import pkg_resources.extern.packaging.version import pytest @@ -38,8 +39,28 @@ @pytest.fixture -def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path): +def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]: + tmp_path = str(tmp_path) + return { + "db": postgresql.dsn, + "storage": { + "cls": "local", + "db": postgresql2.dsn, + "objstorage": { + "cls": "pathslicing", + "args": {"root": tmp_path, "slicing": "0:1/1:5",}, + }, + }, + "cache": { + "cls": "pathslicing", + "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,}, + }, + "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",}, + } + +@pytest.fixture +def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path): for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)): dump_files = os.path.join(sql_dir, "*.sql") all_dump_files = sorted(glob.glob(dump_files), key=sortkey) @@ -59,28 +80,7 @@ ] ) - vault_config = { - "db": db_url("tests", postgresql_proc), - "storage": { - "cls": "local", - "db": db_url("tests2", postgresql_proc), - "objstorage": { - "cls": "pathslicing", - "args": {"root": str(tmp_path), "slicing": "0:1/1:5",}, - }, - }, - "cache": { - "cls": "pathslicing", - "args": { - "root": str(tmp_path), - "slicing": "0:1/1:5", - "allow_delete": True, - }, - }, - "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",}, - } - - return get_vault("local", vault_config) + return get_vault("local", **swh_vault_config) @pytest.fixture diff --git a/swh/vault/tests/test_init.py b/swh/vault/tests/test_init.py new file mode 100644 --- /dev/null +++ b/swh/vault/tests/test_init.py @@ -0,0 +1,55 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.vault import get_vault +from swh.vault.api.client import RemoteVaultClient +from swh.vault.backend import VaultBackend + +SERVER_IMPLEMENTATIONS = [ + ("remote", RemoteVaultClient, {"url": "localhost"}), + ( + "local", + VaultBackend, + { + "db": "something", + "cache": {"cls": "memory", "args": {}}, + "storage": {"cls": "remote", "url": "mock://storage-url"}, + "scheduler": {"cls": "remote", "url": "mock://scheduler-url"}, + }, + ), +] + + +@pytest.fixture +def mock_psycopg2(mocker): + mocker.patch("swh.vault.backend.psycopg2.pool") + mocker.patch("swh.vault.backend.psycopg2.extras") + + +def test_init_get_vault_failure(): + with pytest.raises(ValueError, match="Unknown Vault class"): + get_vault("unknown-vault-storage") + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2): + concrete_vault = get_vault(class_name, **kwargs) + assert isinstance(concrete_vault, expected_class) + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_vault_deprecation_warning( + class_name, expected_class, kwargs, mock_psycopg2 +): + with pytest.warns(DeprecationWarning): + concrete_vault = get_vault(class_name, args=kwargs) + assert isinstance(concrete_vault, expected_class) + + +def test_init_get_vault_ok(swh_vault_config): + concrete_vault = get_vault("local", **swh_vault_config) + assert isinstance(concrete_vault, VaultBackend)