diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -3,22 +3,25 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import importlib +from typing import Any, Dict, List import warnings +from .interface import StorageInterface -STORAGE_IMPLEMENTATION = { - "pipeline", - "local", - "remote", - "memory", - "filter", - "buffer", - "retry", - "cassandra", + +STORAGE_IMPLEMENTATIONS = { + "local": ".storage.Storage", + "remote": ".api.client.RemoteStorage", + "memory": ".in_memory.InMemoryStorage", + "filter": ".filter.FilteringProxyStorage", + "buffer": ".buffer.BufferingProxyStorage", + "retry": ".retry.RetryingProxyStorage", + "cassandra": ".cassandra.CassandraStorage", } -def get_storage(cls, **kwargs): +def get_storage(cls: str, **kwargs) -> StorageInterface: """Get a storage object of class `storage_class` with arguments `storage_args`. @@ -35,12 +38,6 @@ ValueError if passed an unknown storage class. """ - if cls not in STORAGE_IMPLEMENTATION: - raise ValueError( - "Unknown storage class `%s`. Supported: %s" - % (cls, ", ".join(STORAGE_IMPLEMENTATION)) - ) - if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', @@ -51,25 +48,20 @@ if cls == "pipeline": return get_storage_pipeline(**kwargs) - if cls == "remote": - from .api.client import RemoteStorage as Storage - elif cls == "local": - from .storage import Storage - elif cls == "cassandra": - from .cassandra import CassandraStorage as Storage - elif cls == "memory": - from .in_memory import InMemoryStorage as Storage - elif cls == "filter": - from .filter import FilteringProxyStorage as Storage - elif cls == "buffer": - from .buffer import BufferingProxyStorage as Storage - elif cls == "retry": - from .retry import RetryingProxyStorage as Storage + class_path = STORAGE_IMPLEMENTATIONS.get(cls) + if class_path is None: + raise ValueError( + "Unknown storage class `%s`. Supported: %s" + % (cls, ", ".join(STORAGE_IMPLEMENTATIONS)) + ) + (module_path, class_name) = class_path.rsplit(".", 1) + module = importlib.import_module(module_path, package=__package__) + Storage = getattr(module, class_name) return Storage(**kwargs) -def get_storage_pipeline(steps): +def get_storage_pipeline(steps: List[Dict[str, Any]]) -> StorageInterface: """Recursively get a storage object that may use other storage objects as backends. @@ -98,4 +90,7 @@ step["storage"] = storage_config storage_config = step + if storage_config is None: + raise ValueError("'pipeline' has no steps.") + return get_storage(**storage_config) diff --git a/swh/storage/tests/test_kafka_writer.py b/swh/storage/tests/test_kafka_writer.py --- a/swh/storage/tests/test_kafka_writer.py +++ b/swh/storage/tests/test_kafka_writer.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any, Dict + from confluent_kafka import Consumer from swh.storage import get_storage @@ -26,7 +28,7 @@ "prefix": kafka_prefix, "anonymize": False, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "pipeline", "steps": [{"cls": "memory", "journal_writer": writer_config},], } @@ -93,7 +95,7 @@ "prefix": kafka_prefix, "anonymize": True, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "pipeline", "steps": [{"cls": "memory", "journal_writer": writer_config},], } diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -6,8 +6,7 @@ import datetime import functools import logging - -from typing import Container, Dict, Optional +from typing import Any, Container, Dict, Optional import pytest @@ -40,7 +39,7 @@ "client_id": "kafka_writer", "prefix": kafka_prefix, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "memory", "journal_writer": journal_writer_config, } @@ -278,7 +277,7 @@ "prefix": kafka_prefix, "anonymize": True, } - src_config = {"cls": "memory", "journal_writer": writer_config} + src_config: Dict[str, Any] = {"cls": "memory", "journal_writer": writer_config} storage = get_storage(**src_config)