diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -13,7 +13,12 @@ from swh.indexer import codemeta from swh.model import model from swh.model.identifiers import origin_identifier -from swh.search.interface import MinimalOriginDict, OriginDict, PagedResult +from swh.search.interface import ( + SORT_BY_OPTIONS, + MinimalOriginDict, + OriginDict, + PagedResult, +) from swh.search.metrics import send_metric, timed INDEX_NAME_PARAM = "index" @@ -309,6 +314,7 @@ min_last_revision_date: str = "", min_last_release_date: str = "", page_token: Optional[str] = None, + sort_by: List[str] = [], limit: int = 50, ) -> PagedResult[MinimalOriginDict]: query_clauses: List[Dict[str, Any]] = [] @@ -407,6 +413,21 @@ if visit_types is not None: query_clauses.append({"terms": {"visit_types": visit_types}}) + sorting_params = [] + + for field in sort_by: + order = "asc" + if field and field[0] == "-": + field = field[1:] + order = "desc" + + if field in SORT_BY_OPTIONS: + sorting_params.append({field: order}) + + sorting_params.extend( + [{"_score": "desc"}, {"sha1": "asc"},] + ) + body = { "query": { "bool": { @@ -414,7 +435,7 @@ "must_not": [{"term": {"blocklisted": True}}], } }, - "sort": [{"_score": "desc"}, {"sha1": "asc"},], + "sort": sorting_params, } if page_token: diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -4,13 +4,17 @@ # See top-level LICENSE file for more information from collections import defaultdict -from datetime import datetime -import itertools +from datetime import datetime, timezone import re from typing import Any, Dict, Iterable, Iterator, List, Optional from swh.model.identifiers import origin_identifier -from swh.search.interface import MinimalOriginDict, OriginDict, PagedResult +from swh.search.interface import ( + SORT_BY_OPTIONS, + MinimalOriginDict, + OriginDict, + PagedResult, +) _words_regexp = re.compile(r"\w+") @@ -33,6 +37,37 @@ return extract(d, values) +def _get_sorting_key(origin, field): + """Get value of the field from an origin for sorting origins. + + Here field should be a member of SORT_BY_OPTIONS. + If "-" is present at the start of field then invert the value + in a way that it reverses the sorting order. + """ + reversed = False + if field[0] == "-": + field = field[1:] + reversed = True + + datetime_max = datetime.max.replace(tzinfo=timezone.utc) + + if field in ["nb_visits"]: # unlike other options, nb_visits is of type integer + if reversed: + return -origin.get(field, 0) + else: + return origin.get(field, 0) + + elif field in SORT_BY_OPTIONS: + if reversed: + return datetime_max - datetime.fromisoformat( + origin.get(field, "0001-01-01T00:00:00Z").replace("Z", "+00:00") + ) + else: + return datetime.fromisoformat( + origin.get(field, "0001-01-01T00:00:00Z").replace("Z", "+00:00") + ) + + class InMemorySearch: def __init__(self): pass @@ -136,6 +171,7 @@ min_last_eventful_visit_date: str = "", min_last_revision_date: str = "", min_last_release_date: str = "", + sort_by: List[str] = [], limit: int = 50, ) -> PagedResult[MinimalOriginDict]: hits: Iterator[Dict[str, Any]] = ( @@ -239,11 +275,15 @@ hits, ) + hits_list = sorted( + hits, key=lambda o: tuple(_get_sorting_key(o, field) for field in sort_by), + ) + start_at_index = int(page_token) if page_token else 0 origins = [ {"url": hit["url"]} - for hit in itertools.islice(hits, start_at_index, start_at_index + limit) + for hit in hits_list[start_at_index : start_at_index + limit] ] if len(origins) == limit: diff --git a/swh/search/interface.py b/swh/search/interface.py --- a/swh/search/interface.py +++ b/swh/search/interface.py @@ -13,6 +13,14 @@ TResult = TypeVar("TResult") PagedResult = CorePagedResult[TResult, str] +SORT_BY_OPTIONS = [ + "nb_visits", + "last_visit_date", + "last_eventful_visit_date", + "last_revision_date", + "last_release_date", +] + class MinimalOriginDict(TypedDict): """Mandatory keys of an :class:`OriginDict`""" @@ -64,6 +72,7 @@ min_last_eventful_visit_date: str = "", min_last_revision_date: str = "", min_last_release_date: str = "", + sort_by: List[str] = [], limit: int = 50, ) -> PagedResult[MinimalOriginDict]: """Searches for origins matching the `url_pattern`. @@ -86,6 +95,11 @@ last_revision_date on or after the provided date(ISO format) min_last_release_date: Filter origins that have last_release_date on or after the provided date(ISO format) + sort_by: Sort results based on a list of fields mentioned in SORT_BY_OPTIONS + (nb_visits,last_visit_date, last_eventful_visit_date, + last_revision_date, last_release_date). + Return results in descending order if "-" is present at the beginning + otherwise in ascending order. limit: number of results to return Returns: diff --git a/swh/search/tests/test_search.py b/swh/search/tests/test_search.py --- a/swh/search/tests/test_search.py +++ b/swh/search/tests/test_search.py @@ -409,6 +409,46 @@ date_type="last_revision_date" ) + def test_origin_sort_by_search(self): + + now = datetime.now(tz=timezone.utc).isoformat() + now_minus_5_hours = ( + datetime.now(tz=timezone.utc) - timedelta(hours=5) + ).isoformat() + now_plus_5_hours = ( + datetime.now(tz=timezone.utc) + timedelta(hours=5) + ).isoformat() + + ORIGINS = [ + { + "url": "http://foobar.1.com", + "nb_visits": 1, + "last_visit_date": now_minus_5_hours, + }, + {"url": "http://foobar.2.com", "nb_visits": 2, "last_visit_date": now,}, + { + "url": "http://foobar.3.com", + "nb_visits": 3, + "last_visit_date": now_plus_5_hours, + }, + ] + self.search.origin_update(ORIGINS) + self.search.flush() + + def _check_results(sort_by, origins): + page = self.search.origin_search(url_pattern="foobar", sort_by=sort_by) + results = [r["url"] for r in page.results] + assert results == [origin["url"] for origin in origins] + + _check_results(["nb_visits"], ORIGINS) + _check_results(["-nb_visits"], ORIGINS[::-1]) + + _check_results(["last_visit_date"], ORIGINS) + _check_results(["-last_visit_date"], ORIGINS[::-1]) + + _check_results(["nb_visits", "-last_visit_date"], ORIGINS) + _check_results(["-last_visit_date", "nb_visits"], ORIGINS[::-1]) + def test_origin_update_with_no_visit_types(self): """ Update an origin with visit types first then with no visit types,