diff --git a/requirements-test.txt b/requirements-test.txt index 078a4e3..66b4544 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,7 +1,8 @@ pytest pytest-aiohttp pytest-postgresql dulwich >= 0.18.7 swh.loader.core swh.loader.git >= 0.0.52 swh.storage[testing] +pytest-mock diff --git a/swh/vault/__init__.py b/swh/vault/__init__.py index a39a171..db16ff9 100644 --- a/swh/vault/__init__.py +++ b/swh/vault/__init__.py @@ -1,41 +1,54 @@ -# Copyright (C) 2018 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information + +from __future__ import annotations + +import importlib import logging +from typing import Dict +import warnings logger = logging.getLogger(__name__) -def get_vault(cls="remote", args={}): +BACKEND_TYPES: Dict[str, str] = { + "remote": ".api.client.RemoteVaultClient", + "local": ".backend.VaultBackend", +} + + +def get_vault(cls: str = "remote", **kwargs): """ Get a vault object of class `vault_class` with arguments `vault_args`. Args: - vault (dict): dictionary with keys: - - cls (str): vault's class, either 'remote' - - args (dict): dictionary with keys + cls: vault's class, either 'remote' or 'local' + kwargs: arguments to pass to the class' constructor Returns: an instance of VaultBackend (either local or remote) Raises: ValueError if passed an unknown storage class. """ - if cls == "remote": - from .api.client import RemoteVaultClient as Vault - elif cls == "local": - from swh.scheduler import get_scheduler - from swh.storage import get_storage - from swh.vault.backend import VaultBackend as Vault - from swh.vault.cache import VaultCache - - args["cache"] = VaultCache(**args["cache"]) - args["storage"] = get_storage(**args["storage"]) - args["scheduler"] = get_scheduler(**args["scheduler"]) - else: - raise ValueError("Unknown storage class `%s`" % cls) - logger.debug("Instantiating %s with %s" % (Vault, args)) - return Vault(**args) + if "args" in kwargs: + warnings.warn( + 'Explicit "args" key is deprecated, use keys directly instead.', + DeprecationWarning, + ) + kwargs = kwargs["args"] + + class_path = BACKEND_TYPES.get(cls) + if class_path is None: + raise ValueError( + f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}" + ) + + (module_path, class_name) = class_path.rsplit(".", 1) + module = importlib.import_module(module_path, package=__package__) + Vault = getattr(module, class_name) + return Vault(**kwargs) diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py index 6c178e0..6440fc2 100644 --- a/swh/vault/api/server.py +++ b/swh/vault/api/server.py @@ -1,233 +1,233 @@ # Copyright (C) 2016-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import collections import os import aiohttp.web from swh.core import config from swh.core.api.asynchronous import RPCServerApp, decode_request from swh.core.api.asynchronous import encode_data_server as encode_data from swh.model import hashutil from swh.vault import get_vault from swh.vault.backend import NotFoundExc from swh.vault.cookers import COOKER_TYPES DEFAULT_CONFIG_PATH = "vault/server" DEFAULT_CONFIG = { "storage": ("dict", {"cls": "remote", "args": {"url": "http://localhost:5002/",},}), "cache": ( "dict", { "cls": "pathslicing", "args": {"root": "/srv/softwareheritage/vault", "slicing": "0:1/1:5",}, }, ), "client_max_size": ("int", 1024 ** 3), "vault": ( "dict", {"cls": "local", "args": {"db": "dbname=softwareheritage-vault-dev",},}, ), "scheduler": ("dict", {"cls": "remote", "url": "http://localhost:5008/",},), } @asyncio.coroutine def index(request): return aiohttp.web.Response(body="SWH Vault API server") # Web API endpoints @asyncio.coroutine def vault_fetch(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] if not request.app["backend"].is_available(obj_type, obj_id): raise NotFoundExc(f"{obj_type} {obj_id} is not available.") return encode_data(request.app["backend"].fetch(obj_type, obj_id)) def user_info(task_info): return { "id": task_info["id"], "status": task_info["task_status"], "progress_message": task_info["progress_msg"], "obj_type": task_info["type"], "obj_id": hashutil.hash_to_hex(task_info["object_id"]), } @asyncio.coroutine def vault_cook(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] email = request.query.get("email") sticky = request.query.get("sticky") in ("true", "1") if obj_type not in COOKER_TYPES: raise NotFoundExc(f"{obj_type} is an unknown type.") info = request.app["backend"].cook_request( obj_type, obj_id, email=email, sticky=sticky ) # TODO: return 201 status (Created) once the api supports it return encode_data(user_info(info)) @asyncio.coroutine def vault_progress(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] info = request.app["backend"].task_info(obj_type, obj_id) if not info: raise NotFoundExc(f"{obj_type} {obj_id} was not found.") return encode_data(user_info(info)) # Cookers endpoints @asyncio.coroutine def set_progress(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] progress = yield from decode_request(request) request.app["backend"].set_progress(obj_type, obj_id, progress) return encode_data(True) # FIXME: success value? @asyncio.coroutine def set_status(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] status = yield from decode_request(request) request.app["backend"].set_status(obj_type, obj_id, status) return encode_data(True) # FIXME: success value? @asyncio.coroutine def put_bundle(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] # TODO: handle streaming properly content = yield from decode_request(request) request.app["backend"].cache.add(obj_type, obj_id, content) return encode_data(True) # FIXME: success value? @asyncio.coroutine def send_notif(request): obj_type = request.match_info["type"] obj_id = request.match_info["id"] request.app["backend"].send_all_notifications(obj_type, obj_id) return encode_data(True) # FIXME: success value? # Batch endpoints @asyncio.coroutine def batch_cook(request): batch = yield from decode_request(request) for obj_type, obj_id in batch: if obj_type not in COOKER_TYPES: raise NotFoundExc(f"{obj_type} is an unknown type.") batch_id = request.app["backend"].batch_cook(batch) return encode_data({"id": batch_id}) @asyncio.coroutine def batch_progress(request): batch_id = request.match_info["batch_id"] bundles = request.app["backend"].batch_info(batch_id) if not bundles: raise NotFoundExc(f"Batch {batch_id} does not exist.") bundles = [user_info(bundle) for bundle in bundles] counter = collections.Counter(b["status"] for b in bundles) res = { "bundles": bundles, "total": len(bundles), **{k: 0 for k in ("new", "pending", "done", "failed")}, **dict(counter), } return encode_data(res) # Web server def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_route("GET", "/", index) app.client_exception_classes = (NotFoundExc,) # Endpoints used by the web API app.router.add_route("GET", "/fetch/{type}/{id}", vault_fetch) app.router.add_route("POST", "/cook/{type}/{id}", vault_cook) app.router.add_route("GET", "/progress/{type}/{id}", vault_progress) # Endpoints used by the Cookers app.router.add_route("POST", "/set_progress/{type}/{id}", set_progress) app.router.add_route("POST", "/set_status/{type}/{id}", set_status) app.router.add_route("POST", "/put_bundle/{type}/{id}", put_bundle) app.router.add_route("POST", "/send_notif/{type}/{id}", send_notif) # Endpoints for batch requests app.router.add_route("POST", "/batch_cook", batch_cook) app.router.add_route("GET", "/batch_progress/{batch_id}", batch_progress) app["backend"] = backend return app def get_local_backend(cfg): if "vault" not in cfg: raise ValueError("missing '%vault' configuration") vcfg = cfg["vault"] if vcfg["cls"] != "local": raise EnvironmentError( "The vault backend can only be started with a 'local' " "configuration", err=True, ) args = vcfg["args"] if "cache" not in args: args["cache"] = cfg.get("cache") if "storage" not in args: args["storage"] = cfg.get("storage") if "scheduler" not in args: args["scheduler"] = cfg.get("scheduler") for key in ("cache", "storage", "scheduler"): if not args.get(key): raise ValueError("invalid configuration; missing %s config entry." % key) - return get_vault("local", args) + return get_vault("local", **args) def make_app_from_configfile(config_file=None, **kwargs): if config_file is None: config_file = DEFAULT_CONFIG_PATH config_file = os.environ.get("SWH_CONFIG_FILENAME", config_file) if os.path.isfile(config_file): cfg = config.read(config_file, DEFAULT_CONFIG) else: cfg = config.load_named_config(config_file, DEFAULT_CONFIG) vault = get_local_backend(cfg) return make_app(backend=vault, client_max_size=cfg["client_max_size"], **kwargs) if __name__ == "__main__": print("Deprecated. Use swh-vault ") diff --git a/swh/vault/backend.py b/swh/vault/backend.py index 1974e9e..69d4690 100644 --- a/swh/vault/backend.py +++ b/swh/vault/backend.py @@ -1,487 +1,490 @@ -# Copyright (C) 2017-2018 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from email.mime.text import MIMEText import smtplib import psycopg2.extras import psycopg2.pool from swh.core.db import BaseDb from swh.core.db.common import db_transaction from swh.model import hashutil +from swh.scheduler import get_scheduler from swh.scheduler.utils import create_oneshot_task_dict +from swh.storage import get_storage +from swh.vault.cache import VaultCache from swh.vault.cookers import get_cooker_cls from swh.vault.exc import NotFoundExc cooking_task_name = "swh.vault.cooking_tasks.SWHCookingTask" NOTIF_EMAIL_FROM = '"Software Heritage Vault" ' "" NOTIF_EMAIL_SUBJECT_SUCCESS = "Bundle ready: {obj_type} {short_id}" NOTIF_EMAIL_SUBJECT_FAILURE = "Bundle failed: {obj_type} {short_id}" NOTIF_EMAIL_BODY_SUCCESS = """ You have requested the following bundle from the Software Heritage Vault: Object Type: {obj_type} Object ID: {hex_id} This bundle is now available for download at the following address: {url} Please keep in mind that this link might expire at some point, in which case you will need to request the bundle again. --\x20 The Software Heritage Developers """ NOTIF_EMAIL_BODY_FAILURE = """ You have requested the following bundle from the Software Heritage Vault: Object Type: {obj_type} Object ID: {hex_id} This bundle could not be cooked for the following reason: {progress_msg} We apologize for the inconvenience. --\x20 The Software Heritage Developers """ def batch_to_bytes(batch): return [(obj_type, hashutil.hash_to_bytes(obj_id)) for obj_type, obj_id in batch] class VaultBackend: """ Backend for the Software Heritage vault. """ - def __init__(self, db, cache, scheduler, storage=None, **config): + def __init__(self, db, **config): self.config = config - self.cache = cache - self.scheduler = scheduler - self.storage = storage + self.cache = VaultCache(**config["cache"]) + self.scheduler = get_scheduler(**config["scheduler"]) + self.storage = get_storage(**config["storage"]) self.smtp_server = smtplib.SMTP() self._pool = psycopg2.pool.ThreadedConnectionPool( config.get("min_pool_conns", 1), config.get("max_pool_conns", 10), db, cursor_factory=psycopg2.extras.RealDictCursor, ) self._db = None def get_db(self): if self._db: return self._db return BaseDb.from_pool(self._pool) def put_db(self, db): if db is not self._db: db.put_conn() @db_transaction() def task_info(self, obj_type, obj_id, db=None, cur=None): """Fetch information from a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ SELECT id, type, object_id, task_id, task_status, sticky, ts_created, ts_done, ts_last_access, progress_msg FROM vault_bundle WHERE type = %s AND object_id = %s""", (obj_type, obj_id), ) res = cur.fetchone() if res: res["object_id"] = bytes(res["object_id"]) return res def _send_task(self, *args): """Send a cooking task to the celery scheduler""" task = create_oneshot_task_dict("cook-vault-bundle", *args) added_tasks = self.scheduler.create_tasks([task]) return added_tasks[0]["id"] @db_transaction() def create_task(self, obj_type, obj_id, sticky=False, db=None, cur=None): """Create and send a cooking task""" obj_id = hashutil.hash_to_bytes(obj_id) hex_id = hashutil.hash_to_hex(obj_id) cooker_class = get_cooker_cls(obj_type) cooker = cooker_class(obj_type, hex_id, backend=self, storage=self.storage) if not cooker.check_exists(): raise NotFoundExc("Object {} was not found.".format(hex_id)) cur.execute( """ INSERT INTO vault_bundle (type, object_id, sticky) VALUES (%s, %s, %s)""", (obj_type, obj_id, sticky), ) db.conn.commit() task_id = self._send_task(obj_type, hex_id) cur.execute( """ UPDATE vault_bundle SET task_id = %s WHERE type = %s AND object_id = %s""", (task_id, obj_type, obj_id), ) @db_transaction() def add_notif_email(self, obj_type, obj_id, email, db=None, cur=None): """Add an e-mail address to notify when a given bundle is ready""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ INSERT INTO vault_notif_email (email, bundle_id) VALUES (%s, (SELECT id FROM vault_bundle WHERE type = %s AND object_id = %s))""", (email, obj_type, obj_id), ) @db_transaction() def cook_request( self, obj_type, obj_id, *, sticky=False, email=None, db=None, cur=None ): """Main entry point for cooking requests. This starts a cooking task if needed, and add the given e-mail to the notify list""" obj_id = hashutil.hash_to_bytes(obj_id) info = self.task_info(obj_type, obj_id) # If there's a failed bundle entry, delete it first. if info is not None and info["task_status"] == "failed": cur.execute( """DELETE FROM vault_bundle WHERE type = %s AND object_id = %s""", (obj_type, obj_id), ) db.conn.commit() info = None # If there's no bundle entry, create the task. if info is None: self.create_task(obj_type, obj_id, sticky) if email is not None: # If the task is already done, send the email directly if info is not None and info["task_status"] == "done": self.send_notification( None, email, obj_type, obj_id, info["task_status"] ) # Else, add it to the notification queue else: self.add_notif_email(obj_type, obj_id, email) info = self.task_info(obj_type, obj_id) return info @db_transaction() def batch_cook(self, batch, db=None, cur=None): """Cook a batch of bundles and returns the cooking id.""" # Import execute_values at runtime only, because it requires # psycopg2 >= 2.7 (only available on postgresql servers) from psycopg2.extras import execute_values cur.execute( """ INSERT INTO vault_batch (id) VALUES (DEFAULT) RETURNING id""" ) batch_id = cur.fetchone()["id"] batch = batch_to_bytes(batch) # Delete all failed bundles from the batch cur.execute( """ DELETE FROM vault_bundle WHERE task_status = 'failed' AND (type, object_id) IN %s""", (tuple(batch),), ) # Insert all the bundles, return the new ones execute_values( cur, """ INSERT INTO vault_bundle (type, object_id) VALUES %s ON CONFLICT DO NOTHING""", batch, ) # Get the bundle ids and task status cur.execute( """ SELECT id, type, object_id, task_id FROM vault_bundle WHERE (type, object_id) IN %s""", (tuple(batch),), ) bundles = cur.fetchall() # Insert the batch-bundle entries batch_id_bundle_ids = [(batch_id, row["id"]) for row in bundles] execute_values( cur, """ INSERT INTO vault_batch_bundle (batch_id, bundle_id) VALUES %s ON CONFLICT DO NOTHING""", batch_id_bundle_ids, ) db.conn.commit() # Get the tasks to fetch batch_new = [ (row["type"], bytes(row["object_id"])) for row in bundles if row["task_id"] is None ] # Send the tasks args_batch = [ (obj_type, hashutil.hash_to_hex(obj_id)) for obj_type, obj_id in batch_new ] # TODO: change once the scheduler handles priority tasks tasks = [ create_oneshot_task_dict("swh-vault-batch-cooking", *args) for args in args_batch ] added_tasks = self.scheduler.create_tasks(tasks) tasks_ids_bundle_ids = zip([task["id"] for task in added_tasks], batch_new) tasks_ids_bundle_ids = [ (task_id, obj_type, obj_id) for task_id, (obj_type, obj_id) in tasks_ids_bundle_ids ] # Update the task ids execute_values( cur, """ UPDATE vault_bundle SET task_id = s_task_id FROM (VALUES %s) AS sub (s_task_id, s_type, s_object_id) WHERE type = s_type::cook_type AND object_id = s_object_id """, tasks_ids_bundle_ids, ) return batch_id @db_transaction() def batch_info(self, batch_id, db=None, cur=None): """Fetch information from a batch of bundles""" cur.execute( """ SELECT vault_bundle.id as id, type, object_id, task_id, task_status, sticky, ts_created, ts_done, ts_last_access, progress_msg FROM vault_batch_bundle LEFT JOIN vault_bundle ON vault_bundle.id = bundle_id WHERE batch_id = %s""", (batch_id,), ) res = cur.fetchall() if res: for d in res: d["object_id"] = bytes(d["object_id"]) return res @db_transaction() def is_available(self, obj_type, obj_id, db=None, cur=None): """Check whether a bundle is available for retrieval""" info = self.task_info(obj_type, obj_id, cur=cur) return ( info is not None and info["task_status"] == "done" and self.cache.is_cached(obj_type, obj_id) ) @db_transaction() def fetch(self, obj_type, obj_id, db=None, cur=None): """Retrieve a bundle from the cache""" if not self.is_available(obj_type, obj_id, cur=cur): return None self.update_access_ts(obj_type, obj_id, cur=cur) return self.cache.get(obj_type, obj_id) @db_transaction() def update_access_ts(self, obj_type, obj_id, db=None, cur=None): """Update the last access timestamp of a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ UPDATE vault_bundle SET ts_last_access = NOW() WHERE type = %s AND object_id = %s""", (obj_type, obj_id), ) @db_transaction() def set_status(self, obj_type, obj_id, status, db=None, cur=None): """Set the cooking status of a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) req = ( """ UPDATE vault_bundle SET task_status = %s """ + (""", ts_done = NOW() """ if status == "done" else "") + """WHERE type = %s AND object_id = %s""" ) cur.execute(req, (status, obj_type, obj_id)) @db_transaction() def set_progress(self, obj_type, obj_id, progress, db=None, cur=None): """Set the cooking progress of a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ UPDATE vault_bundle SET progress_msg = %s WHERE type = %s AND object_id = %s""", (progress, obj_type, obj_id), ) @db_transaction() def send_all_notifications(self, obj_type, obj_id, db=None, cur=None): """Send all the e-mails in the notification list of a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ SELECT vault_notif_email.id AS id, email, task_status, progress_msg FROM vault_notif_email INNER JOIN vault_bundle ON bundle_id = vault_bundle.id WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s""", (obj_type, obj_id), ) for d in cur: self.send_notification( d["id"], d["email"], obj_type, obj_id, status=d["task_status"], progress_msg=d["progress_msg"], ) @db_transaction() def send_notification( self, n_id, email, obj_type, obj_id, status, progress_msg=None, db=None, cur=None, ): """Send the notification of a bundle to a specific e-mail""" hex_id = hashutil.hash_to_hex(obj_id) short_id = hex_id[:7] # TODO: instead of hardcoding this, we should probably: # * add a "fetch_url" field in the vault_notif_email table # * generate the url with flask.url_for() on the web-ui side # * send this url as part of the cook request and store it in # the table # * use this url for the notification e-mail url = "https://archive.softwareheritage.org/api/1/vault/{}/{}/" "raw".format( obj_type, hex_id ) if status == "done": text = NOTIF_EMAIL_BODY_SUCCESS.strip() text = text.format(obj_type=obj_type, hex_id=hex_id, url=url) msg = MIMEText(text) msg["Subject"] = NOTIF_EMAIL_SUBJECT_SUCCESS.format( obj_type=obj_type, short_id=short_id ) elif status == "failed": text = NOTIF_EMAIL_BODY_FAILURE.strip() text = text.format( obj_type=obj_type, hex_id=hex_id, progress_msg=progress_msg ) msg = MIMEText(text) msg["Subject"] = NOTIF_EMAIL_SUBJECT_FAILURE.format( obj_type=obj_type, short_id=short_id ) else: raise RuntimeError( "send_notification called on a '{}' bundle".format(status) ) msg["From"] = NOTIF_EMAIL_FROM msg["To"] = email self._smtp_send(msg) if n_id is not None: cur.execute( """ DELETE FROM vault_notif_email WHERE id = %s""", (n_id,), ) def _smtp_send(self, msg): # Reconnect if needed try: status = self.smtp_server.noop()[0] except smtplib.SMTPException: status = -1 if status != 250: self.smtp_server.connect("localhost", 25) # Send the message self.smtp_server.send_message(msg) @db_transaction() def _cache_expire(self, cond, *args, db=None, cur=None): """Low-level expiration method, used by cache_expire_* methods""" # Embedded SELECT query to be able to use ORDER BY and LIMIT cur.execute( """ DELETE FROM vault_bundle WHERE ctid IN ( SELECT ctid FROM vault_bundle WHERE sticky = false {} ) RETURNING type, object_id """.format( cond ), args, ) for d in cur: self.cache.delete(d["type"], bytes(d["object_id"])) @db_transaction() def cache_expire_oldest(self, n=1, by="last_access", db=None, cur=None): """Expire the `n` oldest bundles""" assert by in ("created", "done", "last_access") filter = """ORDER BY ts_{} LIMIT {}""".format(by, n) return self._cache_expire(filter) @db_transaction() def cache_expire_until(self, date, by="last_access", db=None, cur=None): """Expire all the bundles until a certain date""" assert by in ("created", "done", "last_access") filter = """AND ts_{} <= %s""".format(by) return self._cache_expire(filter, date) diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py index 9090e46..163ca80 100644 --- a/swh/vault/tests/conftest.py +++ b/swh/vault/tests/conftest.py @@ -1,88 +1,88 @@ import glob import os import subprocess +from typing import Any, Dict import pkg_resources.extern.packaging.version import pytest from pytest_postgresql import factories from swh.core.utils import numfile_sortkey as sortkey from swh.storage.tests import SQL_DIR as STORAGE_SQL_DIR from swh.vault import get_vault from swh.vault.tests import SQL_DIR os.environ["LC_ALL"] = "C.UTF-8" pytest_v = pkg_resources.get_distribution("pytest").parsed_version if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"): @pytest.fixture def tmp_path(request): import pathlib import tempfile with tempfile.TemporaryDirectory() as tmpdir: yield pathlib.Path(tmpdir) def db_url(name, postgresql_proc): return "postgresql://{user}@{host}:{port}/{dbname}".format( host=postgresql_proc.host, port=postgresql_proc.port, user="postgres", dbname=name, ) postgresql2 = factories.postgresql("postgresql_proc", "tests2") @pytest.fixture -def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path): +def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]: + tmp_path = str(tmp_path) + return { + "db": postgresql.dsn, + "storage": { + "cls": "local", + "db": postgresql2.dsn, + "objstorage": { + "cls": "pathslicing", + "args": {"root": tmp_path, "slicing": "0:1/1:5",}, + }, + }, + "cache": { + "cls": "pathslicing", + "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,}, + }, + "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",}, + } + +@pytest.fixture +def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path): for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)): dump_files = os.path.join(sql_dir, "*.sql") all_dump_files = sorted(glob.glob(dump_files), key=sortkey) for fname in all_dump_files: subprocess.check_call( [ "psql", "--quiet", "--no-psqlrc", "-v", "ON_ERROR_STOP=1", "-d", pg.dsn, "-f", fname, ] ) - vault_config = { - "db": db_url("tests", postgresql_proc), - "storage": { - "cls": "local", - "db": db_url("tests2", postgresql_proc), - "objstorage": { - "cls": "pathslicing", - "args": {"root": str(tmp_path), "slicing": "0:1/1:5",}, - }, - }, - "cache": { - "cls": "pathslicing", - "args": { - "root": str(tmp_path), - "slicing": "0:1/1:5", - "allow_delete": True, - }, - }, - "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",}, - } - - return get_vault("local", vault_config) + return get_vault("local", **swh_vault_config) @pytest.fixture def swh_storage(swh_vault): return swh_vault.storage diff --git a/swh/vault/tests/test_init.py b/swh/vault/tests/test_init.py new file mode 100644 index 0000000..7f402d6 --- /dev/null +++ b/swh/vault/tests/test_init.py @@ -0,0 +1,55 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.vault import get_vault +from swh.vault.api.client import RemoteVaultClient +from swh.vault.backend import VaultBackend + +SERVER_IMPLEMENTATIONS = [ + ("remote", RemoteVaultClient, {"url": "localhost"}), + ( + "local", + VaultBackend, + { + "db": "something", + "cache": {"cls": "memory", "args": {}}, + "storage": {"cls": "remote", "url": "mock://storage-url"}, + "scheduler": {"cls": "remote", "url": "mock://scheduler-url"}, + }, + ), +] + + +@pytest.fixture +def mock_psycopg2(mocker): + mocker.patch("swh.vault.backend.psycopg2.pool") + mocker.patch("swh.vault.backend.psycopg2.extras") + + +def test_init_get_vault_failure(): + with pytest.raises(ValueError, match="Unknown Vault class"): + get_vault("unknown-vault-storage") + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2): + concrete_vault = get_vault(class_name, **kwargs) + assert isinstance(concrete_vault, expected_class) + + +@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS) +def test_init_get_vault_deprecation_warning( + class_name, expected_class, kwargs, mock_psycopg2 +): + with pytest.warns(DeprecationWarning): + concrete_vault = get_vault(class_name, args=kwargs) + assert isinstance(concrete_vault, expected_class) + + +def test_init_get_vault_ok(swh_vault_config): + concrete_vault = get_vault("local", **swh_vault_config) + assert isinstance(concrete_vault, VaultBackend)