diff --git a/swh/search/__init__.py b/swh/search/__init__.py index f4cb97c..61a9c71 100644 --- a/swh/search/__init__.py +++ b/swh/search/__init__.py @@ -1,51 +1,53 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import importlib import warnings from typing import Any, Dict +from swh.search.interface import SearchInterface + SEARCH_IMPLEMENTATIONS = { "elasticsearch": ".elasticsearch.ElasticSearch", "remote": ".api.client.RemoteSearch", "memory": ".in_memory.InMemorySearch", } -def get_search(cls: str, **kwargs: Dict[str, Any]): +def get_search(cls: str, **kwargs: Dict[str, Any]) -> SearchInterface: """Get an search object of class `cls` with arguments `args`. Args: cls: search's class, either 'local' or 'remote' args: dictionary of arguments passed to the search class constructor Returns: an instance of swh.search's classes (either local or remote) Raises: ValueError if passed an unknown search class. """ if "args" in kwargs: warnings.warn( 'Explicit "args" key is deprecated, use keys directly instead.', DeprecationWarning, ) kwargs = kwargs["args"] class_path = SEARCH_IMPLEMENTATIONS.get(cls) if class_path is None: raise ValueError( "Unknown search class `%s`. Supported: %s" % (cls, ", ".join(SEARCH_IMPLEMENTATIONS)) ) (module_path, class_name) = class_path.rsplit(".", 1) module = importlib.import_module(module_path, package=__package__) Search = getattr(module, class_name) return Search(**kwargs) diff --git a/swh/search/api/client.py b/swh/search/api/client.py index 786efad..bd2bdee 100644 --- a/swh/search/api/client.py +++ b/swh/search/api/client.py @@ -1,14 +1,14 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.core.api import RPCClient -from ..elasticsearch import ElasticSearch +from ..interface import SearchInterface class RemoteSearch(RPCClient): """Proxy to a remote search API""" - backend_class = ElasticSearch + backend_class = SearchInterface diff --git a/swh/search/api/server.py b/swh/search/api/server.py index bf994dc..6d16853 100644 --- a/swh/search/api/server.py +++ b/swh/search/api/server.py @@ -1,86 +1,86 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import os from swh.core import config from swh.core.api import RPCServerApp, error_handler, encode_data_server as encode_data from .. import get_search -from ..elasticsearch import ElasticSearch +from ..interface import SearchInterface def _get_search(): global search if not search: search = get_search(**app.config["search"]) return search -app = RPCServerApp(__name__, backend_class=ElasticSearch, backend_factory=_get_search) +app = RPCServerApp(__name__, backend_class=SearchInterface, backend_factory=_get_search) search = None @app.errorhandler(Exception) def my_error_handler(exception): return error_handler(exception, encode_data) @app.route("/") def index(): return "SWH Search API server" api_cfg = None def load_and_check_config(config_file, type="elasticsearch"): """Check the minimal configuration is set to run the api or raise an error explanation. Args: config_file (str): Path to the configuration file to load type (str): configuration type. For 'local' type, more checks are done. Raises: Error if the setup is not as expected Returns: configuration as a dict """ if not config_file: raise EnvironmentError("Configuration file must be defined") if not os.path.exists(config_file): raise FileNotFoundError("Configuration file %s does not exist" % (config_file,)) cfg = config.read(config_file) if "search" not in cfg: raise KeyError("Missing 'search' configuration") return cfg def make_app_from_configfile(): """Run the WSGI app from the webserver, loading the configuration from a configuration file. SWH_CONFIG_FILENAME environment variable defines the configuration path to load. """ global api_cfg if not api_cfg: config_file = os.environ.get("SWH_CONFIG_FILENAME") api_cfg = load_and_check_config(config_file) app.config.update(api_cfg) handler = logging.StreamHandler() app.logger.addHandler(handler) return app diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py index 247e079..33cea95 100644 --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -1,220 +1,197 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 import msgpack from typing import Any, Iterable, Dict, List, Iterator, Optional from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk, scan -from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier from swh.model import model from swh.search.interface import PagedResult def _sanitize_origin(origin): origin = origin.copy() res = {"url": origin.pop("url")} for field_name in ("intrinsic_metadata", "has_visits"): if field_name in origin: res[field_name] = origin.pop(field_name) return res def token_encode(index_to_tokenize: Dict[bytes, Any]) -> str: """Tokenize as string an index page result from a search """ page_token = base64.b64encode(msgpack.dumps(index_to_tokenize)) return page_token.decode() def token_decode(page_token: str) -> Dict[bytes, Any]: """Read the page_token """ return msgpack.loads(base64.b64decode(page_token.encode()), raw=True) class ElasticSearch: def __init__(self, hosts: List[str]): self._backend = Elasticsearch(hosts=hosts) - @remote_api_endpoint("check") def check(self): return self._backend.ping() def deinitialize(self) -> None: """Removes all indices from the Elasticsearch backend""" self._backend.indices.delete(index="*") def initialize(self) -> None: """Declare Elasticsearch indices and mappings""" if not self._backend.indices.exists(index="origin"): self._backend.indices.create(index="origin") self._backend.indices.put_mapping( index="origin", body={ "properties": { "sha1": {"type": "keyword", "doc_values": True,}, "url": { "type": "text", # To split URLs into token on any character # that is not alphanumerical "analyzer": "simple", "fields": { "as_you_type": { "type": "search_as_you_type", "analyzer": "simple", } }, }, "has_visits": {"type": "boolean",}, "intrinsic_metadata": { "type": "nested", "properties": { "@context": { # don't bother indexing tokens "type": "keyword", } }, }, } }, ) - @remote_api_endpoint("flush") def flush(self) -> None: - """Blocks until all previous calls to _update() are completely - applied.""" self._backend.indices.refresh(index="_all") - @remote_api_endpoint("origin/update") - def origin_update(self, documents: Iterable[dict]) -> None: + def origin_update(self, documents: Iterable[Dict]) -> None: documents = map(_sanitize_origin, documents) documents_with_sha1 = ( (origin_identifier(document), document) for document in documents ) actions = [ { "_op_type": "update", "_id": sha1, "_index": "origin", "doc": {**document, "sha1": sha1,}, "doc_as_upsert": True, } for (sha1, document) in documents_with_sha1 ] bulk(self._backend, actions, index="origin") def origin_dump(self) -> Iterator[model.Origin]: - """Returns all content in Elasticsearch's index. Not exposed - publicly; but useful for tests.""" results = scan(self._backend, index="*") for hit in results: yield self._backend.termvectors(index="origin", id=hit["_id"], fields=["*"]) - @remote_api_endpoint("origin/search") def origin_search( self, *, url_pattern: Optional[str] = None, - metadata_pattern: str = None, + metadata_pattern: Optional[str] = None, with_visit: bool = False, page_token: Optional[str] = None, limit: int = 50, ) -> PagedResult[Dict[str, Any]]: - """Searches for origins matching the `url_pattern`. - - Args: - url_pattern: Part of the URL to search for - with_visit: Whether origins with no visit are to be - filtered out - page_token: Opaque value used for pagination - limit: number of results to return - - Returns: - PagedResult of origin dicts matching the search criteria. If next_page_token - is None, there is no longer data to retrieve. - - """ query_clauses: List[Dict[str, Any]] = [] if url_pattern: query_clauses.append( { "multi_match": { "query": url_pattern, "type": "bool_prefix", "operator": "and", "fields": [ "url.as_you_type", "url.as_you_type._2gram", "url.as_you_type._3gram", ], } } ) if metadata_pattern: query_clauses.append( { "nested": { "path": "intrinsic_metadata", "query": { "multi_match": { "query": metadata_pattern, "operator": "and", "fields": ["intrinsic_metadata.*"], } }, } } ) if not query_clauses: raise ValueError( "At least one of url_pattern and metadata_pattern must be provided." ) next_page_token: Optional[str] = None if with_visit: query_clauses.append({"term": {"has_visits": True,}}) body = { "query": {"bool": {"must": query_clauses,}}, "sort": [{"_score": "desc"}, {"sha1": "asc"},], } if page_token: # TODO: use ElasticSearch's scroll API? page_token_content = token_decode(page_token) body["search_after"] = [ page_token_content[b"score"], page_token_content[b"sha1"].decode("ascii"), ] res = self._backend.search(index="origin", body=body, size=limit) hits = res["hits"]["hits"] if len(hits) == limit: last_hit = hits[-1] next_page_token_content = { b"score": last_hit["_score"], b"sha1": last_hit["_source"]["sha1"], } next_page_token = token_encode(next_page_token_content) assert len(hits) <= limit return PagedResult( results=[{"url": hit["_source"]["url"]} for hit in hits], next_page_token=next_page_token, ) diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py index 1ddbbc9..7858a7f 100644 --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -1,111 +1,107 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import itertools import re from collections import defaultdict from typing import Any, Dict, Iterable, Iterator, List, Optional -from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier from swh.search.interface import PagedResult class InMemorySearch: def __init__(self): pass - @remote_api_endpoint("check") def check(self): return True def deinitialize(self) -> None: if hasattr(self, "_origins"): del self._origins del self._origin_ids def initialize(self) -> None: self._origins: Dict[str, Dict[str, Any]] = defaultdict(dict) self._origin_ids: List[str] = [] def flush(self) -> None: pass _url_splitter = re.compile(r"\W") - @remote_api_endpoint("origin/update") - def origin_update(self, documents: Iterable[dict]) -> None: + def origin_update(self, documents: Iterable[Dict]) -> None: for document in documents: document = document.copy() id_ = origin_identifier(document) if "url" in document: document["_url_tokens"] = set(self._url_splitter.split(document["url"])) self._origins[id_].update(document) if id_ not in self._origin_ids: self._origin_ids.append(id_) - @remote_api_endpoint("origin/search") def origin_search( self, *, url_pattern: Optional[str] = None, metadata_pattern: Optional[str] = None, with_visit: bool = False, page_token: Optional[str] = None, limit: int = 50, ) -> PagedResult[Dict[str, Any]]: hits: Iterator[Dict[str, Any]] = ( self._origins[id_] for id_ in self._origin_ids ) if url_pattern: tokens = set(self._url_splitter.split(url_pattern)) def predicate(match): missing_tokens = tokens - match["_url_tokens"] if len(missing_tokens) == 0: return True elif len(missing_tokens) > 1: return False else: # There is one missing token, look up by prefix. (missing_token,) = missing_tokens return any( token.startswith(missing_token) for token in match["_url_tokens"] ) hits = filter(predicate, hits) if metadata_pattern: raise NotImplementedError( "Metadata search is not implemented in the in-memory backend." ) if not url_pattern and not metadata_pattern: raise ValueError( "At least one of url_pattern and metadata_pattern must be provided." ) next_page_token: Optional[str] = None if with_visit: hits = filter(lambda o: o.get("has_visits"), hits) start_at_index = int(page_token) if page_token else 0 origins = [ {"url": hit["url"]} for hit in itertools.islice(hits, start_at_index, start_at_index + limit) ] if len(origins) == limit: next_page_token = str(start_at_index + limit) assert len(origins) <= limit return PagedResult(results=origins, next_page_token=next_page_token,) diff --git a/swh/search/interface.py b/swh/search/interface.py index ae201f6..57c7181 100644 --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -1,12 +1,64 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import TypeVar +from typing import Any, Dict, Iterable, Optional, TypeVar from swh.core.api.classes import PagedResult as CorePagedResult +from swh.core.api import remote_api_endpoint + TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] + + +class SearchInterface: + @remote_api_endpoint("check") + def check(self): + """Dedicated method to execute some specific check per implementation. + + """ + ... + + @remote_api_endpoint("flush") + def flush(self) -> None: + """Blocks until all previous calls to _update() are completely + applied. + + """ + ... + + @remote_api_endpoint("origin/update") + def origin_update(self, documents: Iterable[Dict]) -> None: + """Persist documents to the search backend. + + """ + ... + + @remote_api_endpoint("origin/search") + def origin_search( + self, + *, + url_pattern: Optional[str] = None, + metadata_pattern: Optional[str] = None, + with_visit: bool = False, + page_token: Optional[str] = None, + limit: int = 50, + ) -> PagedResult[Dict[str, Any]]: + """Searches for origins matching the `url_pattern`. + + Args: + url_pattern: Part of the URL to search for + with_visit: Whether origins with no visit are to be + filtered out + page_token: Opaque value used for pagination + limit: number of results to return + + Returns: + PagedResult of origin dicts matching the search criteria. If next_page_token + is None, there is no longer data to retrieve. + + """ + ... diff --git a/swh/search/tests/test_init.py b/swh/search/tests/test_init.py index 90451aa..2ea535f 100644 --- a/swh/search/tests/test_init.py +++ b/swh/search/tests/test_init.py @@ -1,44 +1,86 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import inspect + import pytest from swh.search import get_search +from swh.search.interface import SearchInterface from swh.search.elasticsearch import ElasticSearch from swh.search.api.client import RemoteSearch from swh.search.in_memory import InMemorySearch SEARCH_IMPLEMENTATIONS_KWARGS = [ ("remote", RemoteSearch, {"url": "localhost"}), ("elasticsearch", ElasticSearch, {"hosts": ["localhost"]}), ] SEARCH_IMPLEMENTATIONS = SEARCH_IMPLEMENTATIONS_KWARGS + [ ("memory", InMemorySearch, None), ] def test_get_search_failure(): with pytest.raises(ValueError, match="Unknown search class"): get_search("unknown-search") @pytest.mark.parametrize("class_,expected_class,kwargs", SEARCH_IMPLEMENTATIONS) def test_get_search(mocker, class_, expected_class, kwargs): mocker.patch("swh.search.elasticsearch.Elasticsearch") if kwargs: concrete_search = get_search(class_, **kwargs) else: concrete_search = get_search(class_) assert isinstance(concrete_search, expected_class) @pytest.mark.parametrize("class_,expected_class,kwargs", SEARCH_IMPLEMENTATIONS_KWARGS) def test_get_search_deprecation_warning(mocker, class_, expected_class, kwargs): with pytest.warns(DeprecationWarning): concrete_search = get_search(class_, args=kwargs) assert isinstance(concrete_search, expected_class) + + +@pytest.mark.parametrize("class_,expected_class,kwargs", SEARCH_IMPLEMENTATIONS) +def test_types(mocker, class_, expected_class, kwargs): + """Checks all methods of SearchInterface are implemented by this + backend, and that they have the same signature. + + """ + mocker.patch("swh.search.elasticsearch.Elasticsearch") + if kwargs: + concrete_search = get_search(class_, **kwargs) + else: + concrete_search = get_search(class_) + + # Create an instance of the protocol (which cannot be instantiated + # directly, so this creates a subclass, then instantiates it) + interface = type("_", (SearchInterface,), {})() + + for meth_name in dir(interface): + if meth_name.startswith("_"): + continue + interface_meth = getattr(interface, meth_name) + + missing_methods = [] + + try: + concrete_meth = getattr(concrete_search, meth_name) + except AttributeError: + if not getattr(interface_meth, "deprecated_endpoint", False): + # The backend is missing a (non-deprecated) endpoint + missing_methods.append(meth_name) + continue + + expected_signature = inspect.signature(interface_meth) + actual_signature = inspect.signature(concrete_meth) + + assert expected_signature == actual_signature, meth_name + + assert missing_methods == []