diff --git a/swh/search/tests/test_translator.py b/swh/search/tests/test_translator.py new file mode 100644 --- /dev/null +++ b/swh/search/tests/test_translator.py @@ -0,0 +1,111 @@ +import pytest + +from swh.search.translator import Translator + + +def _test_results(query, expected): + output = Translator().parse_query(query) + print(output) + assert output == expected + + +def test_empty_query(): + query = "" + with pytest.raises(Exception): + _test_results(query, {}) + + +def test_conjunction_operators(): + query = "visited = false or visits > 2 and visits < 5" + expected = { + "filters": { + "bool": { + "should": [ + {"term": {"has_visits": False}}, + { + "bool": { + "must": [ + {"range": {"nb_visits": {"gt": "2"}}}, + {"range": {"nb_visits": {"lt": "5"}}}, + ] + } + }, + ] + } + } + } + _test_results(query, expected) + + +def test_conjunction_op_precedence_override(): + query = "(visited = false or visits > 2) and visits < 5" + expected = { + "filters": { + "bool": { + "must": [ + { + "bool": { + "should": [ + {"term": {"has_visits": False}}, + {"range": {"nb_visits": {"gt": "2"}}}, + ] + } + }, + {"range": {"nb_visits": {"lt": "5"}}}, + ] + } + } + } + + _test_results(query, expected) + + +def test_origin_and_metadata_filters(): + query = 'origin = django or metadata = "framework and web"' + expected = { + "filters": { + "bool": { + "should": [ + { + "multi_match": { + "query": "django", + "type": "bool_prefix", + "operator": "and", + "fields": [ + "url.as_you_type", + "url.as_you_type._2gram", + "url.as_you_type._3gram", + ], + } + }, + { + "nested": { + "path": "intrinsic_metadata", + "query": { + "multi_match": { + "query": '"framework and web"', + "type": "cross_fields", + "operator": "and", + "fields": ["intrinsic_metadata.*"], + "lenient": True, + } + }, + } + }, + ] + } + } + } + + _test_results(query, expected) + + +def test_limit_and_sortby(): + query = "visited = true sort_by = [-visits,last_visit] limit = 15" + expected = { + "filters": {"term": {"has_visits": True}}, + "sortBy": ["-visits", "last_visit"], + "limit": 15, + } + + _test_results(query, expected) diff --git a/swh/search/translator.py b/swh/search/translator.py new file mode 100644 --- /dev/null +++ b/swh/search/translator.py @@ -0,0 +1,161 @@ +from tree_sitter import Language, Parser + + +class Translator: + + RANGE_OPERATOR_MAP = { + ">": "gt", + "<": "lt", + ">=": "gte", + "<=": "lte", + } + + def __init__(self): + swh_ql = Language("static/swh_ql.so", "swh_search_ql") + self.parser = Parser() + self.parser.set_language(swh_ql) + self.query = "" + + def parse_query(self, query): + self.query = query + tree = self.parser.parse(bytes(query, "utf8")) + self.query_node = tree.root_node + + if self.query_node.has_error: + raise Exception("Invalid query") + + return self._traverse(self.query_node) + + def _traverse(self, node): + if len(node.children) == 3 and node.children[1].type == "filters": + # filters => ( filters ) + return self._traverse(node.children[1]) # Go past the () brackets + if node.type == "query": + result = {} + for child in node.children: + # query => filters sort_by limit + result[child.type] = self._traverse(child) + + return result + + if node.type == "filters": + if len(node.children) == 1: + # query => filters + # filters => filters + # filters => filter + # Current node is just a wrapper, so go one level deep + return self._traverse(node.children[0]) + + if node.type == "filters": + if len(node.children) == 3: + # filters => filters conj_op filters + filters1 = self._traverse(node.children[0]) + conj_op = self._get_value(node.children[1]) + filters2 = self._traverse(node.children[2]) + + if "and" in conj_op: + return {"bool": {"must": [filters1, filters2]}} + if "or" in conj_op: + return {"bool": {"should": [filters1, filters2]}} + + if node.type == "filter": + filter_category = node.children[0] + return self._parse_filter(filter_category) + + if node.type == "sortBy": + return self._parse_filter(node) + + if node.type == "limit": + return self._parse_filter(node) + + return {} + + def _get_value(self, node): + if ( + len(node.children) > 0 + and node.children[0].type == "[" + and node.children[-1].type == "]" + ): + # array + return [self._get_value(child) for child in node.children if child.is_named] + + start = node.start_point[1] + end = node.end_point[1] + + value = self.query[start:end] + + if node.type == "number": + return int(value) + + return value + + def _parse_filter(self, filter): + category = filter.type + if filter == "boundedListFilter": + filter = filter.children[0] + + children = filter.children + assert len(children) == 3 + + name, op, value = [self._get_value(child) for child in children] + + if category == "patternFilter": + if name == "origin": + return { + "multi_match": { + "query": value, + "type": "bool_prefix", + "operator": "and", + "fields": [ + "url.as_you_type", + "url.as_you_type._2gram", + "url.as_you_type._3gram", + ], + } + } + if name == "metadata": + return { + "nested": { + "path": "intrinsic_metadata", + "query": { + "multi_match": { + "query": value, + "type": "cross_fields", + "operator": "and", + "fields": ["intrinsic_metadata.*"], + "lenient": True, + } + }, + } + } + + if category == "booleanFilter": + if name == "visited": + return {"term": {"has_visits": value == "true"}} + + if category == "numericFilter": + if name == "visits": + if op in ["=", "!="]: + return "TODO" + else: + return { + "range": {"nb_visits": {self.RANGE_OPERATOR_MAP[op]: value}} + } + + if category == "visitTypeFilter": + if name == "visit_type": + return {"terms": {"visit_types": value}} + + if category == "unboundedListFilter": + return "TODO" + + if category == "dateFilter": + return "TODO" + + if category == "sortBy": + return value + + if category == "limit": + return value + + raise Exception("Unknown filter category")