diff --git a/swh/search/__init__.py b/swh/search/__init__.py --- a/swh/search/__init__.py +++ b/swh/search/__init__.py @@ -8,6 +8,8 @@ from typing import Any, Dict +from swh.search.interface import SearchInterface + SEARCH_IMPLEMENTATIONS = { "elasticsearch": ".elasticsearch.ElasticSearch", @@ -16,7 +18,7 @@ } -def get_search(cls: str, **kwargs: Dict[str, Any]): +def get_search(cls: str, **kwargs: Dict[str, Any]) -> SearchInterface: """Get an search object of class `cls` with arguments `args`. Args: diff --git a/swh/search/api/client.py b/swh/search/api/client.py --- a/swh/search/api/client.py +++ b/swh/search/api/client.py @@ -1,14 +1,14 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.core.api import RPCClient -from ..elasticsearch import ElasticSearch +from ..interface import SearchInterface class RemoteSearch(RPCClient): """Proxy to a remote search API""" - backend_class = ElasticSearch + backend_class = SearchInterface diff --git a/swh/search/api/server.py b/swh/search/api/server.py --- a/swh/search/api/server.py +++ b/swh/search/api/server.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -10,7 +10,7 @@ from swh.core.api import RPCServerApp, error_handler, encode_data_server as encode_data from .. import get_search -from ..elasticsearch import ElasticSearch +from ..interface import SearchInterface def _get_search(): @@ -21,7 +21,7 @@ return search -app = RPCServerApp(__name__, backend_class=ElasticSearch, backend_factory=_get_search) +app = RPCServerApp(__name__, backend_class=SearchInterface, backend_factory=_get_search) search = None diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -11,7 +11,6 @@ from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk, scan -from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier from swh.model import model @@ -46,7 +45,6 @@ def __init__(self, hosts: List[str]): self._backend = Elasticsearch(hosts=hosts) - @remote_api_endpoint("check") def check(self): return self._backend.ping() @@ -89,14 +87,12 @@ }, ) - @remote_api_endpoint("flush") def flush(self) -> None: """Blocks until all previous calls to _update() are completely applied.""" self._backend.indices.refresh(index="_all") - @remote_api_endpoint("origin/update") - def origin_update(self, documents: Iterable[dict]) -> None: + def origin_update(self, documents: Iterable[Dict]) -> None: documents = map(_sanitize_origin, documents) documents_with_sha1 = ( (origin_identifier(document), document) for document in documents @@ -120,12 +116,11 @@ for hit in results: yield self._backend.termvectors(index="origin", id=hit["_id"], fields=["*"]) - @remote_api_endpoint("origin/search") def origin_search( self, *, url_pattern: Optional[str] = None, - metadata_pattern: str = None, + metadata_pattern: Optional[str] = None, with_visit: bool = False, page_token: Optional[str] = None, limit: int = 50, diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -9,7 +9,6 @@ from collections import defaultdict from typing import Any, Dict, Iterable, Iterator, List, Optional -from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier from swh.search.interface import PagedResult @@ -19,7 +18,6 @@ def __init__(self): pass - @remote_api_endpoint("check") def check(self): return True @@ -37,8 +35,7 @@ _url_splitter = re.compile(r"\W") - @remote_api_endpoint("origin/update") - def origin_update(self, documents: Iterable[dict]) -> None: + def origin_update(self, documents: Iterable[Dict]) -> None: for document in documents: document = document.copy() id_ = origin_identifier(document) @@ -48,7 +45,6 @@ if id_ not in self._origin_ids: self._origin_ids.append(id_) - @remote_api_endpoint("origin/search") def origin_search( self, *, diff --git a/swh/search/interface.py b/swh/search/interface.py --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -3,10 +3,62 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import TypeVar +from typing import Any, Dict, Iterable, Optional, TypeVar from swh.core.api.classes import PagedResult as CorePagedResult +from swh.core.api import remote_api_endpoint + TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] + + +class SearchInterface: + @remote_api_endpoint("check") + def check(self): + """Dedicated method to execute some specific check per implementation. + + """ + ... + + @remote_api_endpoint("flush") + def flush(self) -> None: + """Blocks until all previous calls to _update() are completely + applied. + + """ + ... + + @remote_api_endpoint("origin/update") + def origin_update(self, documents: Iterable[Dict]) -> None: + """Persist documents to the search backend. + + """ + ... + + @remote_api_endpoint("origin/search") + def origin_search( + self, + *, + url_pattern: Optional[str] = None, + metadata_pattern: Optional[str] = None, + with_visit: bool = False, + page_token: Optional[str] = None, + limit: int = 50, + ) -> PagedResult[Dict[str, Any]]: + """Searches for origins matching the `url_pattern`. + + Args: + url_pattern: Part of the URL to search for + with_visit: Whether origins with no visit are to be + filtered out + page_token: Opaque value used for pagination + limit: number of results to return + + Returns: + PagedResult of origin dicts matching the search criteria. If next_page_token + is None, there is no longer data to retrieve. + + """ + ... diff --git a/swh/search/tests/test_init.py b/swh/search/tests/test_init.py --- a/swh/search/tests/test_init.py +++ b/swh/search/tests/test_init.py @@ -3,10 +3,13 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import inspect + import pytest from swh.search import get_search +from swh.search.interface import SearchInterface from swh.search.elasticsearch import ElasticSearch from swh.search.api.client import RemoteSearch from swh.search.in_memory import InMemorySearch @@ -42,3 +45,42 @@ with pytest.warns(DeprecationWarning): concrete_search = get_search(class_, args=kwargs) assert isinstance(concrete_search, expected_class) + + +@pytest.mark.parametrize("class_,expected_class,kwargs", SEARCH_IMPLEMENTATIONS) +def test_types(mocker, class_, expected_class, kwargs): + """Checks all methods of SearchInterface are implemented by this + backend, and that they have the same signature. + + """ + mocker.patch("swh.search.elasticsearch.Elasticsearch") + if kwargs: + concrete_search = get_search(class_, **kwargs) + else: + concrete_search = get_search(class_) + + # Create an instance of the protocol (which cannot be instantiated + # directly, so this creates a subclass, then instantiates it) + interface = type("_", (SearchInterface,), {})() + + for meth_name in dir(interface): + if meth_name.startswith("_"): + continue + interface_meth = getattr(interface, meth_name) + + missing_methods = [] + + try: + concrete_meth = getattr(concrete_search, meth_name) + except AttributeError: + if not getattr(interface_meth, "deprecated_endpoint", False): + # The backend is missing a (non-deprecated) endpoint + missing_methods.append(meth_name) + continue + + expected_signature = inspect.signature(interface_meth) + actual_signature = inspect.signature(concrete_meth) + + assert expected_signature == actual_signature, meth_name + + assert missing_methods == []