diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py index 9a74266..5db5d34 100644 --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -1,213 +1,213 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from typing import Any, Iterable, Dict, List, Iterator, Optional from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk, scan import msgpack from swh.core.api import remote_api_endpoint from swh.model import model from swh.model.identifiers import origin_identifier def _sanitize_origin(origin): origin = origin.copy() res = {"url": origin.pop("url")} for field_name in ("intrinsic_metadata", "has_visits"): if field_name in origin: res[field_name] = origin.pop(field_name) return res class ElasticSearch: def __init__(self, hosts: List[str]): self._backend = Elasticsearch(hosts=hosts) @remote_api_endpoint("check") def check(self): return self._backend.ping() def deinitialize(self) -> None: """Removes all indices from the Elasticsearch backend""" self._backend.indices.delete(index="*") def initialize(self) -> None: """Declare Elasticsearch indices and mappings""" if not self._backend.indices.exists(index="origin"): self._backend.indices.create(index="origin") self._backend.indices.put_mapping( index="origin", body={ "properties": { "sha1": {"type": "keyword", "doc_values": True,}, "url": { "type": "text", # To split URLs into token on any character # that is not alphanumerical "analyzer": "simple", "fields": { "as_you_type": { "type": "search_as_you_type", "analyzer": "simple", } }, }, "has_visits": {"type": "boolean",}, "intrinsic_metadata": { "type": "nested", "properties": { "@context": { # don't bother indexing tokens "type": "keyword", } }, }, } }, ) @remote_api_endpoint("flush") def flush(self) -> None: """Blocks until all previous calls to _update() are completely applied.""" self._backend.indices.refresh(index="_all") @remote_api_endpoint("origin/update") def origin_update(self, documents: Iterable[dict]) -> None: documents = map(_sanitize_origin, documents) documents_with_sha1 = ( (origin_identifier(document), document) for document in documents ) actions = [ { "_op_type": "update", "_id": sha1, "_index": "origin", "doc": {**document, "sha1": sha1,}, "doc_as_upsert": True, } for (sha1, document) in documents_with_sha1 ] bulk(self._backend, actions, index="origin") def origin_dump(self) -> Iterator[model.Origin]: """Returns all content in Elasticsearch's index. Not exposed publicly; but useful for tests.""" results = scan(self._backend, index="*") for hit in results: yield self._backend.termvectors(index="origin", id=hit["_id"], fields=["*"]) @remote_api_endpoint("origin/search") def origin_search( self, *, url_pattern: str = None, metadata_pattern: str = None, with_visit: bool = False, page_token: str = None, - count: int = 50 + count: int = 50, ) -> Dict[str, object]: """Searches for origins matching the `url_pattern`. Args: url_pattern (str): Part of thr URL to search for with_visit (bool): Whether origins with no visit are to be filtered out page_token (str): Opaque value used for pagination. count (int): number of results to return. Returns: a dictionary with keys: * `next_page_token`: opaque value used for fetching more results. `None` if there are no more result. * `results`: list of dictionaries with key: * `url`: URL of a matching origin """ query_clauses = [] # type: List[Dict[str, Any]] if url_pattern: query_clauses.append( { "multi_match": { "query": url_pattern, "type": "bool_prefix", "operator": "and", "fields": [ "url.as_you_type", "url.as_you_type._2gram", "url.as_you_type._3gram", ], } } ) if metadata_pattern: query_clauses.append( { "nested": { "path": "intrinsic_metadata", "query": { "multi_match": { "query": metadata_pattern, "operator": "and", "fields": ["intrinsic_metadata.*"], } }, } } ) if not query_clauses: raise ValueError( - "At least one of url_pattern and metadata_pattern " "must be provided." + "At least one of url_pattern and metadata_pattern must be provided." ) if with_visit: query_clauses.append({"term": {"has_visits": True,}}) body = { "query": {"bool": {"must": query_clauses,}}, "size": count, "sort": [{"_score": "desc"}, {"sha1": "asc"},], } if page_token: # TODO: use ElasticSearch's scroll API? page_token_content = msgpack.loads(base64.b64decode(page_token), raw=True) body["search_after"] = [ page_token_content[b"score"], page_token_content[b"sha1"].decode("ascii"), ] res = self._backend.search(index="origin", body=body, size=count,) hits = res["hits"]["hits"] if len(hits) == count: last_hit = hits[-1] next_page_token_content = { b"score": last_hit["_score"], b"sha1": last_hit["_source"]["sha1"], } next_page_token = base64.b64encode( msgpack.dumps(next_page_token_content) ) # type: Optional[bytes] else: next_page_token = None return { "next_page_token": next_page_token, "results": [ { # TODO: also add 'id'? "url": hit["_source"]["url"], } for hit in hits ], } diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py index f5fc665..0699241 100644 --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -1,127 +1,127 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from collections import defaultdict import itertools import re from typing import Any, Dict, Iterable, Iterator, List, Optional import msgpack from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier def _sanitize_origin(origin): origin = origin.copy() res = {"url": origin.pop("url")} for field_name in ("type", "intrinsic_metadata"): if field_name in origin: res[field_name] = origin.pop(field_name) return res class InMemorySearch: def __init__(self): pass @remote_api_endpoint("check") def check(self): return True def deinitialize(self) -> None: if hasattr(self, "_origins"): del self._origins del self._origin_ids def initialize(self) -> None: self._origins = defaultdict(dict) # type: Dict[str, Dict[str, Any]] self._origin_ids = [] # type: List[str] def flush(self) -> None: pass _url_splitter = re.compile(r"\W") @remote_api_endpoint("origin/update") def origin_update(self, documents: Iterable[dict]) -> None: for document in documents: document = document.copy() id_ = origin_identifier(document) if "url" in document: document["_url_tokens"] = set(self._url_splitter.split(document["url"])) self._origins[id_].update(document) if id_ not in self._origin_ids: self._origin_ids.append(id_) @remote_api_endpoint("origin/search") def origin_search( self, *, url_pattern: str = None, metadata_pattern: str = None, with_visit: bool = False, page_token: str = None, - count: int = 50 + count: int = 50, ) -> Dict[str, object]: matches = ( self._origins[id_] for id_ in self._origin_ids ) # type: Iterator[Dict[str, Any]] if url_pattern: tokens = set(self._url_splitter.split(url_pattern)) def predicate(match): missing_tokens = tokens - match["_url_tokens"] if len(missing_tokens) == 0: return True elif len(missing_tokens) > 1: return False else: # There is one missing token, look up by prefix. (missing_token,) = missing_tokens return any( token.startswith(missing_token) for token in match["_url_tokens"] ) matches = filter(predicate, matches) if metadata_pattern: raise NotImplementedError( "Metadata search is not implemented in the in-memory backend." ) if not url_pattern and not metadata_pattern: raise ValueError( - "At least one of url_pattern and metadata_pattern " "must be provided." + "At least one of url_pattern and metadata_pattern must be provided." ) if with_visit: matches = filter(lambda o: o.get("has_visits"), matches) if page_token: page_token_content = msgpack.loads(base64.b64decode(page_token)) start_at_index = page_token_content[b"start_at_index"] else: start_at_index = 0 hits = list(itertools.islice(matches, start_at_index, start_at_index + count)) if len(hits) == count: next_page_token_content = { b"start_at_index": start_at_index + count, } next_page_token = base64.b64encode( msgpack.dumps(next_page_token_content) ) # type: Optional[bytes] else: next_page_token = None return { "next_page_token": next_page_token, "results": [{"url": hit["url"]} for hit in hits], } diff --git a/swh/search/tests/test_api_client.py b/swh/search/tests/test_api_client.py index a1fe8e3..ad6b4d0 100644 --- a/swh/search/tests/test_api_client.py +++ b/swh/search/tests/test_api_client.py @@ -1,43 +1,43 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest import pytest from swh.core.api.tests.server_testing import ServerTestFixture from swh.search import get_search from swh.search.api.server import app from .test_search import CommonSearchTest class TestRemoteSearch(CommonSearchTest, ServerTestFixture, unittest.TestCase): @pytest.fixture(autouse=True) def _instantiate_search(self, elasticsearch_host): self._elasticsearch_host = elasticsearch_host def setUp(self): self.config = { "search": { "cls": "elasticsearch", "args": {"hosts": [self._elasticsearch_host],}, } } self.app = app super().setUp() self.reset() self.search = get_search("remote", {"url": self.url(),}) def reset(self): search = get_search("elasticsearch", {"hosts": [self._elasticsearch_host],}) search.deinitialize() search.initialize() @pytest.mark.skip( - "Elasticsearch also returns close matches, " "so this test would fail" + "Elasticsearch also returns close matches, so this test would fail" ) def test_origin_url_paging(self, count): pass diff --git a/swh/search/tests/test_cli.py b/swh/search/tests/test_cli.py index dc1d77a..7e32c1c 100644 --- a/swh/search/tests/test_cli.py +++ b/swh/search/tests/test_cli.py @@ -1,125 +1,125 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import tempfile from unittest.mock import patch, MagicMock from click.testing import CliRunner from swh.journal.serializers import value_to_kafka from swh.journal.tests.utils import MockedKafkaConsumer from swh.search.cli import cli from .test_elasticsearch import BaseElasticsearchTest CLI_CONFIG = """ search: cls: elasticsearch args: hosts: - '{elasticsearch_host}' """ JOURNAL_OBJECTS_CONFIG = """ journal: brokers: - 192.0.2.1 prefix: swh.journal.objects group_id: test-consumer """ class MockedKafkaConsumerWithTopics(MockedKafkaConsumer): def list_topics(self, timeout=None): return { "swh.journal.objects.origin", "swh.journal.objects.origin_visit", } def invoke(catch_exceptions, args, config="", *, elasticsearch_host): runner = CliRunner() with tempfile.NamedTemporaryFile("a", suffix=".yml") as config_fd: config_fd.write( (CLI_CONFIG + config).format(elasticsearch_host=elasticsearch_host) ) config_fd.seek(0) result = runner.invoke(cli, ["-C" + config_fd.name] + args) if not catch_exceptions and result.exception: print(result.output) raise result.exception return result class CliTestCase(BaseElasticsearchTest): def test__journal_client__origin(self): """Tests the re-indexing when origin_batch_size*task_batch_size is a divisor of nb_origins.""" topic = "swh.journal.objects.origin" value = value_to_kafka({"url": "http://foobar.baz",}) message = MagicMock() message.error.return_value = None message.topic.return_value = topic message.value.return_value = value mock_consumer = MockedKafkaConsumerWithTopics([message]) with patch("swh.journal.client.Consumer", return_value=mock_consumer): result = invoke( False, ["journal-client", "objects", "--stop-after-objects", "1",], JOURNAL_OBJECTS_CONFIG, elasticsearch_host=self._elasticsearch_host, ) # Check the output - expected_output = "Processed 1 messages.\n" "Done.\n" + expected_output = "Processed 1 messages.\nDone.\n" assert result.exit_code == 0, result.output assert result.output == expected_output self.search.flush() results = self.search.origin_search(url_pattern="foobar") assert results == { "next_page_token": None, "results": [{"url": "http://foobar.baz"}], } results = self.search.origin_search(url_pattern="foobar", with_visit=True) assert results == {"next_page_token": None, "results": []} def test__journal_client__origin_visit(self): """Tests the re-indexing when origin_batch_size*task_batch_size is a divisor of nb_origins.""" topic = "swh.journal.objects.origin_visit" value = value_to_kafka({"origin": "http://foobar.baz",}) message = MagicMock() message.error.return_value = None message.topic.return_value = topic message.value.return_value = value mock_consumer = MockedKafkaConsumerWithTopics([message]) with patch("swh.journal.client.Consumer", return_value=mock_consumer): result = invoke( False, ["journal-client", "objects", "--stop-after-objects", "1",], JOURNAL_OBJECTS_CONFIG, elasticsearch_host=self._elasticsearch_host, ) # Check the output - expected_output = "Processed 1 messages.\n" "Done.\n" + expected_output = "Processed 1 messages.\nDone.\n" assert result.exit_code == 0, result.output assert result.output == expected_output self.search.flush() results = self.search.origin_search(url_pattern="foobar", with_visit=True) assert results == { "next_page_token": None, "results": [{"url": "http://foobar.baz"}], }