diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ pytest +pytest-mock confluent-kafka diff --git a/swh/search/__init__.py b/swh/search/__init__.py --- a/swh/search/__init__.py +++ b/swh/search/__init__.py @@ -1,16 +1,27 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import importlib +import warnings -def get_search(cls, args): - """Get an search object of class `search_class` with - arguments `search_args`. +from typing import Any, Dict + + +SEARCH_IMPLEMENTATIONS = { + "elasticsearch": ".elasticsearch.ElasticSearch", + "remote": ".api.client.RemoteSearch", + "memory": ".in_memory.InMemorySearch", +} + + +def get_search(cls: str, **kwargs: Dict[str, Any]): + """Get an search object of class `cls` with arguments `args`. Args: - cls (str): search's class, either 'local' or 'remote' - args (dict): dictionary of arguments passed to the + cls: search's class, either 'local' or 'remote' + args: dictionary of arguments passed to the search class constructor Returns: @@ -20,13 +31,21 @@ ValueError if passed an unknown search class. """ - if cls == "remote": - from .api.client import RemoteSearch as Search - elif cls == "elasticsearch": - from .elasticsearch import ElasticSearch as Search - elif cls == "memory": - from .in_memory import InMemorySearch as Search - else: - raise ValueError("Unknown indexer search class `%s`" % cls) - - return Search(**args) + if "args" in kwargs: + warnings.warn( + 'Explicit "args" key is deprecated, use keys directly instead.', + DeprecationWarning, + ) + kwargs = kwargs["args"] + + class_path = SEARCH_IMPLEMENTATIONS.get(cls) + if class_path is None: + raise ValueError( + "Unknown search class `%s`. Supported: %s" + % (cls, ", ".join(SEARCH_IMPLEMENTATIONS)) + ) + + (module_path, class_name) = class_path.rsplit(".", 1) + module = importlib.import_module(module_path, package=__package__) + Search = getattr(module, class_name) + return Search(**kwargs) diff --git a/swh/search/tests/conftest.py b/swh/search/tests/conftest.py --- a/swh/search/tests/conftest.py +++ b/swh/search/tests/conftest.py @@ -124,7 +124,7 @@ """ logger.debug("swh_search: elasticsearch_host: %s", elasticsearch_host) - search = get_search("elasticsearch", {"hosts": [elasticsearch_host],}) + search = get_search("elasticsearch", hosts=[elasticsearch_host],) search.deinitialize() # To reset internal state from previous runs search.initialize() # install required index yield search diff --git a/swh/search/tests/test_api_client.py b/swh/search/tests/test_api_client.py --- a/swh/search/tests/test_api_client.py +++ b/swh/search/tests/test_api_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -29,10 +29,10 @@ self.app = app super().setUp() self.reset() - self.search = get_search("remote", {"url": self.url(),}) + self.search = get_search("remote", url=self.url(),) def reset(self): - search = get_search("elasticsearch", {"hosts": [self._elasticsearch_host],}) + search = get_search("elasticsearch", hosts=[self._elasticsearch_host],) search.deinitialize() search.initialize() diff --git a/swh/search/tests/test_in_memory.py b/swh/search/tests/test_in_memory.py --- a/swh/search/tests/test_in_memory.py +++ b/swh/search/tests/test_in_memory.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -14,7 +14,7 @@ class InmemorySearchTest(unittest.TestCase, CommonSearchTest): @pytest.fixture(autouse=True) def _instantiate_search(self): - self.search = get_search("memory", {}) + self.search = get_search("memory") def setUp(self): self.reset() diff --git a/swh/search/tests/test_init.py b/swh/search/tests/test_init.py new file mode 100644 --- /dev/null +++ b/swh/search/tests/test_init.py @@ -0,0 +1,45 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.search import get_search + +from swh.search.elasticsearch import ElasticSearch +from swh.search.api.client import RemoteSearch +from swh.search.in_memory import InMemorySearch + + +SEARCH_IMPLEMENTATIONS = [ + ("remote", RemoteSearch, {"url": "localhost"}), + ("elasticsearch", ElasticSearch, {"hosts": ["localhost"]}), + ("memory", InMemorySearch, None), +] + + +def test_get_search_failure(): + with pytest.raises(ValueError, match="Unknown search class"): + get_search("unknown-search") + + +@pytest.mark.parametrize("clazz,expected_clazz,kwargs", SEARCH_IMPLEMENTATIONS) +def test_get_search(mocker, clazz, expected_clazz, kwargs): + mocker.patch("swh.search.elasticsearch.Elasticsearch") + if kwargs: + concrete_search = get_search(clazz, **kwargs) + else: + concrete_search = get_search(clazz) + assert isinstance(concrete_search, expected_clazz) + + +def test_get_search_warning(mocker): + mocker.patch("swh.search.elasticsearch.Elasticsearch") + for (cls, expected_class, kwargs) in [ + ("remote", RemoteSearch, {"url": "localhost"}), + ("elasticsearch", ElasticSearch, {"hosts": ["localhost"]}), + ]: + with pytest.warns(DeprecationWarning): + search = get_search(cls, args=kwargs) + assert isinstance(search, expected_class)