Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/requirements-test.txt b/requirements-test.txt
index 078a4e3..66b4544 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,7 +1,8 @@
pytest
pytest-aiohttp
pytest-postgresql
dulwich >= 0.18.7
swh.loader.core
swh.loader.git >= 0.0.52
swh.storage[testing]
+pytest-mock
diff --git a/swh/vault/__init__.py b/swh/vault/__init__.py
index a39a171..db16ff9 100644
--- a/swh/vault/__init__.py
+++ b/swh/vault/__init__.py
@@ -1,41 +1,54 @@
-# Copyright (C) 2018 The Software Heritage developers
+# Copyright (C) 2018-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU Affero General Public License version 3, or any later version
# See top-level LICENSE file for more information
+
+from __future__ import annotations
+
+import importlib
import logging
+from typing import Dict
+import warnings
logger = logging.getLogger(__name__)
-def get_vault(cls="remote", args={}):
+BACKEND_TYPES: Dict[str, str] = {
+ "remote": ".api.client.RemoteVaultClient",
+ "local": ".backend.VaultBackend",
+}
+
+
+def get_vault(cls: str = "remote", **kwargs):
"""
Get a vault object of class `vault_class` with arguments
`vault_args`.
Args:
- vault (dict): dictionary with keys:
- - cls (str): vault's class, either 'remote'
- - args (dict): dictionary with keys
+ cls: vault's class, either 'remote' or 'local'
+ kwargs: arguments to pass to the class' constructor
Returns:
an instance of VaultBackend (either local or remote)
Raises:
ValueError if passed an unknown storage class.
"""
- if cls == "remote":
- from .api.client import RemoteVaultClient as Vault
- elif cls == "local":
- from swh.scheduler import get_scheduler
- from swh.storage import get_storage
- from swh.vault.backend import VaultBackend as Vault
- from swh.vault.cache import VaultCache
-
- args["cache"] = VaultCache(**args["cache"])
- args["storage"] = get_storage(**args["storage"])
- args["scheduler"] = get_scheduler(**args["scheduler"])
- else:
- raise ValueError("Unknown storage class `%s`" % cls)
- logger.debug("Instantiating %s with %s" % (Vault, args))
- return Vault(**args)
+ if "args" in kwargs:
+ warnings.warn(
+ 'Explicit "args" key is deprecated, use keys directly instead.',
+ DeprecationWarning,
+ )
+ kwargs = kwargs["args"]
+
+ class_path = BACKEND_TYPES.get(cls)
+ if class_path is None:
+ raise ValueError(
+ f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}"
+ )
+
+ (module_path, class_name) = class_path.rsplit(".", 1)
+ module = importlib.import_module(module_path, package=__package__)
+ Vault = getattr(module, class_name)
+ return Vault(**kwargs)
diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py
index 6c178e0..6440fc2 100644
--- a/swh/vault/api/server.py
+++ b/swh/vault/api/server.py
@@ -1,233 +1,233 @@
# Copyright (C) 2016-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import asyncio
import collections
import os
import aiohttp.web
from swh.core import config
from swh.core.api.asynchronous import RPCServerApp, decode_request
from swh.core.api.asynchronous import encode_data_server as encode_data
from swh.model import hashutil
from swh.vault import get_vault
from swh.vault.backend import NotFoundExc
from swh.vault.cookers import COOKER_TYPES
DEFAULT_CONFIG_PATH = "vault/server"
DEFAULT_CONFIG = {
"storage": ("dict", {"cls": "remote", "args": {"url": "http://localhost:5002/",},}),
"cache": (
"dict",
{
"cls": "pathslicing",
"args": {"root": "/srv/softwareheritage/vault", "slicing": "0:1/1:5",},
},
),
"client_max_size": ("int", 1024 ** 3),
"vault": (
"dict",
{"cls": "local", "args": {"db": "dbname=softwareheritage-vault-dev",},},
),
"scheduler": ("dict", {"cls": "remote", "url": "http://localhost:5008/",},),
}
@asyncio.coroutine
def index(request):
return aiohttp.web.Response(body="SWH Vault API server")
# Web API endpoints
@asyncio.coroutine
def vault_fetch(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
if not request.app["backend"].is_available(obj_type, obj_id):
raise NotFoundExc(f"{obj_type} {obj_id} is not available.")
return encode_data(request.app["backend"].fetch(obj_type, obj_id))
def user_info(task_info):
return {
"id": task_info["id"],
"status": task_info["task_status"],
"progress_message": task_info["progress_msg"],
"obj_type": task_info["type"],
"obj_id": hashutil.hash_to_hex(task_info["object_id"]),
}
@asyncio.coroutine
def vault_cook(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
email = request.query.get("email")
sticky = request.query.get("sticky") in ("true", "1")
if obj_type not in COOKER_TYPES:
raise NotFoundExc(f"{obj_type} is an unknown type.")
info = request.app["backend"].cook_request(
obj_type, obj_id, email=email, sticky=sticky
)
# TODO: return 201 status (Created) once the api supports it
return encode_data(user_info(info))
@asyncio.coroutine
def vault_progress(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
info = request.app["backend"].task_info(obj_type, obj_id)
if not info:
raise NotFoundExc(f"{obj_type} {obj_id} was not found.")
return encode_data(user_info(info))
# Cookers endpoints
@asyncio.coroutine
def set_progress(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
progress = yield from decode_request(request)
request.app["backend"].set_progress(obj_type, obj_id, progress)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def set_status(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
status = yield from decode_request(request)
request.app["backend"].set_status(obj_type, obj_id, status)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def put_bundle(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
# TODO: handle streaming properly
content = yield from decode_request(request)
request.app["backend"].cache.add(obj_type, obj_id, content)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def send_notif(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
request.app["backend"].send_all_notifications(obj_type, obj_id)
return encode_data(True) # FIXME: success value?
# Batch endpoints
@asyncio.coroutine
def batch_cook(request):
batch = yield from decode_request(request)
for obj_type, obj_id in batch:
if obj_type not in COOKER_TYPES:
raise NotFoundExc(f"{obj_type} is an unknown type.")
batch_id = request.app["backend"].batch_cook(batch)
return encode_data({"id": batch_id})
@asyncio.coroutine
def batch_progress(request):
batch_id = request.match_info["batch_id"]
bundles = request.app["backend"].batch_info(batch_id)
if not bundles:
raise NotFoundExc(f"Batch {batch_id} does not exist.")
bundles = [user_info(bundle) for bundle in bundles]
counter = collections.Counter(b["status"] for b in bundles)
res = {
"bundles": bundles,
"total": len(bundles),
**{k: 0 for k in ("new", "pending", "done", "failed")},
**dict(counter),
}
return encode_data(res)
# Web server
def make_app(backend, **kwargs):
app = RPCServerApp(**kwargs)
app.router.add_route("GET", "/", index)
app.client_exception_classes = (NotFoundExc,)
# Endpoints used by the web API
app.router.add_route("GET", "/fetch/{type}/{id}", vault_fetch)
app.router.add_route("POST", "/cook/{type}/{id}", vault_cook)
app.router.add_route("GET", "/progress/{type}/{id}", vault_progress)
# Endpoints used by the Cookers
app.router.add_route("POST", "/set_progress/{type}/{id}", set_progress)
app.router.add_route("POST", "/set_status/{type}/{id}", set_status)
app.router.add_route("POST", "/put_bundle/{type}/{id}", put_bundle)
app.router.add_route("POST", "/send_notif/{type}/{id}", send_notif)
# Endpoints for batch requests
app.router.add_route("POST", "/batch_cook", batch_cook)
app.router.add_route("GET", "/batch_progress/{batch_id}", batch_progress)
app["backend"] = backend
return app
def get_local_backend(cfg):
if "vault" not in cfg:
raise ValueError("missing '%vault' configuration")
vcfg = cfg["vault"]
if vcfg["cls"] != "local":
raise EnvironmentError(
"The vault backend can only be started with a 'local' " "configuration",
err=True,
)
args = vcfg["args"]
if "cache" not in args:
args["cache"] = cfg.get("cache")
if "storage" not in args:
args["storage"] = cfg.get("storage")
if "scheduler" not in args:
args["scheduler"] = cfg.get("scheduler")
for key in ("cache", "storage", "scheduler"):
if not args.get(key):
raise ValueError("invalid configuration; missing %s config entry." % key)
- return get_vault("local", args)
+ return get_vault("local", **args)
def make_app_from_configfile(config_file=None, **kwargs):
if config_file is None:
config_file = DEFAULT_CONFIG_PATH
config_file = os.environ.get("SWH_CONFIG_FILENAME", config_file)
if os.path.isfile(config_file):
cfg = config.read(config_file, DEFAULT_CONFIG)
else:
cfg = config.load_named_config(config_file, DEFAULT_CONFIG)
vault = get_local_backend(cfg)
return make_app(backend=vault, client_max_size=cfg["client_max_size"], **kwargs)
if __name__ == "__main__":
print("Deprecated. Use swh-vault ")
diff --git a/swh/vault/backend.py b/swh/vault/backend.py
index 1974e9e..69d4690 100644
--- a/swh/vault/backend.py
+++ b/swh/vault/backend.py
@@ -1,487 +1,490 @@
-# Copyright (C) 2017-2018 The Software Heritage developers
+# Copyright (C) 2017-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from email.mime.text import MIMEText
import smtplib
import psycopg2.extras
import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.model import hashutil
+from swh.scheduler import get_scheduler
from swh.scheduler.utils import create_oneshot_task_dict
+from swh.storage import get_storage
+from swh.vault.cache import VaultCache
from swh.vault.cookers import get_cooker_cls
from swh.vault.exc import NotFoundExc
cooking_task_name = "swh.vault.cooking_tasks.SWHCookingTask"
NOTIF_EMAIL_FROM = '"Software Heritage Vault" ' "<bot@softwareheritage.org>"
NOTIF_EMAIL_SUBJECT_SUCCESS = "Bundle ready: {obj_type} {short_id}"
NOTIF_EMAIL_SUBJECT_FAILURE = "Bundle failed: {obj_type} {short_id}"
NOTIF_EMAIL_BODY_SUCCESS = """
You have requested the following bundle from the Software Heritage
Vault:
Object Type: {obj_type}
Object ID: {hex_id}
This bundle is now available for download at the following address:
{url}
Please keep in mind that this link might expire at some point, in which
case you will need to request the bundle again.
--\x20
The Software Heritage Developers
"""
NOTIF_EMAIL_BODY_FAILURE = """
You have requested the following bundle from the Software Heritage
Vault:
Object Type: {obj_type}
Object ID: {hex_id}
This bundle could not be cooked for the following reason:
{progress_msg}
We apologize for the inconvenience.
--\x20
The Software Heritage Developers
"""
def batch_to_bytes(batch):
return [(obj_type, hashutil.hash_to_bytes(obj_id)) for obj_type, obj_id in batch]
class VaultBackend:
"""
Backend for the Software Heritage vault.
"""
- def __init__(self, db, cache, scheduler, storage=None, **config):
+ def __init__(self, db, **config):
self.config = config
- self.cache = cache
- self.scheduler = scheduler
- self.storage = storage
+ self.cache = VaultCache(**config["cache"])
+ self.scheduler = get_scheduler(**config["scheduler"])
+ self.storage = get_storage(**config["storage"])
self.smtp_server = smtplib.SMTP()
self._pool = psycopg2.pool.ThreadedConnectionPool(
config.get("min_pool_conns", 1),
config.get("max_pool_conns", 10),
db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._db = None
def get_db(self):
if self._db:
return self._db
return BaseDb.from_pool(self._pool)
def put_db(self, db):
if db is not self._db:
db.put_conn()
@db_transaction()
def task_info(self, obj_type, obj_id, db=None, cur=None):
"""Fetch information from a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
SELECT id, type, object_id, task_id, task_status, sticky,
ts_created, ts_done, ts_last_access, progress_msg
FROM vault_bundle
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
res = cur.fetchone()
if res:
res["object_id"] = bytes(res["object_id"])
return res
def _send_task(self, *args):
"""Send a cooking task to the celery scheduler"""
task = create_oneshot_task_dict("cook-vault-bundle", *args)
added_tasks = self.scheduler.create_tasks([task])
return added_tasks[0]["id"]
@db_transaction()
def create_task(self, obj_type, obj_id, sticky=False, db=None, cur=None):
"""Create and send a cooking task"""
obj_id = hashutil.hash_to_bytes(obj_id)
hex_id = hashutil.hash_to_hex(obj_id)
cooker_class = get_cooker_cls(obj_type)
cooker = cooker_class(obj_type, hex_id, backend=self, storage=self.storage)
if not cooker.check_exists():
raise NotFoundExc("Object {} was not found.".format(hex_id))
cur.execute(
"""
INSERT INTO vault_bundle (type, object_id, sticky)
VALUES (%s, %s, %s)""",
(obj_type, obj_id, sticky),
)
db.conn.commit()
task_id = self._send_task(obj_type, hex_id)
cur.execute(
"""
UPDATE vault_bundle
SET task_id = %s
WHERE type = %s AND object_id = %s""",
(task_id, obj_type, obj_id),
)
@db_transaction()
def add_notif_email(self, obj_type, obj_id, email, db=None, cur=None):
"""Add an e-mail address to notify when a given bundle is ready"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
INSERT INTO vault_notif_email (email, bundle_id)
VALUES (%s, (SELECT id FROM vault_bundle
WHERE type = %s AND object_id = %s))""",
(email, obj_type, obj_id),
)
@db_transaction()
def cook_request(
self, obj_type, obj_id, *, sticky=False, email=None, db=None, cur=None
):
"""Main entry point for cooking requests. This starts a cooking task if
needed, and add the given e-mail to the notify list"""
obj_id = hashutil.hash_to_bytes(obj_id)
info = self.task_info(obj_type, obj_id)
# If there's a failed bundle entry, delete it first.
if info is not None and info["task_status"] == "failed":
cur.execute(
"""DELETE FROM vault_bundle
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
db.conn.commit()
info = None
# If there's no bundle entry, create the task.
if info is None:
self.create_task(obj_type, obj_id, sticky)
if email is not None:
# If the task is already done, send the email directly
if info is not None and info["task_status"] == "done":
self.send_notification(
None, email, obj_type, obj_id, info["task_status"]
)
# Else, add it to the notification queue
else:
self.add_notif_email(obj_type, obj_id, email)
info = self.task_info(obj_type, obj_id)
return info
@db_transaction()
def batch_cook(self, batch, db=None, cur=None):
"""Cook a batch of bundles and returns the cooking id."""
# Import execute_values at runtime only, because it requires
# psycopg2 >= 2.7 (only available on postgresql servers)
from psycopg2.extras import execute_values
cur.execute(
"""
INSERT INTO vault_batch (id)
VALUES (DEFAULT)
RETURNING id"""
)
batch_id = cur.fetchone()["id"]
batch = batch_to_bytes(batch)
# Delete all failed bundles from the batch
cur.execute(
"""
DELETE FROM vault_bundle
WHERE task_status = 'failed'
AND (type, object_id) IN %s""",
(tuple(batch),),
)
# Insert all the bundles, return the new ones
execute_values(
cur,
"""
INSERT INTO vault_bundle (type, object_id)
VALUES %s ON CONFLICT DO NOTHING""",
batch,
)
# Get the bundle ids and task status
cur.execute(
"""
SELECT id, type, object_id, task_id FROM vault_bundle
WHERE (type, object_id) IN %s""",
(tuple(batch),),
)
bundles = cur.fetchall()
# Insert the batch-bundle entries
batch_id_bundle_ids = [(batch_id, row["id"]) for row in bundles]
execute_values(
cur,
"""
INSERT INTO vault_batch_bundle (batch_id, bundle_id)
VALUES %s ON CONFLICT DO NOTHING""",
batch_id_bundle_ids,
)
db.conn.commit()
# Get the tasks to fetch
batch_new = [
(row["type"], bytes(row["object_id"]))
for row in bundles
if row["task_id"] is None
]
# Send the tasks
args_batch = [
(obj_type, hashutil.hash_to_hex(obj_id)) for obj_type, obj_id in batch_new
]
# TODO: change once the scheduler handles priority tasks
tasks = [
create_oneshot_task_dict("swh-vault-batch-cooking", *args)
for args in args_batch
]
added_tasks = self.scheduler.create_tasks(tasks)
tasks_ids_bundle_ids = zip([task["id"] for task in added_tasks], batch_new)
tasks_ids_bundle_ids = [
(task_id, obj_type, obj_id)
for task_id, (obj_type, obj_id) in tasks_ids_bundle_ids
]
# Update the task ids
execute_values(
cur,
"""
UPDATE vault_bundle
SET task_id = s_task_id
FROM (VALUES %s) AS sub (s_task_id, s_type, s_object_id)
WHERE type = s_type::cook_type AND object_id = s_object_id """,
tasks_ids_bundle_ids,
)
return batch_id
@db_transaction()
def batch_info(self, batch_id, db=None, cur=None):
"""Fetch information from a batch of bundles"""
cur.execute(
"""
SELECT vault_bundle.id as id,
type, object_id, task_id, task_status, sticky,
ts_created, ts_done, ts_last_access, progress_msg
FROM vault_batch_bundle
LEFT JOIN vault_bundle ON vault_bundle.id = bundle_id
WHERE batch_id = %s""",
(batch_id,),
)
res = cur.fetchall()
if res:
for d in res:
d["object_id"] = bytes(d["object_id"])
return res
@db_transaction()
def is_available(self, obj_type, obj_id, db=None, cur=None):
"""Check whether a bundle is available for retrieval"""
info = self.task_info(obj_type, obj_id, cur=cur)
return (
info is not None
and info["task_status"] == "done"
and self.cache.is_cached(obj_type, obj_id)
)
@db_transaction()
def fetch(self, obj_type, obj_id, db=None, cur=None):
"""Retrieve a bundle from the cache"""
if not self.is_available(obj_type, obj_id, cur=cur):
return None
self.update_access_ts(obj_type, obj_id, cur=cur)
return self.cache.get(obj_type, obj_id)
@db_transaction()
def update_access_ts(self, obj_type, obj_id, db=None, cur=None):
"""Update the last access timestamp of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
UPDATE vault_bundle
SET ts_last_access = NOW()
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
@db_transaction()
def set_status(self, obj_type, obj_id, status, db=None, cur=None):
"""Set the cooking status of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
req = (
"""
UPDATE vault_bundle
SET task_status = %s """
+ (""", ts_done = NOW() """ if status == "done" else "")
+ """WHERE type = %s AND object_id = %s"""
)
cur.execute(req, (status, obj_type, obj_id))
@db_transaction()
def set_progress(self, obj_type, obj_id, progress, db=None, cur=None):
"""Set the cooking progress of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
UPDATE vault_bundle
SET progress_msg = %s
WHERE type = %s AND object_id = %s""",
(progress, obj_type, obj_id),
)
@db_transaction()
def send_all_notifications(self, obj_type, obj_id, db=None, cur=None):
"""Send all the e-mails in the notification list of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
SELECT vault_notif_email.id AS id, email, task_status, progress_msg
FROM vault_notif_email
INNER JOIN vault_bundle ON bundle_id = vault_bundle.id
WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s""",
(obj_type, obj_id),
)
for d in cur:
self.send_notification(
d["id"],
d["email"],
obj_type,
obj_id,
status=d["task_status"],
progress_msg=d["progress_msg"],
)
@db_transaction()
def send_notification(
self,
n_id,
email,
obj_type,
obj_id,
status,
progress_msg=None,
db=None,
cur=None,
):
"""Send the notification of a bundle to a specific e-mail"""
hex_id = hashutil.hash_to_hex(obj_id)
short_id = hex_id[:7]
# TODO: instead of hardcoding this, we should probably:
# * add a "fetch_url" field in the vault_notif_email table
# * generate the url with flask.url_for() on the web-ui side
# * send this url as part of the cook request and store it in
# the table
# * use this url for the notification e-mail
url = "https://archive.softwareheritage.org/api/1/vault/{}/{}/" "raw".format(
obj_type, hex_id
)
if status == "done":
text = NOTIF_EMAIL_BODY_SUCCESS.strip()
text = text.format(obj_type=obj_type, hex_id=hex_id, url=url)
msg = MIMEText(text)
msg["Subject"] = NOTIF_EMAIL_SUBJECT_SUCCESS.format(
obj_type=obj_type, short_id=short_id
)
elif status == "failed":
text = NOTIF_EMAIL_BODY_FAILURE.strip()
text = text.format(
obj_type=obj_type, hex_id=hex_id, progress_msg=progress_msg
)
msg = MIMEText(text)
msg["Subject"] = NOTIF_EMAIL_SUBJECT_FAILURE.format(
obj_type=obj_type, short_id=short_id
)
else:
raise RuntimeError(
"send_notification called on a '{}' bundle".format(status)
)
msg["From"] = NOTIF_EMAIL_FROM
msg["To"] = email
self._smtp_send(msg)
if n_id is not None:
cur.execute(
"""
DELETE FROM vault_notif_email
WHERE id = %s""",
(n_id,),
)
def _smtp_send(self, msg):
# Reconnect if needed
try:
status = self.smtp_server.noop()[0]
except smtplib.SMTPException:
status = -1
if status != 250:
self.smtp_server.connect("localhost", 25)
# Send the message
self.smtp_server.send_message(msg)
@db_transaction()
def _cache_expire(self, cond, *args, db=None, cur=None):
"""Low-level expiration method, used by cache_expire_* methods"""
# Embedded SELECT query to be able to use ORDER BY and LIMIT
cur.execute(
"""
DELETE FROM vault_bundle
WHERE ctid IN (
SELECT ctid
FROM vault_bundle
WHERE sticky = false
{}
)
RETURNING type, object_id
""".format(
cond
),
args,
)
for d in cur:
self.cache.delete(d["type"], bytes(d["object_id"]))
@db_transaction()
def cache_expire_oldest(self, n=1, by="last_access", db=None, cur=None):
"""Expire the `n` oldest bundles"""
assert by in ("created", "done", "last_access")
filter = """ORDER BY ts_{} LIMIT {}""".format(by, n)
return self._cache_expire(filter)
@db_transaction()
def cache_expire_until(self, date, by="last_access", db=None, cur=None):
"""Expire all the bundles until a certain date"""
assert by in ("created", "done", "last_access")
filter = """AND ts_{} <= %s""".format(by)
return self._cache_expire(filter, date)
diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py
index 9090e46..163ca80 100644
--- a/swh/vault/tests/conftest.py
+++ b/swh/vault/tests/conftest.py
@@ -1,88 +1,88 @@
import glob
import os
import subprocess
+from typing import Any, Dict
import pkg_resources.extern.packaging.version
import pytest
from pytest_postgresql import factories
from swh.core.utils import numfile_sortkey as sortkey
from swh.storage.tests import SQL_DIR as STORAGE_SQL_DIR
from swh.vault import get_vault
from swh.vault.tests import SQL_DIR
os.environ["LC_ALL"] = "C.UTF-8"
pytest_v = pkg_resources.get_distribution("pytest").parsed_version
if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"):
@pytest.fixture
def tmp_path(request):
import pathlib
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
yield pathlib.Path(tmpdir)
def db_url(name, postgresql_proc):
return "postgresql://{user}@{host}:{port}/{dbname}".format(
host=postgresql_proc.host,
port=postgresql_proc.port,
user="postgres",
dbname=name,
)
postgresql2 = factories.postgresql("postgresql_proc", "tests2")
@pytest.fixture
-def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
+def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]:
+ tmp_path = str(tmp_path)
+ return {
+ "db": postgresql.dsn,
+ "storage": {
+ "cls": "local",
+ "db": postgresql2.dsn,
+ "objstorage": {
+ "cls": "pathslicing",
+ "args": {"root": tmp_path, "slicing": "0:1/1:5",},
+ },
+ },
+ "cache": {
+ "cls": "pathslicing",
+ "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,},
+ },
+ "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
+ }
+
+@pytest.fixture
+def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path):
for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)):
dump_files = os.path.join(sql_dir, "*.sql")
all_dump_files = sorted(glob.glob(dump_files), key=sortkey)
for fname in all_dump_files:
subprocess.check_call(
[
"psql",
"--quiet",
"--no-psqlrc",
"-v",
"ON_ERROR_STOP=1",
"-d",
pg.dsn,
"-f",
fname,
]
)
- vault_config = {
- "db": db_url("tests", postgresql_proc),
- "storage": {
- "cls": "local",
- "db": db_url("tests2", postgresql_proc),
- "objstorage": {
- "cls": "pathslicing",
- "args": {"root": str(tmp_path), "slicing": "0:1/1:5",},
- },
- },
- "cache": {
- "cls": "pathslicing",
- "args": {
- "root": str(tmp_path),
- "slicing": "0:1/1:5",
- "allow_delete": True,
- },
- },
- "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
- }
-
- return get_vault("local", vault_config)
+ return get_vault("local", **swh_vault_config)
@pytest.fixture
def swh_storage(swh_vault):
return swh_vault.storage
diff --git a/swh/vault/tests/test_init.py b/swh/vault/tests/test_init.py
new file mode 100644
index 0000000..7f402d6
--- /dev/null
+++ b/swh/vault/tests/test_init.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import pytest
+
+from swh.vault import get_vault
+from swh.vault.api.client import RemoteVaultClient
+from swh.vault.backend import VaultBackend
+
+SERVER_IMPLEMENTATIONS = [
+ ("remote", RemoteVaultClient, {"url": "localhost"}),
+ (
+ "local",
+ VaultBackend,
+ {
+ "db": "something",
+ "cache": {"cls": "memory", "args": {}},
+ "storage": {"cls": "remote", "url": "mock://storage-url"},
+ "scheduler": {"cls": "remote", "url": "mock://scheduler-url"},
+ },
+ ),
+]
+
+
+@pytest.fixture
+def mock_psycopg2(mocker):
+ mocker.patch("swh.vault.backend.psycopg2.pool")
+ mocker.patch("swh.vault.backend.psycopg2.extras")
+
+
+def test_init_get_vault_failure():
+ with pytest.raises(ValueError, match="Unknown Vault class"):
+ get_vault("unknown-vault-storage")
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2):
+ concrete_vault = get_vault(class_name, **kwargs)
+ assert isinstance(concrete_vault, expected_class)
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault_deprecation_warning(
+ class_name, expected_class, kwargs, mock_psycopg2
+):
+ with pytest.warns(DeprecationWarning):
+ concrete_vault = get_vault(class_name, args=kwargs)
+ assert isinstance(concrete_vault, expected_class)
+
+
+def test_init_get_vault_ok(swh_vault_config):
+ concrete_vault = get_vault("local", **swh_vault_config)
+ assert isinstance(concrete_vault, VaultBackend)

File Metadata

Mime Type
text/x-diff
Expires
Fri, Jul 4, 11:35 AM (3 w, 2 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3247413

Event Timeline