Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F9341053
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
32 KB
Subscribers
None
View Options
diff --git a/requirements-test.txt b/requirements-test.txt
index 078a4e3..66b4544 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -1,7 +1,8 @@
pytest
pytest-aiohttp
pytest-postgresql
dulwich >= 0.18.7
swh.loader.core
swh.loader.git >= 0.0.52
swh.storage[testing]
+pytest-mock
diff --git a/swh/vault/__init__.py b/swh/vault/__init__.py
index a39a171..db16ff9 100644
--- a/swh/vault/__init__.py
+++ b/swh/vault/__init__.py
@@ -1,41 +1,54 @@
-# Copyright (C) 2018 The Software Heritage developers
+# Copyright (C) 2018-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU Affero General Public License version 3, or any later version
# See top-level LICENSE file for more information
+
+from __future__ import annotations
+
+import importlib
import logging
+from typing import Dict
+import warnings
logger = logging.getLogger(__name__)
-def get_vault(cls="remote", args={}):
+BACKEND_TYPES: Dict[str, str] = {
+ "remote": ".api.client.RemoteVaultClient",
+ "local": ".backend.VaultBackend",
+}
+
+
+def get_vault(cls: str = "remote", **kwargs):
"""
Get a vault object of class `vault_class` with arguments
`vault_args`.
Args:
- vault (dict): dictionary with keys:
- - cls (str): vault's class, either 'remote'
- - args (dict): dictionary with keys
+ cls: vault's class, either 'remote' or 'local'
+ kwargs: arguments to pass to the class' constructor
Returns:
an instance of VaultBackend (either local or remote)
Raises:
ValueError if passed an unknown storage class.
"""
- if cls == "remote":
- from .api.client import RemoteVaultClient as Vault
- elif cls == "local":
- from swh.scheduler import get_scheduler
- from swh.storage import get_storage
- from swh.vault.backend import VaultBackend as Vault
- from swh.vault.cache import VaultCache
-
- args["cache"] = VaultCache(**args["cache"])
- args["storage"] = get_storage(**args["storage"])
- args["scheduler"] = get_scheduler(**args["scheduler"])
- else:
- raise ValueError("Unknown storage class `%s`" % cls)
- logger.debug("Instantiating %s with %s" % (Vault, args))
- return Vault(**args)
+ if "args" in kwargs:
+ warnings.warn(
+ 'Explicit "args" key is deprecated, use keys directly instead.',
+ DeprecationWarning,
+ )
+ kwargs = kwargs["args"]
+
+ class_path = BACKEND_TYPES.get(cls)
+ if class_path is None:
+ raise ValueError(
+ f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}"
+ )
+
+ (module_path, class_name) = class_path.rsplit(".", 1)
+ module = importlib.import_module(module_path, package=__package__)
+ Vault = getattr(module, class_name)
+ return Vault(**kwargs)
diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py
index 6c178e0..6440fc2 100644
--- a/swh/vault/api/server.py
+++ b/swh/vault/api/server.py
@@ -1,233 +1,233 @@
# Copyright (C) 2016-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import asyncio
import collections
import os
import aiohttp.web
from swh.core import config
from swh.core.api.asynchronous import RPCServerApp, decode_request
from swh.core.api.asynchronous import encode_data_server as encode_data
from swh.model import hashutil
from swh.vault import get_vault
from swh.vault.backend import NotFoundExc
from swh.vault.cookers import COOKER_TYPES
DEFAULT_CONFIG_PATH = "vault/server"
DEFAULT_CONFIG = {
"storage": ("dict", {"cls": "remote", "args": {"url": "http://localhost:5002/",},}),
"cache": (
"dict",
{
"cls": "pathslicing",
"args": {"root": "/srv/softwareheritage/vault", "slicing": "0:1/1:5",},
},
),
"client_max_size": ("int", 1024 ** 3),
"vault": (
"dict",
{"cls": "local", "args": {"db": "dbname=softwareheritage-vault-dev",},},
),
"scheduler": ("dict", {"cls": "remote", "url": "http://localhost:5008/",},),
}
@asyncio.coroutine
def index(request):
return aiohttp.web.Response(body="SWH Vault API server")
# Web API endpoints
@asyncio.coroutine
def vault_fetch(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
if not request.app["backend"].is_available(obj_type, obj_id):
raise NotFoundExc(f"{obj_type} {obj_id} is not available.")
return encode_data(request.app["backend"].fetch(obj_type, obj_id))
def user_info(task_info):
return {
"id": task_info["id"],
"status": task_info["task_status"],
"progress_message": task_info["progress_msg"],
"obj_type": task_info["type"],
"obj_id": hashutil.hash_to_hex(task_info["object_id"]),
}
@asyncio.coroutine
def vault_cook(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
email = request.query.get("email")
sticky = request.query.get("sticky") in ("true", "1")
if obj_type not in COOKER_TYPES:
raise NotFoundExc(f"{obj_type} is an unknown type.")
info = request.app["backend"].cook_request(
obj_type, obj_id, email=email, sticky=sticky
)
# TODO: return 201 status (Created) once the api supports it
return encode_data(user_info(info))
@asyncio.coroutine
def vault_progress(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
info = request.app["backend"].task_info(obj_type, obj_id)
if not info:
raise NotFoundExc(f"{obj_type} {obj_id} was not found.")
return encode_data(user_info(info))
# Cookers endpoints
@asyncio.coroutine
def set_progress(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
progress = yield from decode_request(request)
request.app["backend"].set_progress(obj_type, obj_id, progress)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def set_status(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
status = yield from decode_request(request)
request.app["backend"].set_status(obj_type, obj_id, status)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def put_bundle(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
# TODO: handle streaming properly
content = yield from decode_request(request)
request.app["backend"].cache.add(obj_type, obj_id, content)
return encode_data(True) # FIXME: success value?
@asyncio.coroutine
def send_notif(request):
obj_type = request.match_info["type"]
obj_id = request.match_info["id"]
request.app["backend"].send_all_notifications(obj_type, obj_id)
return encode_data(True) # FIXME: success value?
# Batch endpoints
@asyncio.coroutine
def batch_cook(request):
batch = yield from decode_request(request)
for obj_type, obj_id in batch:
if obj_type not in COOKER_TYPES:
raise NotFoundExc(f"{obj_type} is an unknown type.")
batch_id = request.app["backend"].batch_cook(batch)
return encode_data({"id": batch_id})
@asyncio.coroutine
def batch_progress(request):
batch_id = request.match_info["batch_id"]
bundles = request.app["backend"].batch_info(batch_id)
if not bundles:
raise NotFoundExc(f"Batch {batch_id} does not exist.")
bundles = [user_info(bundle) for bundle in bundles]
counter = collections.Counter(b["status"] for b in bundles)
res = {
"bundles": bundles,
"total": len(bundles),
**{k: 0 for k in ("new", "pending", "done", "failed")},
**dict(counter),
}
return encode_data(res)
# Web server
def make_app(backend, **kwargs):
app = RPCServerApp(**kwargs)
app.router.add_route("GET", "/", index)
app.client_exception_classes = (NotFoundExc,)
# Endpoints used by the web API
app.router.add_route("GET", "/fetch/{type}/{id}", vault_fetch)
app.router.add_route("POST", "/cook/{type}/{id}", vault_cook)
app.router.add_route("GET", "/progress/{type}/{id}", vault_progress)
# Endpoints used by the Cookers
app.router.add_route("POST", "/set_progress/{type}/{id}", set_progress)
app.router.add_route("POST", "/set_status/{type}/{id}", set_status)
app.router.add_route("POST", "/put_bundle/{type}/{id}", put_bundle)
app.router.add_route("POST", "/send_notif/{type}/{id}", send_notif)
# Endpoints for batch requests
app.router.add_route("POST", "/batch_cook", batch_cook)
app.router.add_route("GET", "/batch_progress/{batch_id}", batch_progress)
app["backend"] = backend
return app
def get_local_backend(cfg):
if "vault" not in cfg:
raise ValueError("missing '%vault' configuration")
vcfg = cfg["vault"]
if vcfg["cls"] != "local":
raise EnvironmentError(
"The vault backend can only be started with a 'local' " "configuration",
err=True,
)
args = vcfg["args"]
if "cache" not in args:
args["cache"] = cfg.get("cache")
if "storage" not in args:
args["storage"] = cfg.get("storage")
if "scheduler" not in args:
args["scheduler"] = cfg.get("scheduler")
for key in ("cache", "storage", "scheduler"):
if not args.get(key):
raise ValueError("invalid configuration; missing %s config entry." % key)
- return get_vault("local", args)
+ return get_vault("local", **args)
def make_app_from_configfile(config_file=None, **kwargs):
if config_file is None:
config_file = DEFAULT_CONFIG_PATH
config_file = os.environ.get("SWH_CONFIG_FILENAME", config_file)
if os.path.isfile(config_file):
cfg = config.read(config_file, DEFAULT_CONFIG)
else:
cfg = config.load_named_config(config_file, DEFAULT_CONFIG)
vault = get_local_backend(cfg)
return make_app(backend=vault, client_max_size=cfg["client_max_size"], **kwargs)
if __name__ == "__main__":
print("Deprecated. Use swh-vault ")
diff --git a/swh/vault/backend.py b/swh/vault/backend.py
index 1974e9e..69d4690 100644
--- a/swh/vault/backend.py
+++ b/swh/vault/backend.py
@@ -1,487 +1,490 @@
-# Copyright (C) 2017-2018 The Software Heritage developers
+# Copyright (C) 2017-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from email.mime.text import MIMEText
import smtplib
import psycopg2.extras
import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.model import hashutil
+from swh.scheduler import get_scheduler
from swh.scheduler.utils import create_oneshot_task_dict
+from swh.storage import get_storage
+from swh.vault.cache import VaultCache
from swh.vault.cookers import get_cooker_cls
from swh.vault.exc import NotFoundExc
cooking_task_name = "swh.vault.cooking_tasks.SWHCookingTask"
NOTIF_EMAIL_FROM = '"Software Heritage Vault" ' "<bot@softwareheritage.org>"
NOTIF_EMAIL_SUBJECT_SUCCESS = "Bundle ready: {obj_type} {short_id}"
NOTIF_EMAIL_SUBJECT_FAILURE = "Bundle failed: {obj_type} {short_id}"
NOTIF_EMAIL_BODY_SUCCESS = """
You have requested the following bundle from the Software Heritage
Vault:
Object Type: {obj_type}
Object ID: {hex_id}
This bundle is now available for download at the following address:
{url}
Please keep in mind that this link might expire at some point, in which
case you will need to request the bundle again.
--\x20
The Software Heritage Developers
"""
NOTIF_EMAIL_BODY_FAILURE = """
You have requested the following bundle from the Software Heritage
Vault:
Object Type: {obj_type}
Object ID: {hex_id}
This bundle could not be cooked for the following reason:
{progress_msg}
We apologize for the inconvenience.
--\x20
The Software Heritage Developers
"""
def batch_to_bytes(batch):
return [(obj_type, hashutil.hash_to_bytes(obj_id)) for obj_type, obj_id in batch]
class VaultBackend:
"""
Backend for the Software Heritage vault.
"""
- def __init__(self, db, cache, scheduler, storage=None, **config):
+ def __init__(self, db, **config):
self.config = config
- self.cache = cache
- self.scheduler = scheduler
- self.storage = storage
+ self.cache = VaultCache(**config["cache"])
+ self.scheduler = get_scheduler(**config["scheduler"])
+ self.storage = get_storage(**config["storage"])
self.smtp_server = smtplib.SMTP()
self._pool = psycopg2.pool.ThreadedConnectionPool(
config.get("min_pool_conns", 1),
config.get("max_pool_conns", 10),
db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._db = None
def get_db(self):
if self._db:
return self._db
return BaseDb.from_pool(self._pool)
def put_db(self, db):
if db is not self._db:
db.put_conn()
@db_transaction()
def task_info(self, obj_type, obj_id, db=None, cur=None):
"""Fetch information from a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
SELECT id, type, object_id, task_id, task_status, sticky,
ts_created, ts_done, ts_last_access, progress_msg
FROM vault_bundle
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
res = cur.fetchone()
if res:
res["object_id"] = bytes(res["object_id"])
return res
def _send_task(self, *args):
"""Send a cooking task to the celery scheduler"""
task = create_oneshot_task_dict("cook-vault-bundle", *args)
added_tasks = self.scheduler.create_tasks([task])
return added_tasks[0]["id"]
@db_transaction()
def create_task(self, obj_type, obj_id, sticky=False, db=None, cur=None):
"""Create and send a cooking task"""
obj_id = hashutil.hash_to_bytes(obj_id)
hex_id = hashutil.hash_to_hex(obj_id)
cooker_class = get_cooker_cls(obj_type)
cooker = cooker_class(obj_type, hex_id, backend=self, storage=self.storage)
if not cooker.check_exists():
raise NotFoundExc("Object {} was not found.".format(hex_id))
cur.execute(
"""
INSERT INTO vault_bundle (type, object_id, sticky)
VALUES (%s, %s, %s)""",
(obj_type, obj_id, sticky),
)
db.conn.commit()
task_id = self._send_task(obj_type, hex_id)
cur.execute(
"""
UPDATE vault_bundle
SET task_id = %s
WHERE type = %s AND object_id = %s""",
(task_id, obj_type, obj_id),
)
@db_transaction()
def add_notif_email(self, obj_type, obj_id, email, db=None, cur=None):
"""Add an e-mail address to notify when a given bundle is ready"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
INSERT INTO vault_notif_email (email, bundle_id)
VALUES (%s, (SELECT id FROM vault_bundle
WHERE type = %s AND object_id = %s))""",
(email, obj_type, obj_id),
)
@db_transaction()
def cook_request(
self, obj_type, obj_id, *, sticky=False, email=None, db=None, cur=None
):
"""Main entry point for cooking requests. This starts a cooking task if
needed, and add the given e-mail to the notify list"""
obj_id = hashutil.hash_to_bytes(obj_id)
info = self.task_info(obj_type, obj_id)
# If there's a failed bundle entry, delete it first.
if info is not None and info["task_status"] == "failed":
cur.execute(
"""DELETE FROM vault_bundle
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
db.conn.commit()
info = None
# If there's no bundle entry, create the task.
if info is None:
self.create_task(obj_type, obj_id, sticky)
if email is not None:
# If the task is already done, send the email directly
if info is not None and info["task_status"] == "done":
self.send_notification(
None, email, obj_type, obj_id, info["task_status"]
)
# Else, add it to the notification queue
else:
self.add_notif_email(obj_type, obj_id, email)
info = self.task_info(obj_type, obj_id)
return info
@db_transaction()
def batch_cook(self, batch, db=None, cur=None):
"""Cook a batch of bundles and returns the cooking id."""
# Import execute_values at runtime only, because it requires
# psycopg2 >= 2.7 (only available on postgresql servers)
from psycopg2.extras import execute_values
cur.execute(
"""
INSERT INTO vault_batch (id)
VALUES (DEFAULT)
RETURNING id"""
)
batch_id = cur.fetchone()["id"]
batch = batch_to_bytes(batch)
# Delete all failed bundles from the batch
cur.execute(
"""
DELETE FROM vault_bundle
WHERE task_status = 'failed'
AND (type, object_id) IN %s""",
(tuple(batch),),
)
# Insert all the bundles, return the new ones
execute_values(
cur,
"""
INSERT INTO vault_bundle (type, object_id)
VALUES %s ON CONFLICT DO NOTHING""",
batch,
)
# Get the bundle ids and task status
cur.execute(
"""
SELECT id, type, object_id, task_id FROM vault_bundle
WHERE (type, object_id) IN %s""",
(tuple(batch),),
)
bundles = cur.fetchall()
# Insert the batch-bundle entries
batch_id_bundle_ids = [(batch_id, row["id"]) for row in bundles]
execute_values(
cur,
"""
INSERT INTO vault_batch_bundle (batch_id, bundle_id)
VALUES %s ON CONFLICT DO NOTHING""",
batch_id_bundle_ids,
)
db.conn.commit()
# Get the tasks to fetch
batch_new = [
(row["type"], bytes(row["object_id"]))
for row in bundles
if row["task_id"] is None
]
# Send the tasks
args_batch = [
(obj_type, hashutil.hash_to_hex(obj_id)) for obj_type, obj_id in batch_new
]
# TODO: change once the scheduler handles priority tasks
tasks = [
create_oneshot_task_dict("swh-vault-batch-cooking", *args)
for args in args_batch
]
added_tasks = self.scheduler.create_tasks(tasks)
tasks_ids_bundle_ids = zip([task["id"] for task in added_tasks], batch_new)
tasks_ids_bundle_ids = [
(task_id, obj_type, obj_id)
for task_id, (obj_type, obj_id) in tasks_ids_bundle_ids
]
# Update the task ids
execute_values(
cur,
"""
UPDATE vault_bundle
SET task_id = s_task_id
FROM (VALUES %s) AS sub (s_task_id, s_type, s_object_id)
WHERE type = s_type::cook_type AND object_id = s_object_id """,
tasks_ids_bundle_ids,
)
return batch_id
@db_transaction()
def batch_info(self, batch_id, db=None, cur=None):
"""Fetch information from a batch of bundles"""
cur.execute(
"""
SELECT vault_bundle.id as id,
type, object_id, task_id, task_status, sticky,
ts_created, ts_done, ts_last_access, progress_msg
FROM vault_batch_bundle
LEFT JOIN vault_bundle ON vault_bundle.id = bundle_id
WHERE batch_id = %s""",
(batch_id,),
)
res = cur.fetchall()
if res:
for d in res:
d["object_id"] = bytes(d["object_id"])
return res
@db_transaction()
def is_available(self, obj_type, obj_id, db=None, cur=None):
"""Check whether a bundle is available for retrieval"""
info = self.task_info(obj_type, obj_id, cur=cur)
return (
info is not None
and info["task_status"] == "done"
and self.cache.is_cached(obj_type, obj_id)
)
@db_transaction()
def fetch(self, obj_type, obj_id, db=None, cur=None):
"""Retrieve a bundle from the cache"""
if not self.is_available(obj_type, obj_id, cur=cur):
return None
self.update_access_ts(obj_type, obj_id, cur=cur)
return self.cache.get(obj_type, obj_id)
@db_transaction()
def update_access_ts(self, obj_type, obj_id, db=None, cur=None):
"""Update the last access timestamp of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
UPDATE vault_bundle
SET ts_last_access = NOW()
WHERE type = %s AND object_id = %s""",
(obj_type, obj_id),
)
@db_transaction()
def set_status(self, obj_type, obj_id, status, db=None, cur=None):
"""Set the cooking status of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
req = (
"""
UPDATE vault_bundle
SET task_status = %s """
+ (""", ts_done = NOW() """ if status == "done" else "")
+ """WHERE type = %s AND object_id = %s"""
)
cur.execute(req, (status, obj_type, obj_id))
@db_transaction()
def set_progress(self, obj_type, obj_id, progress, db=None, cur=None):
"""Set the cooking progress of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
UPDATE vault_bundle
SET progress_msg = %s
WHERE type = %s AND object_id = %s""",
(progress, obj_type, obj_id),
)
@db_transaction()
def send_all_notifications(self, obj_type, obj_id, db=None, cur=None):
"""Send all the e-mails in the notification list of a bundle"""
obj_id = hashutil.hash_to_bytes(obj_id)
cur.execute(
"""
SELECT vault_notif_email.id AS id, email, task_status, progress_msg
FROM vault_notif_email
INNER JOIN vault_bundle ON bundle_id = vault_bundle.id
WHERE vault_bundle.type = %s AND vault_bundle.object_id = %s""",
(obj_type, obj_id),
)
for d in cur:
self.send_notification(
d["id"],
d["email"],
obj_type,
obj_id,
status=d["task_status"],
progress_msg=d["progress_msg"],
)
@db_transaction()
def send_notification(
self,
n_id,
email,
obj_type,
obj_id,
status,
progress_msg=None,
db=None,
cur=None,
):
"""Send the notification of a bundle to a specific e-mail"""
hex_id = hashutil.hash_to_hex(obj_id)
short_id = hex_id[:7]
# TODO: instead of hardcoding this, we should probably:
# * add a "fetch_url" field in the vault_notif_email table
# * generate the url with flask.url_for() on the web-ui side
# * send this url as part of the cook request and store it in
# the table
# * use this url for the notification e-mail
url = "https://archive.softwareheritage.org/api/1/vault/{}/{}/" "raw".format(
obj_type, hex_id
)
if status == "done":
text = NOTIF_EMAIL_BODY_SUCCESS.strip()
text = text.format(obj_type=obj_type, hex_id=hex_id, url=url)
msg = MIMEText(text)
msg["Subject"] = NOTIF_EMAIL_SUBJECT_SUCCESS.format(
obj_type=obj_type, short_id=short_id
)
elif status == "failed":
text = NOTIF_EMAIL_BODY_FAILURE.strip()
text = text.format(
obj_type=obj_type, hex_id=hex_id, progress_msg=progress_msg
)
msg = MIMEText(text)
msg["Subject"] = NOTIF_EMAIL_SUBJECT_FAILURE.format(
obj_type=obj_type, short_id=short_id
)
else:
raise RuntimeError(
"send_notification called on a '{}' bundle".format(status)
)
msg["From"] = NOTIF_EMAIL_FROM
msg["To"] = email
self._smtp_send(msg)
if n_id is not None:
cur.execute(
"""
DELETE FROM vault_notif_email
WHERE id = %s""",
(n_id,),
)
def _smtp_send(self, msg):
# Reconnect if needed
try:
status = self.smtp_server.noop()[0]
except smtplib.SMTPException:
status = -1
if status != 250:
self.smtp_server.connect("localhost", 25)
# Send the message
self.smtp_server.send_message(msg)
@db_transaction()
def _cache_expire(self, cond, *args, db=None, cur=None):
"""Low-level expiration method, used by cache_expire_* methods"""
# Embedded SELECT query to be able to use ORDER BY and LIMIT
cur.execute(
"""
DELETE FROM vault_bundle
WHERE ctid IN (
SELECT ctid
FROM vault_bundle
WHERE sticky = false
{}
)
RETURNING type, object_id
""".format(
cond
),
args,
)
for d in cur:
self.cache.delete(d["type"], bytes(d["object_id"]))
@db_transaction()
def cache_expire_oldest(self, n=1, by="last_access", db=None, cur=None):
"""Expire the `n` oldest bundles"""
assert by in ("created", "done", "last_access")
filter = """ORDER BY ts_{} LIMIT {}""".format(by, n)
return self._cache_expire(filter)
@db_transaction()
def cache_expire_until(self, date, by="last_access", db=None, cur=None):
"""Expire all the bundles until a certain date"""
assert by in ("created", "done", "last_access")
filter = """AND ts_{} <= %s""".format(by)
return self._cache_expire(filter, date)
diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py
index 9090e46..163ca80 100644
--- a/swh/vault/tests/conftest.py
+++ b/swh/vault/tests/conftest.py
@@ -1,88 +1,88 @@
import glob
import os
import subprocess
+from typing import Any, Dict
import pkg_resources.extern.packaging.version
import pytest
from pytest_postgresql import factories
from swh.core.utils import numfile_sortkey as sortkey
from swh.storage.tests import SQL_DIR as STORAGE_SQL_DIR
from swh.vault import get_vault
from swh.vault.tests import SQL_DIR
os.environ["LC_ALL"] = "C.UTF-8"
pytest_v = pkg_resources.get_distribution("pytest").parsed_version
if pytest_v < pkg_resources.extern.packaging.version.parse("3.9"):
@pytest.fixture
def tmp_path(request):
import pathlib
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
yield pathlib.Path(tmpdir)
def db_url(name, postgresql_proc):
return "postgresql://{user}@{host}:{port}/{dbname}".format(
host=postgresql_proc.host,
port=postgresql_proc.port,
user="postgres",
dbname=name,
)
postgresql2 = factories.postgresql("postgresql_proc", "tests2")
@pytest.fixture
-def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
+def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]:
+ tmp_path = str(tmp_path)
+ return {
+ "db": postgresql.dsn,
+ "storage": {
+ "cls": "local",
+ "db": postgresql2.dsn,
+ "objstorage": {
+ "cls": "pathslicing",
+ "args": {"root": tmp_path, "slicing": "0:1/1:5",},
+ },
+ },
+ "cache": {
+ "cls": "pathslicing",
+ "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,},
+ },
+ "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
+ }
+
+@pytest.fixture
+def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path):
for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)):
dump_files = os.path.join(sql_dir, "*.sql")
all_dump_files = sorted(glob.glob(dump_files), key=sortkey)
for fname in all_dump_files:
subprocess.check_call(
[
"psql",
"--quiet",
"--no-psqlrc",
"-v",
"ON_ERROR_STOP=1",
"-d",
pg.dsn,
"-f",
fname,
]
)
- vault_config = {
- "db": db_url("tests", postgresql_proc),
- "storage": {
- "cls": "local",
- "db": db_url("tests2", postgresql_proc),
- "objstorage": {
- "cls": "pathslicing",
- "args": {"root": str(tmp_path), "slicing": "0:1/1:5",},
- },
- },
- "cache": {
- "cls": "pathslicing",
- "args": {
- "root": str(tmp_path),
- "slicing": "0:1/1:5",
- "allow_delete": True,
- },
- },
- "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
- }
-
- return get_vault("local", vault_config)
+ return get_vault("local", **swh_vault_config)
@pytest.fixture
def swh_storage(swh_vault):
return swh_vault.storage
diff --git a/swh/vault/tests/test_init.py b/swh/vault/tests/test_init.py
new file mode 100644
index 0000000..7f402d6
--- /dev/null
+++ b/swh/vault/tests/test_init.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import pytest
+
+from swh.vault import get_vault
+from swh.vault.api.client import RemoteVaultClient
+from swh.vault.backend import VaultBackend
+
+SERVER_IMPLEMENTATIONS = [
+ ("remote", RemoteVaultClient, {"url": "localhost"}),
+ (
+ "local",
+ VaultBackend,
+ {
+ "db": "something",
+ "cache": {"cls": "memory", "args": {}},
+ "storage": {"cls": "remote", "url": "mock://storage-url"},
+ "scheduler": {"cls": "remote", "url": "mock://scheduler-url"},
+ },
+ ),
+]
+
+
+@pytest.fixture
+def mock_psycopg2(mocker):
+ mocker.patch("swh.vault.backend.psycopg2.pool")
+ mocker.patch("swh.vault.backend.psycopg2.extras")
+
+
+def test_init_get_vault_failure():
+ with pytest.raises(ValueError, match="Unknown Vault class"):
+ get_vault("unknown-vault-storage")
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2):
+ concrete_vault = get_vault(class_name, **kwargs)
+ assert isinstance(concrete_vault, expected_class)
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault_deprecation_warning(
+ class_name, expected_class, kwargs, mock_psycopg2
+):
+ with pytest.warns(DeprecationWarning):
+ concrete_vault = get_vault(class_name, args=kwargs)
+ assert isinstance(concrete_vault, expected_class)
+
+
+def test_init_get_vault_ok(swh_vault_config):
+ concrete_vault = get_vault("local", **swh_vault_config)
+ assert isinstance(concrete_vault, VaultBackend)
File Metadata
Details
Attached
Mime Type
text/x-diff
Expires
Fri, Jul 4, 11:35 AM (3 w, 2 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3247413
Attached To
rDVAU Software Heritage Vault
Event Timeline
Log In to Comment