diff --git a/swh/storage/__init__.py b/swh/storage/__init__.py index 4c306f96..25685fc8 100644 --- a/swh/storage/__init__.py +++ b/swh/storage/__init__.py @@ -1,43 +1,46 @@ # Copyright (C) 2015-2016 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information + from . import storage Storage = storage.Storage class HashCollision(Exception): pass +STORAGE_IMPLEMENTATION = {'local', 'remote', 'memory'} + + def get_storage(cls, args): - """ - Get a storage object of class `storage_class` with arguments + """Get a storage object of class `storage_class` with arguments `storage_args`. Args: storage (dict): dictionary with keys: - - cls (str): storage's class, either 'local', 'remote', - or 'memory' + - cls (str): storage's class, either local, remote, memory - args (dict): dictionary with keys Returns: an instance of swh.storage.Storage (either local or remote) Raises: ValueError if passed an unknown storage class. """ + if cls not in STORAGE_IMPLEMENTATION: + raise ValueError('Unknown storage class `%s`. Supported: %s' % ( + cls, ', '.join(STORAGE_IMPLEMENTATION))) if cls == 'remote': from .api.client import RemoteStorage as Storage elif cls == 'local': from .storage import Storage elif cls == 'memory': from .in_memory import Storage - else: - raise ValueError('Unknown storage class `%s`' % cls) return Storage(**args) diff --git a/swh/storage/tests/test_init.py b/swh/storage/tests/test_init.py new file mode 100644 index 00000000..a61a7ef9 --- /dev/null +++ b/swh/storage/tests/test_init.py @@ -0,0 +1,42 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from unittest.mock import patch + +from swh.storage import get_storage + +from swh.storage.api.client import RemoteStorage +from swh.storage.storage import Storage as DbStorage +from swh.storage.in_memory import Storage as MemoryStorage + + +@patch('swh.storage.storage.psycopg2.pool') +def test_get_storage(mock_pool): + """Instantiating an existing storage should be ok + + """ + mock_pool.ThreadedConnectionPool.return_value = None + for cls, real_class, dummy_args in [ + ('remote', RemoteStorage, {'url': 'url'}), + ('memory', MemoryStorage, {}), + ('local', DbStorage, { + 'db': 'postgresql://db', 'objstorage': { + 'cls': 'memory', 'args': {}, + }, + }), + ]: + actual_storage = get_storage(cls, args=dummy_args) + assert actual_storage is not None + assert isinstance(actual_storage, real_class) + + +def test_get_storage_failure(): + """Instantiating an unknown storage should raise + + """ + with pytest.raises(ValueError, match='Unknown storage class `unknown`'): + get_storage('unknown', args=[])