diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ psycopg2 python-dateutil fastimport - +typing-extensions diff --git a/swh/vault/api/client.py b/swh/vault/api/client.py --- a/swh/vault/api/client.py +++ b/swh/vault/api/client.py @@ -4,56 +4,12 @@ # See top-level LICENSE file for more information from swh.core.api import RPCClient -from swh.model import hashutil from swh.vault.exc import NotFoundExc +from swh.vault.interface import VaultInterface class RemoteVaultClient(RPCClient): """Client to the Software Heritage vault cache.""" + backend_class = VaultInterface reraise_exceptions = [NotFoundExc] - - # Web API endpoints - - def fetch(self, obj_type, obj_id): - hex_id = hashutil.hash_to_hex(obj_id) - return self.get("fetch/{}/{}".format(obj_type, hex_id)) - - def cook(self, obj_type, obj_id, email=None): - hex_id = hashutil.hash_to_hex(obj_id) - return self.post( - "cook/{}/{}".format(obj_type, hex_id), - data={}, - params=({"email": email} if email else None), - ) - - def progress(self, obj_type, obj_id): - hex_id = hashutil.hash_to_hex(obj_id) - return self.get("progress/{}/{}".format(obj_type, hex_id)) - - # Cookers endpoints - - def set_progress(self, obj_type, obj_id, progress): - hex_id = hashutil.hash_to_hex(obj_id) - return self.post("set_progress/{}/{}".format(obj_type, hex_id), data=progress) - - def set_status(self, obj_type, obj_id, status): - hex_id = hashutil.hash_to_hex(obj_id) - return self.post("set_status/{}/{}".format(obj_type, hex_id), data=status) - - # TODO: handle streaming properly - def put_bundle(self, obj_type, obj_id, bundle): - hex_id = hashutil.hash_to_hex(obj_id) - return self.post("put_bundle/{}/{}".format(obj_type, hex_id), data=bundle) - - def send_notif(self, obj_type, obj_id): - hex_id = hashutil.hash_to_hex(obj_id) - return self.post("send_notif/{}/{}".format(obj_type, hex_id), data=None) - - # Batch endpoints - - def batch_cook(self, batch): - return self.post("batch_cook", data=batch) - - def batch_progress(self, batch_id): - return self.get("batch_progress/{}".format(batch_id)) diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py --- a/swh/vault/api/server.py +++ b/swh/vault/api/server.py @@ -4,196 +4,61 @@ # See top-level LICENSE file for more information import asyncio -import collections -import os from typing import Any, Dict, Optional import aiohttp.web -from swh.core import config -from swh.core.api.asynchronous import RPCServerApp, decode_request -from swh.core.api.asynchronous import encode_data_server as encode_data -from swh.model import hashutil -from swh.vault import get_vault +from swh.core.api.asynchronous import RPCServerApp +from swh.core.config import load_from_envvar +from swh.vault import get_vault as get_swhvault from swh.vault.backend import NotFoundExc -from swh.vault.cookers import COOKER_TYPES +from swh.vault.interface import VaultInterface -DEFAULT_CONFIG_PATH = "vault/server" DEFAULT_CONFIG = { - "storage": ("dict", {"cls": "remote", "args": {"url": "http://localhost:5002/",},}), - "cache": ( - "dict", - { - "cls": "pathslicing", - "args": {"root": "/srv/softwareheritage/vault", "slicing": "0:1/1:5",}, - }, - ), - "client_max_size": ("int", 1024 ** 3), - "vault": ( - "dict", - {"cls": "local", "args": {"db": "dbname=softwareheritage-vault-dev",},}, - ), - "scheduler": ("dict", {"cls": "remote", "url": "http://localhost:5008/",},), + "storage": {"cls": "remote", "args": {"url": "http://localhost:5002/",},}, + "cache": { + "cls": "pathslicing", + "args": {"root": "/srv/softwareheritage/vault", "slicing": "0:1/1:5",}, + }, + "client_max_size": 1024 ** 3, + "vault": {"cls": "local", "args": {"db": "dbname=softwareheritage-vault-dev",},}, + "scheduler": {"cls": "remote", "url": "http://localhost:5008/",}, } -@asyncio.coroutine -def index(request): - return aiohttp.web.Response(body="SWH Vault API server") - - -# Web API endpoints - - -@asyncio.coroutine -def vault_fetch(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - - if not request.app["backend"].is_available(obj_type, obj_id): - raise NotFoundExc(f"{obj_type} {obj_id} is not available.") - - return encode_data(request.app["backend"].fetch(obj_type, obj_id)) - - -def user_info(task_info): - return { - "id": task_info["id"], - "status": task_info["task_status"], - "progress_message": task_info["progress_msg"], - "obj_type": task_info["type"], - "obj_id": hashutil.hash_to_hex(task_info["object_id"]), - } - - -@asyncio.coroutine -def vault_cook(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - email = request.query.get("email") - sticky = request.query.get("sticky") in ("true", "1") - - if obj_type not in COOKER_TYPES: - raise NotFoundExc(f"{obj_type} is an unknown type.") - - info = request.app["backend"].cook_request( - obj_type, obj_id, email=email, sticky=sticky - ) - - # TODO: return 201 status (Created) once the api supports it - return encode_data(user_info(info)) - - -@asyncio.coroutine -def vault_progress(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - - info = request.app["backend"].task_info(obj_type, obj_id) - if not info: - raise NotFoundExc(f"{obj_type} {obj_id} was not found.") - - return encode_data(user_info(info)) - - -# Cookers endpoints - - -@asyncio.coroutine -def set_progress(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - progress = yield from decode_request(request) - request.app["backend"].set_progress(obj_type, obj_id, progress) - return encode_data(True) # FIXME: success value? +vault = None +app = None -@asyncio.coroutine -def set_status(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - status = yield from decode_request(request) - request.app["backend"].set_status(obj_type, obj_id, status) - return encode_data(True) # FIXME: success value? - - -@asyncio.coroutine -def put_bundle(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - - # TODO: handle streaming properly - content = yield from decode_request(request) - request.app["backend"].cache.add(obj_type, obj_id, content) - return encode_data(True) # FIXME: success value? - - -@asyncio.coroutine -def send_notif(request): - obj_type = request.match_info["type"] - obj_id = request.match_info["id"] - request.app["backend"].send_all_notifications(obj_type, obj_id) - return encode_data(True) # FIXME: success value? +def get_vault(config: Optional[Dict[str, Any]] = None) -> VaultInterface: + global vault + if not vault: + assert config is not None + vault = get_swhvault(**config) + return vault -# Batch endpoints +class VaultServerApp(RPCServerApp): + client_exception_classes = (NotFoundExc,) @asyncio.coroutine -def batch_cook(request): - batch = yield from decode_request(request) - for obj_type, obj_id in batch: - if obj_type not in COOKER_TYPES: - raise NotFoundExc(f"{obj_type} is an unknown type.") - batch_id = request.app["backend"].batch_cook(batch) - return encode_data({"id": batch_id}) - - -@asyncio.coroutine -def batch_progress(request): - batch_id = request.match_info["batch_id"] - bundles = request.app["backend"].batch_info(batch_id) - if not bundles: - raise NotFoundExc(f"Batch {batch_id} does not exist.") - bundles = [user_info(bundle) for bundle in bundles] - counter = collections.Counter(b["status"] for b in bundles) - res = { - "bundles": bundles, - "total": len(bundles), - **{k: 0 for k in ("new", "pending", "done", "failed")}, - **dict(counter), - } - return encode_data(res) - - -# Web server - - -def make_app(backend, **kwargs): - app = RPCServerApp(**kwargs) - app.router.add_route("GET", "/", index) - app.client_exception_classes = (NotFoundExc,) - - # Endpoints used by the web API - app.router.add_route("GET", "/fetch/{type}/{id}", vault_fetch) - app.router.add_route("POST", "/cook/{type}/{id}", vault_cook) - app.router.add_route("GET", "/progress/{type}/{id}", vault_progress) +def index(request): + return aiohttp.web.Response(body="SWH Vault API server") - # Endpoints used by the Cookers - app.router.add_route("POST", "/set_progress/{type}/{id}", set_progress) - app.router.add_route("POST", "/set_status/{type}/{id}", set_status) - app.router.add_route("POST", "/put_bundle/{type}/{id}", put_bundle) - app.router.add_route("POST", "/send_notif/{type}/{id}", send_notif) - # Endpoints for batch requests - app.router.add_route("POST", "/batch_cook", batch_cook) - app.router.add_route("GET", "/batch_progress/{batch_id}", batch_progress) +def check_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + """Ensure the configuration is ok to run a local vault server - app["backend"] = backend - return app + Raises: + EnvironmentError if the configuration is not for local instance + ValueError if one of the following keys is missing: vault, cache, storage, + scheduler + Returns: + Configuration dict to instantiate a local vault server instance -def check_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + """ if "vault" not in cfg: raise ValueError("missing 'vault' configuration") @@ -214,21 +79,34 @@ if not args.get(key): raise ValueError(f"invalid configuration: missing {key} config entry.") - return args - - -def make_app_from_configfile(config_file: Optional[str] = None, **kwargs): - if config_file is None: - config_file = DEFAULT_CONFIG_PATH - config_file = os.environ.get("SWH_CONFIG_FILENAME", config_file) - assert config_file is not None - if os.path.isfile(config_file): - cfg = config.read(config_file, DEFAULT_CONFIG) - else: - cfg = config.load_named_config(config_file, DEFAULT_CONFIG) - kwargs = check_config(cfg) - vault = get_vault("local", **kwargs) - return make_app(backend=vault, client_max_size=cfg["client_max_size"], **kwargs) + return cfg + + +def make_app(config_to_check: Dict[str, Any]) -> VaultServerApp: + """Ensure the configuration is ok, then instantiate the server application + + """ + config_ok = check_config(config_to_check) + app = VaultServerApp( + __name__, + backend_class=VaultInterface, + backend_factory=lambda: get_vault(config_ok["vault"]), + client_max_size=config_ok["client_max_size"], + ) + app.router.add_route("GET", "/", index) + return app + + +def make_app_from_configfile(**kwargs) -> VaultServerApp: + """Check configuration then instantiate once a vault server application. + + """ + global app + if not app: + app_config = load_from_envvar(DEFAULT_CONFIG) + app = make_app(app_config) + + return app if __name__ == "__main__": diff --git a/swh/vault/backend.py b/swh/vault/backend.py --- a/swh/vault/backend.py +++ b/swh/vault/backend.py @@ -3,8 +3,10 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import collections from email.mime.text import MIMEText import smtplib +from typing import Any, Dict, List, Optional, Tuple import psycopg2.extras import psycopg2.pool @@ -16,7 +18,7 @@ from swh.scheduler.utils import create_oneshot_task_dict from swh.storage import get_storage from swh.vault.cache import VaultCache -from swh.vault.cookers import get_cooker_cls +from swh.vault.cookers import COOKER_TYPES, get_cooker_cls from swh.vault.exc import NotFoundExc cooking_task_name = "swh.vault.cooking_tasks.SWHCookingTask" @@ -61,7 +63,7 @@ """ -def batch_to_bytes(batch): +def batch_to_bytes(batch: List[Tuple[str, str]]) -> List[Tuple[str, bytes]]: return [(obj_type, hashutil.hash_to_bytes(obj_id)) for obj_type, obj_id in batch] @@ -95,9 +97,16 @@ db.put_conn() @db_transaction() - def task_info(self, obj_type, obj_id, db=None, cur=None): - """Fetch information from a bundle""" - obj_id = hashutil.hash_to_bytes(obj_id) + def task_info( + self, + obj_type: str, + obj_id: bytes, + raise_notfound: bool = True, + db=None, + cur=None, + ) -> Optional[Dict[str, Any]]: + hex_id = hashutil.hash_to_hex(obj_id) + cur.execute( """ SELECT id, type, object_id, task_id, task_status, sticky, @@ -107,26 +116,33 @@ (obj_type, obj_id), ) res = cur.fetchone() - if res: - res["object_id"] = bytes(res["object_id"]) + if not res: + if raise_notfound: + hex_id = hashutil.hash_to_hex(obj_id) + raise NotFoundExc(f"{obj_type} {hex_id} was not found.") + return None + + res["object_id"] = bytes(res["object_id"]) return res - def _send_task(self, *args): + def _send_task(self, obj_type: str, hex_id: str): """Send a cooking task to the celery scheduler""" - task = create_oneshot_task_dict("cook-vault-bundle", *args) + task = create_oneshot_task_dict("cook-vault-bundle", obj_type, hex_id) added_tasks = self.scheduler.create_tasks([task]) return added_tasks[0]["id"] @db_transaction() - def create_task(self, obj_type, obj_id, sticky=False, db=None, cur=None): + def create_task( + self, obj_type: str, obj_id: bytes, sticky: bool = False, db=None, cur=None + ): """Create and send a cooking task""" - obj_id = hashutil.hash_to_bytes(obj_id) hex_id = hashutil.hash_to_hex(obj_id) cooker_class = get_cooker_cls(obj_type) cooker = cooker_class(obj_type, hex_id, backend=self, storage=self.storage) + if not cooker.check_exists(): - raise NotFoundExc("Object {} was not found.".format(hex_id)) + raise NotFoundExc(f"{obj_type} {hex_id} was not found.") cur.execute( """ @@ -147,9 +163,10 @@ ) @db_transaction() - def add_notif_email(self, obj_type, obj_id, email, db=None, cur=None): + def add_notif_email( + self, obj_type: str, obj_id: bytes, email: str, db=None, cur=None + ): """Add an e-mail address to notify when a given bundle is ready""" - obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ INSERT INTO vault_notif_email (email, bundle_id) @@ -158,20 +175,30 @@ (email, obj_type, obj_id), ) + def put_bundle(self, obj_type: str, obj_id: bytes, bundle) -> bool: + self.cache.add(obj_type, obj_id, bundle) + return True + @db_transaction() def cook_request( - self, obj_type, obj_id, *, sticky=False, email=None, db=None, cur=None - ): - """Main entry point for cooking requests. This starts a cooking task if - needed, and add the given e-mail to the notify list""" - obj_id = hashutil.hash_to_bytes(obj_id) - info = self.task_info(obj_type, obj_id) + self, + obj_type: str, + obj_id: bytes, + *, + sticky: bool = False, + email: Optional[str] = None, + db=None, + cur=None, + ) -> Dict[str, Any]: + info = self.task_info(obj_type, obj_id, raise_notfound=False) + + if obj_type not in COOKER_TYPES: + raise NotFoundExc(f"{obj_type} is an unknown type.") # If there's a failed bundle entry, delete it first. if info is not None and info["task_status"] == "failed": cur.execute( - """DELETE FROM vault_bundle - WHERE type = %s AND object_id = %s""", + "DELETE FROM vault_bundle WHERE type = %s AND object_id = %s", (obj_type, obj_id), ) db.conn.commit() @@ -191,16 +218,20 @@ else: self.add_notif_email(obj_type, obj_id, email) - info = self.task_info(obj_type, obj_id) - return info + return self.task_info(obj_type, obj_id) @db_transaction() - def batch_cook(self, batch, db=None, cur=None): - """Cook a batch of bundles and returns the cooking id.""" + def batch_cook( + self, batch: List[Tuple[str, str]], db=None, cur=None + ) -> Dict[str, int]: # Import execute_values at runtime only, because it requires # psycopg2 >= 2.7 (only available on postgresql servers) from psycopg2.extras import execute_values + for obj_type, _ in batch: + if obj_type not in COOKER_TYPES: + raise NotFoundExc(f"{obj_type} is an unknown type.") + cur.execute( """ INSERT INTO vault_batch (id) @@ -208,7 +239,7 @@ RETURNING id""" ) batch_id = cur.fetchone()["id"] - batch = batch_to_bytes(batch) + batch_bytes = batch_to_bytes(batch) # Delete all failed bundles from the batch cur.execute( @@ -216,7 +247,7 @@ DELETE FROM vault_bundle WHERE task_status = 'failed' AND (type, object_id) IN %s""", - (tuple(batch),), + (tuple(batch_bytes),), ) # Insert all the bundles, return the new ones @@ -225,7 +256,7 @@ """ INSERT INTO vault_bundle (type, object_id) VALUES %s ON CONFLICT DO NOTHING""", - batch, + batch_bytes, ) # Get the bundle ids and task status @@ -233,7 +264,7 @@ """ SELECT id, type, object_id, task_id FROM vault_bundle WHERE (type, object_id) IN %s""", - (tuple(batch),), + (tuple(batch_bytes),), ) bundles = cur.fetchall() @@ -266,10 +297,11 @@ ] added_tasks = self.scheduler.create_tasks(tasks) - tasks_ids_bundle_ids = zip([task["id"] for task in added_tasks], batch_new) tasks_ids_bundle_ids = [ (task_id, obj_type, obj_id) - for task_id, (obj_type, obj_id) in tasks_ids_bundle_ids + for task_id, (obj_type, obj_id) in zip( + [task["id"] for task in added_tasks], batch_new + ) ] # Update the task ids @@ -282,11 +314,10 @@ WHERE type = s_type::cook_type AND object_id = s_object_id """, tasks_ids_bundle_ids, ) - return batch_id + return {"id": batch_id} @db_transaction() - def batch_info(self, batch_id, db=None, cur=None): - """Fetch information from a batch of bundles""" + def batch_progress(self, batch_id: int, db=None, cur=None) -> Dict[str, Any]: cur.execute( """ SELECT vault_bundle.id as id, @@ -297,16 +328,26 @@ WHERE batch_id = %s""", (batch_id,), ) - res = cur.fetchall() - if res: - for d in res: - d["object_id"] = bytes(d["object_id"]) + bundles = cur.fetchall() + if not bundles: + raise NotFoundExc(f"Batch {batch_id} does not exist.") + + for bundle in bundles: + bundle["object_id"] = bytes(bundle["object_id"]) + counter = collections.Counter(b["status"] for b in bundles) + res = { + "bundles": bundles, + "total": len(bundles), + **{k: 0 for k in ("new", "pending", "done", "failed")}, + **dict(counter), + } + return res @db_transaction() - def is_available(self, obj_type, obj_id, db=None, cur=None): + def is_available(self, obj_type: str, obj_id: bytes, db=None, cur=None): """Check whether a bundle is available for retrieval""" - info = self.task_info(obj_type, obj_id, cur=cur) + info = self.task_info(obj_type, obj_id, raise_notfound=False, cur=cur) return ( info is not None and info["task_status"] == "done" @@ -314,15 +355,21 @@ ) @db_transaction() - def fetch(self, obj_type, obj_id, db=None, cur=None): + def fetch( + self, obj_type: str, obj_id: bytes, raise_notfound=True, db=None, cur=None + ): """Retrieve a bundle from the cache""" - if not self.is_available(obj_type, obj_id, cur=cur): + available = self.is_available(obj_type, obj_id, cur=cur) + if not available: + if raise_notfound: + hex_id = hashutil.hash_to_hex(obj_id) + raise NotFoundExc(f"{obj_type} {hex_id} is not available.") return None self.update_access_ts(obj_type, obj_id, cur=cur) return self.cache.get(obj_type, obj_id) @db_transaction() - def update_access_ts(self, obj_type, obj_id, db=None, cur=None): + def update_access_ts(self, obj_type: str, obj_id: bytes, db=None, cur=None): """Update the last access timestamp of a bundle""" obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( @@ -334,8 +381,9 @@ ) @db_transaction() - def set_status(self, obj_type, obj_id, status, db=None, cur=None): - """Set the cooking status of a bundle""" + def set_status( + self, obj_type: str, obj_id: bytes, status: str, db=None, cur=None + ) -> bool: obj_id = hashutil.hash_to_bytes(obj_id) req = ( """ @@ -345,10 +393,12 @@ + """WHERE type = %s AND object_id = %s""" ) cur.execute(req, (status, obj_type, obj_id)) + return True @db_transaction() - def set_progress(self, obj_type, obj_id, progress, db=None, cur=None): - """Set the cooking progress of a bundle""" + def set_progress( + self, obj_type: str, obj_id: bytes, progress: str, db=None, cur=None + ) -> bool: obj_id = hashutil.hash_to_bytes(obj_id) cur.execute( """ @@ -357,11 +407,12 @@ WHERE type = %s AND object_id = %s""", (progress, obj_type, obj_id), ) + return True @db_transaction() - def send_all_notifications(self, obj_type, obj_id, db=None, cur=None): - """Send all the e-mails in the notification list of a bundle""" - obj_id = hashutil.hash_to_bytes(obj_id) + def send_all_notifications( + self, obj_type: str, obj_id: bytes, db=None, cur=None + ) -> bool: cur.execute( """ SELECT vault_notif_email.id AS id, email, task_status, progress_msg @@ -379,19 +430,20 @@ status=d["task_status"], progress_msg=d["progress_msg"], ) + return True @db_transaction() def send_notification( self, - n_id, - email, - obj_type, - obj_id, - status, - progress_msg=None, + n_id: Optional[int], + email: str, + obj_type: str, + obj_id: bytes, + status: str, + progress_msg: Optional[str] = None, db=None, cur=None, - ): + ) -> None: """Send the notification of a bundle to a specific e-mail""" hex_id = hashutil.hash_to_hex(obj_id) short_id = hex_id[:7] @@ -440,7 +492,7 @@ (n_id,), ) - def _smtp_send(self, msg): + def _smtp_send(self, msg: MIMEText): # Reconnect if needed try: status = self.smtp_server.noop()[0] @@ -453,7 +505,7 @@ self.smtp_server.send_message(msg) @db_transaction() - def _cache_expire(self, cond, *args, db=None, cur=None): + def _cache_expire(self, cond, *args, db=None, cur=None) -> None: """Low-level expiration method, used by cache_expire_* methods""" # Embedded SELECT query to be able to use ORDER BY and LIMIT cur.execute( @@ -476,14 +528,14 @@ self.cache.delete(d["type"], bytes(d["object_id"])) @db_transaction() - def cache_expire_oldest(self, n=1, by="last_access", db=None, cur=None): + def cache_expire_oldest(self, n=1, by="last_access", db=None, cur=None) -> None: """Expire the `n` oldest bundles""" assert by in ("created", "done", "last_access") filter = """ORDER BY ts_{} LIMIT {}""".format(by, n) return self._cache_expire(filter) @db_transaction() - def cache_expire_until(self, date, by="last_access", db=None, cur=None): + def cache_expire_until(self, date, by="last_access", db=None, cur=None) -> None: """Expire all the bundles until a certain date""" assert by in ("created", "done", "last_access") filter = """AND ts_{} <= %s""".format(by) diff --git a/swh/vault/cookers/base.py b/swh/vault/cookers/base.py --- a/swh/vault/cookers/base.py +++ b/swh/vault/cookers/base.py @@ -133,4 +133,4 @@ self.backend.set_status(self.obj_type, self.obj_id, "done") self.backend.set_progress(self.obj_type, self.obj_id, None) finally: - self.backend.send_notif(self.obj_type, self.obj_id) + self.backend.send_all_notifications(self.obj_type, self.obj_id) diff --git a/swh/vault/interface.py b/swh/vault/interface.py new file mode 100644 --- /dev/null +++ b/swh/vault/interface.py @@ -0,0 +1,68 @@ +# Copyright (C) 2017-2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import Any, Dict, List, Optional, Tuple + +from typing_extensions import Protocol, runtime_checkable + +from swh.core.api import remote_api_endpoint + + +@runtime_checkable +class VaultInterface(Protocol): + """ + Backend Interface for the Software Heritage vault. + """ + + @remote_api_endpoint("fetch") + def fetch(self, obj_type: str, obj_id: bytes) -> Dict[str, Any]: + """Fetch information from a bundle""" + ... + + @remote_api_endpoint("cook") + def cook_request( + self, obj_type: str, obj_id: bytes, email: Optional[str] = None + ) -> Dict[str, Any]: + """Main entry point for cooking requests. This starts a cooking task if + needed, and add the given e-mail to the notify list""" + ... + + @remote_api_endpoint("progress") + def task_info(self, obj_type: str, obj_id: bytes): + ... + + # Cookers endpoints + + @remote_api_endpoint("set_progress") + def set_progress(self, obj_type: str, obj_id: bytes, progress: str) -> None: + """Set the cooking progress of a bundle""" + ... + + @remote_api_endpoint("set_status") + def set_status(self, obj_type: str, obj_id: bytes, status: str) -> None: + """Set the cooking status of a bundle""" + ... + + @remote_api_endpoint("put_bundle") + def put_bundle(self, obj_type: str, obj_id: bytes, bundle): + """Store bundle in vault cache""" + ... + + @remote_api_endpoint("send_notif") + def send_all_notifications(self, obj_type: str, obj_id: bytes): + """Send all the e-mails in the notification list of a bundle""" + ... + + # Batch endpoints + + @remote_api_endpoint("batch_cook") + def batch_cook(self, batch: List[Tuple[str, bytes]]) -> int: + """Cook a batch of bundles and returns the cooking id.""" + ... + + @remote_api_endpoint("batch_progress") + def batch_progress(self, batch_id: int) -> Dict[str, Any]: + """Fetch information from a batch of bundles""" + ... diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py --- a/swh/vault/tests/conftest.py +++ b/swh/vault/tests/conftest.py @@ -8,6 +8,7 @@ import pkg_resources.extern.packaging.version import pytest +import yaml from swh.core.db.pytest_plugin import postgresql_fact from swh.storage.tests import SQL_DIR as STORAGE_SQL_DIR @@ -63,12 +64,29 @@ }, "cache": { "cls": "pathslicing", - "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,}, + "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True}, }, "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",}, } +@pytest.fixture +def swh_local_vault_config(swh_vault_config: Dict[str, Any]) -> Dict[str, Any]: + return { + "vault": {"cls": "local", "args": swh_vault_config}, + "client_max_size": 1024 ** 3, + } + + +@pytest.fixture +def swh_vault_config_file(swh_local_vault_config, monkeypatch, tmp_path): + conf_path = os.path.join(str(tmp_path), "vault-server.yml") + with open(conf_path, "w") as f: + f.write(yaml.dump(swh_local_vault_config)) + monkeypatch.setenv("SWH_CONFIG_FILENAME", conf_path) + return conf_path + + @pytest.fixture def swh_vault(request, swh_vault_config): return get_vault("local", **swh_vault_config) diff --git a/swh/vault/tests/test_backend.py b/swh/vault/tests/test_backend.py --- a/swh/vault/tests/test_backend.py +++ b/swh/vault/tests/test_backend.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017 The Software Heritage developers +# Copyright (C) 2017-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -7,10 +7,12 @@ import datetime from unittest.mock import MagicMock, patch +import attr import psycopg2 import pytest from swh.model import hashutil +from swh.vault.exc import NotFoundExc from swh.vault.tests.vault_testing import hash_content @@ -64,6 +66,14 @@ TEST_EMAIL = "ouiche@lorraine.fr" +@pytest.fixture +def swh_vault(swh_vault, sample_data): + # make the vault's storage consistent with test data + revision = attr.evolve(sample_data.revision, id=TEST_OBJ_ID) + swh_vault.storage.revision_add([revision]) + return swh_vault + + def test_create_task_simple(swh_vault): with mock_cooking(swh_vault) as m: swh_vault.create_task(TEST_TYPE, TEST_OBJ_ID) @@ -159,7 +169,7 @@ assert access_ts_2 < access_ts_3 -def test_cook_request_idempotent(swh_vault): +def test_cook_request_idempotent(swh_vault, sample_data): with mock_cooking(swh_vault): info1 = swh_vault.cook_request(TEST_TYPE, TEST_OBJ_ID) info2 = swh_vault.cook_request(TEST_TYPE, TEST_OBJ_ID) @@ -238,7 +248,13 @@ def test_fetch(swh_vault): - assert swh_vault.fetch(TEST_TYPE, TEST_OBJ_ID) is None + assert swh_vault.fetch(TEST_TYPE, TEST_OBJ_ID, raise_notfound=False) is None + + with pytest.raises( + NotFoundExc, match=f"{TEST_TYPE} {TEST_HEX_ID} is not available." + ): + swh_vault.fetch(TEST_TYPE, TEST_OBJ_ID) + obj_id, content = fake_cook(swh_vault, TEST_TYPE, b"content") info = swh_vault.task_info(TEST_TYPE, obj_id) diff --git a/swh/vault/tests/test_cookers_base.py b/swh/vault/tests/test_cookers_base.py --- a/swh/vault/tests/test_cookers_base.py +++ b/swh/vault/tests/test_cookers_base.py @@ -44,7 +44,7 @@ ) cooker.backend.set_status.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID, "done") cooker.backend.set_progress.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID, None) - cooker.backend.send_notif.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) + cooker.backend.send_all_notifications.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) def test_code_exception_cook(): @@ -58,7 +58,7 @@ cooker.backend.set_status.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID, "failed") assert "Nope" not in cooker.backend.set_progress.call_args[0][2] - cooker.backend.send_notif.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) + cooker.backend.send_all_notifications.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) def test_policy_exception_cook(): @@ -71,4 +71,4 @@ cooker.backend.set_status.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID, "failed") assert "exceeds" in cooker.backend.set_progress.call_args[0][2] - cooker.backend.send_notif.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) + cooker.backend.send_all_notifications.assert_called_with(TEST_OBJ_TYPE, TEST_OBJ_ID) diff --git a/swh/vault/tests/test_server.py b/swh/vault/tests/test_server.py --- a/swh/vault/tests/test_server.py +++ b/swh/vault/tests/test_server.py @@ -4,44 +4,79 @@ # See top-level LICENSE file for more information import copy +from typing import Any, Dict import pytest -from swh.core.api.serializers import msgpack_dumps, msgpack_loads -from swh.vault.api.server import check_config, make_app +from swh.core.api.serializers import json_dumps, msgpack_dumps, msgpack_loads +from swh.vault.api.server import ( + VaultServerApp, + check_config, + make_app, + make_app_from_configfile, +) +from swh.vault.tests.test_backend import TEST_HEX_ID, TEST_OBJ_ID + + +def test_make_app_from_env_variable(swh_vault_config_file): + """Instantiation of the server should happen once (through environment variable) + + """ + app0 = make_app_from_configfile() + assert app0 is not None + app1 = make_app_from_configfile() + assert app1 == app0 @pytest.fixture -def client(swh_vault, loop, aiohttp_client): - app = make_app(backend=swh_vault) - return loop.run_until_complete(aiohttp_client(app)) +def async_app(swh_local_vault_config: Dict[str, Any],) -> VaultServerApp: + """Instantiate the vault server application. + Note: This requires the db setup to run (fixture swh_vault in charge of this) -async def test_index(client): - resp = await client.get("/") + """ + return make_app(swh_local_vault_config) + + +@pytest.fixture +def cli(async_app, aiohttp_client, loop): + return loop.run_until_complete(aiohttp_client(async_app)) + + +async def test_client_index(cli): + resp = await cli.get("/") assert resp.status == 200 -async def test_cook_notfound(client): - resp = await client.post("/cook/directory/000000") +async def test_client_cook_notfound(cli): + resp = await cli.post( + "/cook", + data=json_dumps({"obj_type": "directory", "obj_id": TEST_OBJ_ID}), + headers=[("Content-Type", "application/json")], + ) assert resp.status == 400 content = msgpack_loads(await resp.content.read()) assert content["exception"]["type"] == "NotFoundExc" - assert content["exception"]["args"] == ["Object 000000 was not found."] + assert content["exception"]["args"] == [f"directory {TEST_HEX_ID} was not found."] -async def test_progress_notfound(client): - resp = await client.get("/progress/directory/000000") +async def test_client_progress_notfound(cli): + resp = await cli.post( + "/progress", + data=json_dumps({"obj_type": "directory", "obj_id": TEST_OBJ_ID}), + headers=[("Content-Type", "application/json")], + ) assert resp.status == 400 content = msgpack_loads(await resp.content.read()) assert content["exception"]["type"] == "NotFoundExc" - assert content["exception"]["args"] == ["directory 000000 was not found."] + assert content["exception"]["args"] == [f"directory {TEST_HEX_ID} was not found."] -async def test_batch_cook_invalid_type(client): - data = msgpack_dumps([("foobar", [])]) - resp = await client.post( - "/batch_cook", data=data, headers={"Content-Type": "application/x-msgpack"} +async def test_client_batch_cook_invalid_type(cli): + resp = await cli.post( + "/batch_cook", + data=msgpack_dumps({"batch": [("foobar", [])]}), + headers={"Content-Type": "application/x-msgpack"}, ) assert resp.status == 400 content = msgpack_loads(await resp.content.read()) @@ -49,8 +84,12 @@ assert content["exception"]["args"] == ["foobar is an unknown type."] -async def test_batch_progress_notfound(client): - resp = await client.get("/batch_progress/1") +async def test_client_batch_progress_notfound(cli): + resp = await cli.post( + "/batch_progress", + data=msgpack_dumps({"batch_id": 1}), + headers={"Content-Type": "application/x-msgpack"}, + ) assert resp.status == 400 content = msgpack_loads(await resp.content.read()) assert content["exception"]["type"] == "NotFoundExc"