diff --git a/swh/search/api/client.py b/swh/search/api/client.py index bd2bdee..80b63b8 100644 --- a/swh/search/api/client.py +++ b/swh/search/api/client.py @@ -1,14 +1,16 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.core.api import RPCClient +from .. import exc from ..interface import SearchInterface class RemoteSearch(RPCClient): """Proxy to a remote search API""" backend_class = SearchInterface + reraise_exceptions = [getattr(exc, exc_name) for exc_name in exc.__all__] diff --git a/swh/search/exc.py b/swh/search/exc.py new file mode 100644 index 0000000..0dd6d9f --- /dev/null +++ b/swh/search/exc.py @@ -0,0 +1,18 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +__all__ = ("SearchException", "SearchQuerySyntaxError") + + +class SearchException(Exception): + """Base exception for errors specific to swh-search""" + + pass + + +class SearchQuerySyntaxError(SearchException): + """Raised when the 'query' argument of origin_search cannot be parsed""" + + pass diff --git a/swh/search/tests/test_elasticsearch.py b/swh/search/tests/test_elasticsearch.py index 943f9ed..585d9c9 100644 --- a/swh/search/tests/test_elasticsearch.py +++ b/swh/search/tests/test_elasticsearch.py @@ -1,167 +1,179 @@ -# Copyright (C) 2019-2021 The Software Heritage developers +# Copyright (C) 2019-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from datetime import datetime, timedelta, timezone from textwrap import dedent import types import unittest from elasticsearch.helpers.errors import BulkIndexError import pytest +from swh.search.exc import SearchQuerySyntaxError from swh.search.metrics import OPERATIONS_METRIC from .test_search import CommonSearchTest class BaseElasticsearchTest(unittest.TestCase): @pytest.fixture(autouse=True) def _instantiate_search(self, swh_search, elasticsearch_host, mocker): self._elasticsearch_host = elasticsearch_host self.search = swh_search self.mocker = mocker # override self.search.origin_update to catch painless script errors # and pretty print them origin_update = self.search.origin_update def _origin_update(self, *args, **kwargs): script_error = False error_detail = "" try: origin_update(*args, **kwargs) except BulkIndexError as e: error = e.errors[0].get("update", {}).get("error", {}).get("caused_by") if error and "script_stack" in error: script_error = True error_detail = dedent( f""" Painless update script failed ({error.get('reason')}). error type: {error.get('caused_by', {}).get('type')} error reason: {error.get('caused_by', {}).get('reason')} script stack: """ ) error_detail += "\n".join(error["script_stack"]) else: raise e assert script_error is False, error_detail[1:] self.search.origin_update = types.MethodType(_origin_update, self.search) def reset(self): self.search.deinitialize() self.search.initialize() class TestElasticsearchSearch(CommonSearchTest, BaseElasticsearchTest): def test_metrics_update_duration(self): mock = self.mocker.patch("swh.search.metrics.statsd.timing") for url in ["http://foobar.bar", "http://foobar.baz"]: self.search.origin_update([{"url": url}]) assert mock.call_count == 2 def test_metrics_search_duration(self): mock = self.mocker.patch("swh.search.metrics.statsd.timing") for url_pattern in ["foobar", "foobaz"]: self.search.origin_search(url_pattern=url_pattern, with_visit=True) assert mock.call_count == 2 def test_metrics_indexation_counters(self): mock_es = self.mocker.patch("elasticsearch.helpers.bulk") mock_es.return_value = 2, ["error"] mock_metrics = self.mocker.patch("swh.search.metrics.statsd.increment") self.search.origin_update([{"url": "http://foobar.baz"}]) assert mock_metrics.call_count == 2 mock_metrics.assert_any_call( OPERATIONS_METRIC, 2, tags={ "endpoint": "origin_update", "object_type": "document", "operation": "index", }, ) mock_metrics.assert_any_call( OPERATIONS_METRIC, 1, tags={ "endpoint": "origin_update", "object_type": "document", "operation": "index_error", }, ) def test_write_alias_usage(self): mock = self.mocker.patch("elasticsearch.helpers.bulk") mock.return_value = 2, ["result"] self.search.origin_update([{"url": "http://foobar.baz"}]) assert mock.call_args[1]["index"] == "test-write" def test_read_alias_usage(self): mock = self.mocker.patch("elasticsearch.Elasticsearch.search") mock.return_value = {"hits": {"hits": []}} self.search.origin_search(url_pattern="foobar.baz") assert mock.call_args[1]["index"] == "test-read" def test_sort_by_and_limit_query(self): now = datetime.now(tz=timezone.utc).isoformat() now_minus_5_hours = ( datetime.now(tz=timezone.utc) - timedelta(hours=5) ).isoformat() now_plus_5_hours = ( datetime.now(tz=timezone.utc) + timedelta(hours=5) ).isoformat() ORIGINS = [ { "url": "http://foobar.1.com", "nb_visits": 1, "last_visit_date": now_minus_5_hours, "last_eventful_visit_date": now_minus_5_hours, }, { "url": "http://foobar.2.com", "nb_visits": 2, "last_visit_date": now, "last_eventful_visit_date": now, }, { "url": "http://foobar.3.com", "nb_visits": 3, "last_visit_date": now_plus_5_hours, "last_eventful_visit_date": now_minus_5_hours, }, ] self.search.origin_update(ORIGINS) self.search.flush() def _check_results(query, origin_indices): page = self.search.origin_search(url_pattern="foobar", query=query) results = [r["url"] for r in page.results] assert results == [ORIGINS[index]["url"] for index in origin_indices] _check_results("sort_by = [-visits]", [2, 1, 0]) _check_results("sort_by = [last_visit]", [0, 1, 2]) _check_results("sort_by = [-last_eventful_visit, visits]", [1, 0, 2]) _check_results("sort_by = [last_eventful_visit,-last_visit]", [2, 0, 1]) _check_results("sort_by = [-visits] limit = 1", [2]) _check_results("sort_by = [last_visit] and limit = 2", [0, 1]) _check_results("sort_by = [-last_eventful_visit, visits] limit = 3", [1, 0, 2]) + + def test_query_syntax_error(self): + ORIGINS = [ + {"url": "http://foobar.1.com",}, + ] + + self.search.origin_update(ORIGINS) + self.search.flush() + + with pytest.raises(SearchQuerySyntaxError): + self.search.origin_search(query="foobar") diff --git a/swh/search/translator.py b/swh/search/translator.py index af3b16d..b8bff14 100644 --- a/swh/search/translator.py +++ b/swh/search/translator.py @@ -1,307 +1,308 @@ -# Copyright (C) 2021 The Software Heritage developers +# Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging import os import tempfile from pkg_resources import resource_filename from tree_sitter import Language, Parser +from swh.search.exc import SearchQuerySyntaxError from swh.search.utils import get_expansion, unescape logger = logging.getLogger(__name__) class Translator: RANGE_OPERATOR_MAP = { ">": "gt", "<": "lt", ">=": "gte", "<=": "lte", } def __init__(self): ql_path = resource_filename("swh.search", "static/swh_ql.so") if not os.path.exists(ql_path): logging.info("%s does not exist, building in temporary directory", ql_path) self._build_dir = tempfile.TemporaryDirectory(prefix="swh.search-build") source_path = resource_filename("swh.search", "query_language") ql_path = os.path.join(self._build_dir.name, "swh_ql.so") Language.build_library(ql_path, [source_path]) search_ql = Language(ql_path, "swh_search_ql") self.parser = Parser() self.parser.set_language(search_ql) self.query = "" def parse_query(self, query): self.query = query.encode() tree = self.parser.parse(self.query) self.query_node = tree.root_node if self.query_node.has_error: - raise Exception("Invalid query") + raise SearchQuerySyntaxError("Invalid query") return self._traverse(self.query_node) def _traverse(self, node): if len(node.children) == 3 and node.children[1].type == "filters": # filters => ( filters ) return self._traverse(node.children[1]) # Go past the () brackets if node.type == "query": result = {} for child in node.children: # query => filters sort_by limit result[child.type] = self._traverse(child) return result if node.type == "filters": if len(node.children) == 1: # query => filters # filters => filters # filters => filter # Current node is just a wrapper, so go one level deep return self._traverse(node.children[0]) if len(node.children) == 3: # filters => filters conj_op filters filters1 = self._traverse(node.children[0]) conj_op = self._get_value(node.children[1]) filters2 = self._traverse(node.children[2]) if conj_op == "and": # "must" is equivalent to "AND" return {"bool": {"must": [filters1, filters2]}} if conj_op == "or": # "should" is equivalent to "OR" return {"bool": {"should": [filters1, filters2]}} if node.type == "filter": filter_category = node.children[0] return self._parse_filter(filter_category) if node.type == "sortBy": return self._parse_filter(node) if node.type == "limit": return self._parse_filter(node) return Exception( f"Unknown node type ({node.type}) " f"or unexpected number of children ({node.children})" ) def _get_value(self, node): if ( len(node.children) > 0 and node.children[0].type == "[" and node.children[-1].type == "]" ): # array return [self._get_value(child) for child in node.children if child.is_named] start = node.start_point[1] end = node.end_point[1] value = self.query[start:end].decode() if len(value) > 1 and ( (value[0] == "'" and value[-1] == "'") or (value[0] and value[-1] == '"') ): return unescape(value[1:-1]) if node.type in ["number", "numberVal"]: return int(value) return unescape(value) def _parse_filter(self, filter): if filter.type == "boundedListFilter": filter = filter.children[0] children = filter.children assert len(children) == 3 category = filter.type name, op, value = [self._get_value(child) for child in children] if category == "patternFilter": if name == "origin": return { "multi_match": { "query": value, "type": "bool_prefix", "operator": "and", "fields": [ "url.as_you_type", "url.as_you_type._2gram", "url.as_you_type._3gram", ], } } elif name == "metadata": return { "nested": { "path": "intrinsic_metadata", "query": { "multi_match": { "query": value, # Makes it so that the "foo bar" query returns # documents which contain "foo" in a field and "bar" # in a different field "type": "cross_fields", # All keywords must be found in a document for it to # be considered a match. # TODO: allow missing keywords? "operator": "and", # Searches on all fields of the intrinsic_metadata dict, # recursively. "fields": ["intrinsic_metadata.*"], # date{Created,Modified,Published} are of type date "lenient": True, } }, } } if category == "booleanFilter": if name == "visited": return {"term": {"has_visits": value == "true"}} if category == "numericFilter": if name == "visits": if op in ["=", "!="]: return { "bool": { ("must" if op == "=" else "must_not"): [ {"range": {"nb_visits": {"gte": value, "lte": value}}} ] } } else: return { "range": {"nb_visits": {self.RANGE_OPERATOR_MAP[op]: value}} } if category == "visitTypeFilter": if name == "visit_type": return {"terms": {"visit_types": value}} if category == "unboundedListFilter": value_array = value if name == "keyword": return { "nested": { "path": "intrinsic_metadata", "query": { "multi_match": { "query": " ".join(value_array), "fields": [ get_expansion("keywords", ".") + "^2", get_expansion("descriptions", "."), # "^2" boosts an origin's score by 2x # if it the queried keywords are # found in its intrinsic_metadata.keywords ], } }, } } elif name in ["language", "license"]: name_mapping = { "language": "programming_languages", "license": "licenses", } name = name_mapping[name] return { "nested": { "path": "intrinsic_metadata", "query": { "bool": { "should": [ {"match": {get_expansion(name, "."): val}} for val in value_array ], } }, } } if category == "dateFilter": if name in ["created", "modified", "published"]: if op in ["=", "!="]: return { "nested": { "path": "intrinsic_metadata", "query": { "bool": { ("must" if op == "=" else "must_not"): [ { "range": { get_expansion(f"date_{name}", "."): { "gte": value, "lte": value, } } } ], } }, } } return { "nested": { "path": "intrinsic_metadata", "query": { "bool": { "must": [ { "range": { get_expansion(f"date_{name}", "."): { self.RANGE_OPERATOR_MAP[op]: value, } } } ], } }, } } else: if op in ["=", "!="]: return { "bool": { ("must" if op == "=" else "must_not"): [ { "range": { f"{name}_date": {"gte": value, "lte": value,} } } ], } } return { "range": { f"{name}_date": { self.RANGE_OPERATOR_MAP[op]: value.replace("Z", "+00:00"), } } } if category == "sortBy": return value if category == "limit": return value - raise Exception(f"Unknown filter {category}.{name}") + raise SearchQuerySyntaxError(f"Unknown filter {category}.{name}")