diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -3,22 +3,24 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import importlib import warnings +from .interface import StorageInterface -STORAGE_IMPLEMENTATION = { - "pipeline", - "local", - "remote", - "memory", - "filter", - "buffer", - "retry", - "cassandra", + +STORAGE_IMPLEMENTATIONS = { + "local": ".storage.Storage", + "remote": ".api.client.RemoteStorage", + "memory": ".in_memory.InMemoryStorage", + "filter": ".filter.FilteringProxyStorage", + "buffer": ".buffer.BufferingProxyStorage", + "retry": ".retry.RetryingProxyStorage", + "cassandra": ".cassandra.CassandraStorage", } -def get_storage(cls, **kwargs): +def get_storage(cls: str, **kwargs) -> StorageInterface: """Get a storage object of class `storage_class` with arguments `storage_args`. @@ -35,12 +37,6 @@ ValueError if passed an unknown storage class. """ - if cls not in STORAGE_IMPLEMENTATION: - raise ValueError( - "Unknown storage class `%s`. Supported: %s" - % (cls, ", ".join(STORAGE_IMPLEMENTATION)) - ) - if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', @@ -51,21 +47,16 @@ if cls == "pipeline": return get_storage_pipeline(**kwargs) - if cls == "remote": - from .api.client import RemoteStorage as Storage - elif cls == "local": - from .storage import Storage - elif cls == "cassandra": - from .cassandra import CassandraStorage as Storage - elif cls == "memory": - from .in_memory import InMemoryStorage as Storage - elif cls == "filter": - from .filter import FilteringProxyStorage as Storage - elif cls == "buffer": - from .buffer import BufferingProxyStorage as Storage - elif cls == "retry": - from .retry import RetryingProxyStorage as Storage + class_path = STORAGE_IMPLEMENTATIONS.get(cls) + if class_path is None: + raise ValueError( + "Unknown storage class `%s`. Supported: %s" + % (cls, ", ".join(STORAGE_IMPLEMENTATIONS)) + ) + (module_path, class_name) = class_path.rsplit(".", 1) + module = importlib.import_module(module_path, package=__package__) + Storage = getattr(module, class_name) return Storage(**kwargs) diff --git a/swh/storage/tests/test_kafka_writer.py b/swh/storage/tests/test_kafka_writer.py --- a/swh/storage/tests/test_kafka_writer.py +++ b/swh/storage/tests/test_kafka_writer.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any, Dict + from confluent_kafka import Consumer from swh.storage import get_storage @@ -26,7 +28,7 @@ "prefix": kafka_prefix, "anonymize": False, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "pipeline", "steps": [{"cls": "memory", "journal_writer": writer_config},], } @@ -93,7 +95,7 @@ "prefix": kafka_prefix, "anonymize": True, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "pipeline", "steps": [{"cls": "memory", "journal_writer": writer_config},], } diff --git a/swh/storage/tests/test_replay.py b/swh/storage/tests/test_replay.py --- a/swh/storage/tests/test_replay.py +++ b/swh/storage/tests/test_replay.py @@ -6,8 +6,7 @@ import datetime import functools import logging - -from typing import Container, Dict, Optional +from typing import Any, Container, Dict, Optional import pytest @@ -40,7 +39,7 @@ "client_id": "kafka_writer", "prefix": kafka_prefix, } - storage_config = { + storage_config: Dict[str, Any] = { "cls": "memory", "journal_writer": journal_writer_config, } @@ -278,7 +277,7 @@ "prefix": kafka_prefix, "anonymize": True, } - src_config = {"cls": "memory", "journal_writer": writer_config} + src_config: Dict[str, Any] = {"cls": "memory", "journal_writer": writer_config} storage = get_storage(**src_config)