diff --git a/requirements-swh.txt b/requirements-swh.txt --- a/requirements-swh.txt +++ b/requirements-swh.txt @@ -1,4 +1,4 @@ # Add here internal Software Heritage dependencies, one per line. -swh.core[http] +swh.core[http] >= 0.2.0 swh.journal >= 0.1.0 swh.model diff --git a/swh/search/api/client.py b/swh/search/api/client.py --- a/swh/search/api/client.py +++ b/swh/search/api/client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/search/api/server.py b/swh/search/api/server.py --- a/swh/search/api/server.py +++ b/swh/search/api/server.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/search/elasticsearch.py b/swh/search/elasticsearch.py --- a/swh/search/elasticsearch.py +++ b/swh/search/elasticsearch.py @@ -1,18 +1,21 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 +import msgpack + from typing import Any, Iterable, Dict, List, Iterator, Optional from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk, scan -import msgpack from swh.core.api import remote_api_endpoint -from swh.model import model from swh.model.identifiers import origin_identifier +from swh.model import model + +from swh.search.interface import PagedResult def _sanitize_origin(origin): @@ -24,6 +27,21 @@ return res +def token_encode(index_to_tokenize: Dict[bytes, Any]) -> str: + """Tokenize as string an index page result from a search + + """ + page_token = base64.b64encode(msgpack.dumps(index_to_tokenize)) + return page_token.decode() + + +def token_decode(page_token: str) -> Dict[bytes, Any]: + """Read the page_token + + """ + return msgpack.loads(base64.b64decode(page_token.encode()), raw=True) + + class ElasticSearch: def __init__(self, hosts: List[str]): self._backend = Elasticsearch(hosts=hosts) @@ -106,31 +124,27 @@ def origin_search( self, *, - url_pattern: str = None, + url_pattern: Optional[str] = None, metadata_pattern: str = None, with_visit: bool = False, - page_token: str = None, - count: int = 50, - ) -> Dict[str, object]: + page_token: Optional[str] = None, + limit: int = 50, + ) -> PagedResult[Dict[str, Any]]: """Searches for origins matching the `url_pattern`. Args: - url_pattern (str): Part of thr URL to search for - with_visit (bool): Whether origins with no visit are to be - filtered out - page_token (str): Opaque value used for pagination. - count (int): number of results to return. + url_pattern: Part of the URL to search for + with_visit: Whether origins with no visit are to be + filtered out + page_token: Opaque value used for pagination + limit: number of results to return Returns: - a dictionary with keys: - * `next_page_token`: - opaque value used for fetching more results. `None` if there - are no more result. - * `results`: - list of dictionaries with key: - * `url`: URL of a matching origin + PagedResult of origin dicts matching the search criteria. If next_page_token + is None, there is no longer data to retrieve. + """ - query_clauses = [] # type: List[Dict[str, Any]] + query_clauses: List[Dict[str, Any]] = [] if url_pattern: query_clauses.append( @@ -169,45 +183,38 @@ "At least one of url_pattern and metadata_pattern must be provided." ) + next_page_token = None + if with_visit: query_clauses.append({"term": {"has_visits": True,}}) body = { "query": {"bool": {"must": query_clauses,}}, - "size": count, "sort": [{"_score": "desc"}, {"sha1": "asc"},], } if page_token: # TODO: use ElasticSearch's scroll API? - page_token_content = msgpack.loads(base64.b64decode(page_token), raw=True) + page_token_content = token_decode(page_token) body["search_after"] = [ page_token_content[b"score"], page_token_content[b"sha1"].decode("ascii"), ] - res = self._backend.search(index="origin", body=body, size=count,) + res = self._backend.search(index="origin", body=body, size=limit) hits = res["hits"]["hits"] - if len(hits) == count: + if len(hits) == limit: last_hit = hits[-1] next_page_token_content = { b"score": last_hit["_score"], b"sha1": last_hit["_source"]["sha1"], } - next_page_token = base64.b64encode( - msgpack.dumps(next_page_token_content) - ) # type: Optional[bytes] - else: - next_page_token = None - - return { - "next_page_token": next_page_token, - "results": [ - { - # TODO: also add 'id'? - "url": hit["_source"]["url"], - } - for hit in hits - ], - } + next_page_token = token_encode(next_page_token_content) + + assert len(hits) <= limit + + return PagedResult( + results=[{"url": hit["_source"]["url"]} for hit in hits], + next_page_token=next_page_token, + ) diff --git a/swh/search/in_memory.py b/swh/search/in_memory.py --- a/swh/search/in_memory.py +++ b/swh/search/in_memory.py @@ -1,19 +1,19 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import base64 -from collections import defaultdict import itertools import re -from typing import Any, Dict, Iterable, Iterator, List, Optional -import msgpack +from collections import defaultdict +from typing import Any, Dict, Iterable, Iterator, List, Optional from swh.core.api import remote_api_endpoint from swh.model.identifiers import origin_identifier +from swh.search.interface import PagedResult + def _sanitize_origin(origin): origin = origin.copy() @@ -38,8 +38,8 @@ del self._origin_ids def initialize(self) -> None: - self._origins = defaultdict(dict) # type: Dict[str, Dict[str, Any]] - self._origin_ids = [] # type: List[str] + self._origins: Dict[str, Dict[str, Any]] = defaultdict(dict) + self._origin_ids: List[str] = [] def flush(self) -> None: pass @@ -61,15 +61,15 @@ def origin_search( self, *, - url_pattern: str = None, - metadata_pattern: str = None, + url_pattern: Optional[str] = None, + metadata_pattern: Optional[str] = None, with_visit: bool = False, - page_token: str = None, - count: int = 50, - ) -> Dict[str, object]: - matches = ( + page_token: Optional[str] = None, + limit: int = 50, + ) -> PagedResult[Dict[str, Any]]: + hits: Iterator[Dict[str, Any]] = ( self._origins[id_] for id_ in self._origin_ids - ) # type: Iterator[Dict[str, Any]] + ) if url_pattern: tokens = set(self._url_splitter.split(url_pattern)) @@ -88,7 +88,7 @@ for token in match["_url_tokens"] ) - matches = filter(predicate, matches) + hits = filter(predicate, hits) if metadata_pattern: raise NotImplementedError( @@ -100,28 +100,24 @@ "At least one of url_pattern and metadata_pattern must be provided." ) + next_page_token = None + if with_visit: - matches = filter(lambda o: o.get("has_visits"), matches) - - if page_token: - page_token_content = msgpack.loads(base64.b64decode(page_token)) - start_at_index = page_token_content[b"start_at_index"] - else: - start_at_index = 0 - - hits = list(itertools.islice(matches, start_at_index, start_at_index + count)) - - if len(hits) == count: - next_page_token_content = { - b"start_at_index": start_at_index + count, - } - next_page_token = base64.b64encode( - msgpack.dumps(next_page_token_content) - ) # type: Optional[bytes] - else: - next_page_token = None - - return { - "next_page_token": next_page_token, - "results": [{"url": hit["url"]} for hit in hits], - } + hits = filter(lambda o: o.get("has_visits"), hits) + + start_at_index = int(page_token) if page_token else 0 + + origins = [ + {"url": hit["url"]} + for hit in itertools.islice( + hits, start_at_index, start_at_index + limit + 1 + ) + ] + + if len(origins) > limit: + next_page_token = str(start_at_index + limit) + origins = origins[:limit] + + assert len(origins) <= limit + + return PagedResult(results=origins, next_page_token=next_page_token,) diff --git a/swh/search/interface.py b/swh/search/interface.py new file mode 100644 --- /dev/null +++ b/swh/search/interface.py @@ -0,0 +1,12 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from typing import TypeVar + +from swh.core.api.classes import PagedResult as CorePagedResult + + +TResult = TypeVar("TResult") +PagedResult = CorePagedResult[TResult, str] diff --git a/swh/search/tests/test_cli.py b/swh/search/tests/test_cli.py --- a/swh/search/tests/test_cli.py +++ b/swh/search/tests/test_cli.py @@ -16,6 +16,8 @@ from swh.search.cli import cli +from swh.search.tests.utils import assert_page_match + CLI_CONFIG = """ search: @@ -82,17 +84,14 @@ swh_search.flush() # searching origin without visit as requirement - results = swh_search.origin_search(url_pattern="foobar") + actual_page = swh_search.origin_search(url_pattern="foobar") # We find it - assert results == { - "next_page_token": None, - "results": [{"url": "http://foobar.baz"}], - } + assert_page_match(actual_page, [{"url": "http://foobar.baz"}]) # It's an origin with no visit, searching for it with visit - results = swh_search.origin_search(url_pattern="foobar", with_visit=True) + actual_page = swh_search.origin_search(url_pattern="foobar", with_visit=True) # returns nothing - assert results == {"next_page_token": None, "results": []} + assert_page_match(actual_page, []) def test__journal_client__origin_visit( @@ -100,6 +99,7 @@ ): """Tests the re-indexing when origin_batch_size*task_batch_size is a divisor of nb_origins.""" + origin_foobar = {"url": "http://baz.foobar"} producer = Producer( { "bootstrap.servers": kafka_server, @@ -108,7 +108,7 @@ } ) topic = f"{kafka_prefix}.origin_visit" - value = value_to_kafka({"origin": "http://baz.foobar",}) + value = value_to_kafka({"origin": origin_foobar["url"]}) producer.produce(topic=topic, key=b"bogus-origin-visit", value=value) journal_objects_config = JOURNAL_OBJECTS_CONFIG_TEMPLATE.format( @@ -128,15 +128,12 @@ swh_search.flush() - expected_result = { - "next_page_token": None, - "results": [{"url": "http://baz.foobar"}], - } # Both search returns the visit - results = swh_search.origin_search(url_pattern="foobar", with_visit=False) - assert results == expected_result - results = swh_search.origin_search(url_pattern="foobar", with_visit=True) - assert results == expected_result + actual_page = swh_search.origin_search(url_pattern="foobar", with_visit=False) + assert_page_match(actual_page, [origin_foobar]) + + actual_page = swh_search.origin_search(url_pattern="foobar", with_visit=True) + assert_page_match(actual_page, [origin_foobar]) def test__journal_client__missing_main_journal_config_key(elasticsearch_host): diff --git a/swh/search/tests/test_search.py b/swh/search/tests/test_search.py --- a/swh/search/tests/test_search.py +++ b/swh/search/tests/test_search.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -7,125 +7,105 @@ from swh.search.utils import stream_results +from swh.search.tests.utils import assert_page_match, assert_results_match + class CommonSearchTest: def test_origin_url_unique_word_prefix(self): - self.search.origin_update( - [ - {"url": "http://foobar.baz"}, - {"url": "http://barbaz.qux"}, - {"url": "http://qux.quux"}, - ] - ) + origin_foobar_baz = {"url": "http://foobar.baz"} + origin_barbaz_qux = {"url": "http://barbaz.qux"} + origin_qux_quux = {"url": "http://qux.quux"} + origins = [origin_foobar_baz, origin_barbaz_qux, origin_qux_quux] + + self.search.origin_update(origins) self.search.flush() - results = self.search.origin_search(url_pattern="foobar") - assert results == { - "next_page_token": None, - "results": [{"url": "http://foobar.baz"}], - } + actual_page = self.search.origin_search(url_pattern="foobar") + assert_page_match(actual_page, [origin_foobar_baz]) - results = self.search.origin_search(url_pattern="barb") - assert results == { - "next_page_token": None, - "results": [{"url": "http://barbaz.qux"}], - } + actual_page = self.search.origin_search(url_pattern="barb") + assert_page_match(actual_page, [origin_barbaz_qux]) # 'bar' is part of 'foobar', but is not the beginning of it - results = self.search.origin_search(url_pattern="bar") - assert results == { - "next_page_token": None, - "results": [{"url": "http://barbaz.qux"}], - } - - results = self.search.origin_search(url_pattern="barbaz") - assert results == { - "next_page_token": None, - "results": [{"url": "http://barbaz.qux"}], - } + actual_page = self.search.origin_search(url_pattern="bar") + assert_page_match(actual_page, [origin_barbaz_qux]) + + actual_page = self.search.origin_search(url_pattern="barbaz") + assert_page_match(actual_page, [origin_barbaz_qux]) def test_origin_url_unique_word_prefix_multiple_results(self): + origin_foobar_baz = {"url": "http://foobar.baz"} + origin_barbaz_qux = {"url": "http://barbaz.qux"} + origin_qux_quux = {"url": "http://qux.quux"} + self.search.origin_update( - [ - {"url": "http://foobar.baz"}, - {"url": "http://barbaz.qux"}, - {"url": "http://qux.quux"}, - ] + [origin_foobar_baz, origin_barbaz_qux, origin_qux_quux] ) self.search.flush() - results = self.search.origin_search(url_pattern="qu") - assert results["next_page_token"] is None - - results = [res["url"] for res in results["results"]] - expected_results = ["http://qux.quux", "http://barbaz.qux"] - assert sorted(results) == sorted(expected_results) + actual_page = self.search.origin_search(url_pattern="qu") + assert_page_match(actual_page, [origin_qux_quux, origin_barbaz_qux]) - results = self.search.origin_search(url_pattern="qux") - assert results["next_page_token"] is None - - results = [res["url"] for res in results["results"]] - expected_results = ["http://barbaz.qux", "http://qux.quux"] - assert sorted(results) == sorted(expected_results) + actual_page = self.search.origin_search(url_pattern="qux") + assert_page_match(actual_page, [origin_qux_quux, origin_barbaz_qux]) def test_origin_url_all_terms(self): - self.search.origin_update( - [{"url": "http://foo.bar/baz"}, {"url": "http://foo.bar/foo.bar"},] - ) + origin_foo_bar_baz = {"url": "http://foo.bar/baz"} + origin_foo_bar_foo_bar = {"url": "http://foo.bar/foo.bar"} + origins = [origin_foo_bar_baz, origin_foo_bar_foo_bar] + + self.search.origin_update(origins) self.search.flush() # Only results containing all terms should be returned. - results = self.search.origin_search(url_pattern="foo bar baz") - assert results == { - "next_page_token": None, - "results": [{"url": "http://foo.bar/baz"},], - } + actual_page = self.search.origin_search(url_pattern="foo bar baz") + assert_page_match(actual_page, [origin_foo_bar_baz]) def test_origin_with_visit(self): + origin_foobar_baz = {"url": "http://foobar/baz"} + self.search.origin_update( - [{"url": "http://foobar.baz", "has_visits": True},] + [{**o, "has_visits": True} for o in [origin_foobar_baz]] ) self.search.flush() - results = self.search.origin_search(url_pattern="foobar", with_visit=True) - assert results == { - "next_page_token": None, - "results": [{"url": "http://foobar.baz"}], - } + actual_page = self.search.origin_search(url_pattern="foobar", with_visit=True) + assert_page_match(actual_page, [origin_foobar_baz]) def test_origin_with_visit_added(self): - self.search.origin_update( - [{"url": "http://foobar.baz"},] - ) + origin_foobar_baz = {"url": "http://foobar.baz"} + + self.search.origin_update([origin_foobar_baz]) self.search.flush() - results = self.search.origin_search(url_pattern="foobar", with_visit=True) - assert results == {"next_page_token": None, "results": []} + actual_page = self.search.origin_search(url_pattern="foobar", with_visit=True) + assert_page_match(actual_page, []) self.search.origin_update( - [{"url": "http://foobar.baz", "has_visits": True},] + [{**o, "has_visits": True} for o in [origin_foobar_baz]] ) self.search.flush() - results = self.search.origin_search(url_pattern="foobar", with_visit=True) - assert results == { - "next_page_token": None, - "results": [{"url": "http://foobar.baz"}], - } + actual_page = self.search.origin_search(url_pattern="foobar", with_visit=True) + assert_page_match(actual_page, [origin_foobar_baz]) def test_origin_intrinsic_metadata_description(self): + origin1_nothin = {"url": "http://origin1"} + origin2_foobar = {"url": "http://origin2"} + origin3_barbaz = {"url": "http://origin3"} + self.search.origin_update( [ - {"url": "http://origin1", "intrinsic_metadata": {},}, + {**origin1_nothin, "intrinsic_metadata": {},}, { - "url": "http://origin2", + **origin2_foobar, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "description": "foo bar", }, }, { - "url": "http://origin3", + **origin3_barbaz, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "description": "bar baz", @@ -135,36 +115,30 @@ ) self.search.flush() - results = self.search.origin_search(metadata_pattern="foo") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin2"}], - } + actual_page = self.search.origin_search(metadata_pattern="foo") + assert_page_match(actual_page, [origin2_foobar]) - results = self.search.origin_search(metadata_pattern="foo bar") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin2"}], - } + actual_page = self.search.origin_search(metadata_pattern="foo bar") + assert_page_match(actual_page, [origin2_foobar]) - results = self.search.origin_search(metadata_pattern="bar baz") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin3"}], - } + actual_page = self.search.origin_search(metadata_pattern="bar baz") + assert_page_match(actual_page, [origin3_barbaz]) def test_origin_intrinsic_metadata_all_terms(self): + origin1_foobarfoobar = {"url": "http://origin1"} + origin3_foobarbaz = {"url": "http://origin2"} + self.search.origin_update( [ { - "url": "http://origin1", + **origin1_foobarfoobar, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "description": "foo bar foo bar", }, }, { - "url": "http://origin3", + **origin3_foobarbaz, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "description": "foo bar baz", @@ -174,25 +148,26 @@ ) self.search.flush() - results = self.search.origin_search(metadata_pattern="foo bar baz") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin3"}], - } + actual_page = self.search.origin_search(metadata_pattern="foo bar baz") + assert_page_match(actual_page, [origin3_foobarbaz]) def test_origin_intrinsic_metadata_nested(self): + origin1_nothin = {"url": "http://origin1"} + origin2_foobar = {"url": "http://origin2"} + origin3_barbaz = {"url": "http://origin3"} + self.search.origin_update( [ - {"url": "http://origin1", "intrinsic_metadata": {},}, + {**origin1_nothin, "intrinsic_metadata": {},}, { - "url": "http://origin2", + **origin2_foobar, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "keywords": ["foo", "bar"], }, }, { - "url": "http://origin3", + **origin3_barbaz, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "keywords": ["bar", "baz"], @@ -202,23 +177,14 @@ ) self.search.flush() - results = self.search.origin_search(metadata_pattern="foo") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin2"}], - } + actual_page = self.search.origin_search(metadata_pattern="foo") + assert_page_match(actual_page, [origin2_foobar]) - results = self.search.origin_search(metadata_pattern="foo bar") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin2"}], - } + actual_page = self.search.origin_search(metadata_pattern="foo bar") + assert_page_match(actual_page, [origin2_foobar]) - results = self.search.origin_search(metadata_pattern="bar baz") - assert results == { - "next_page_token": None, - "results": [{"url": "http://origin3"}], - } + actual_page = self.search.origin_search(metadata_pattern="bar baz") + assert_page_match(actual_page, [origin3_barbaz]) # TODO: add more tests with more codemeta terms @@ -226,71 +192,58 @@ @settings(deadline=None) @given(strategies.integers(min_value=1, max_value=4)) - def test_origin_url_paging(self, count): + def test_origin_url_paging(self, limit): # TODO: no hypothesis + origin1_foo = {"url": "http://origin1/foo"} + origin2_foobar = {"url": "http://origin2/foo/bar"} + origin3_foobarbaz = {"url": "http://origin3/foo/bar/baz"} + self.reset() - self.search.origin_update( - [ - {"url": "http://origin1/foo"}, - {"url": "http://origin2/foo/bar"}, - {"url": "http://origin3/foo/bar/baz"}, - ] - ) + self.search.origin_update([origin1_foo, origin2_foobar, origin3_foobarbaz]) self.search.flush() results = stream_results( - self.search.origin_search, url_pattern="foo bar baz", count=count + self.search.origin_search, url_pattern="foo bar baz", limit=limit ) - results = [res["url"] for res in results] - expected_results = [ - "http://origin3/foo/bar/baz", - ] - assert sorted(results[0 : len(expected_results)]) == sorted(expected_results) + assert_results_match(results, [origin3_foobarbaz]) results = stream_results( - self.search.origin_search, url_pattern="foo bar", count=count + self.search.origin_search, url_pattern="foo bar", limit=limit ) - expected_results = [ - "http://origin2/foo/bar", - "http://origin3/foo/bar/baz", - ] - results = [res["url"] for res in results] - assert sorted(results[0 : len(expected_results)]) == sorted(expected_results) + assert_results_match(results, [origin2_foobar, origin3_foobarbaz]) results = stream_results( - self.search.origin_search, url_pattern="foo", count=count + self.search.origin_search, url_pattern="foo", limit=limit ) - expected_results = [ - "http://origin1/foo", - "http://origin2/foo/bar", - "http://origin3/foo/bar/baz", - ] - results = [res["url"] for res in results] - assert sorted(results[0 : len(expected_results)]) == sorted(expected_results) + assert_results_match(results, [origin1_foo, origin2_foobar, origin3_foobarbaz]) @settings(deadline=None) @given(strategies.integers(min_value=1, max_value=4)) - def test_origin_intrinsic_metadata_paging(self, count): + def test_origin_intrinsic_metadata_paging(self, limit): # TODO: no hypothesis + origin1_foo = {"url": "http://origin1/foo"} + origin2_foobar = {"url": "http://origin2/foo/bar"} + origin3_foobarbaz = {"url": "http://origin3/foo/bar/baz"} + self.reset() self.search.origin_update( [ { - "url": "http://origin1", + **origin1_foo, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "keywords": ["foo"], }, }, { - "url": "http://origin2", + **origin2_foobar, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "keywords": ["foo", "bar"], }, }, { - "url": "http://origin3", + **origin3_foobarbaz, "intrinsic_metadata": { "@context": "https://doi.org/10.5063/schema/codemeta-2.0", "keywords": ["foo", "bar", "baz"], @@ -301,20 +254,16 @@ self.search.flush() results = stream_results( - self.search.origin_search, metadata_pattern="foo bar baz", count=count + self.search.origin_search, metadata_pattern="foo bar baz", limit=limit ) - assert list(results) == [{"url": "http://origin3"}] + assert_results_match(results, [origin3_foobarbaz]) results = stream_results( - self.search.origin_search, metadata_pattern="foo bar", count=count + self.search.origin_search, metadata_pattern="foo bar", limit=limit ) - assert list(results) == [{"url": "http://origin2"}, {"url": "http://origin3"}] + assert_results_match(results, [origin2_foobar, origin3_foobarbaz]) results = stream_results( - self.search.origin_search, metadata_pattern="foo", count=count + self.search.origin_search, metadata_pattern="foo", limit=limit ) - assert list(results) == [ - {"url": "http://origin1"}, - {"url": "http://origin2"}, - {"url": "http://origin3"}, - ] + assert_results_match(results, [origin1_foo, origin2_foobar, origin3_foobarbaz]) diff --git a/swh/search/tests/test_utils.py b/swh/search/tests/test_utils.py new file mode 100644 --- /dev/null +++ b/swh/search/tests/test_utils.py @@ -0,0 +1,51 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from swh.search.interface import PagedResult + +from .utils import assert_results_match, assert_page_match + + +def test_assert_result_match(): + actual_result = [{"url": "origin1"}, {"url": "origin2"}, {"url": "origin3"}] + + # order does not count, the results match + assert_results_match( + actual_result, [{"url": "origin2"}, {"url": "origin3"}, {"url": "origin1"}] + ) + + with pytest.raises(AssertionError): + assert_results_match(actual_result, [{"url": "origin1"}]) + + +def test_assert_page_match(): + actual_page = PagedResult( + results=[{"url": "origin1"}, {"url": "origin2"}], next_page_token=None + ) + + # match ok + assert_page_match(actual_page, [{"url": "origin1"}, {"url": "origin2"}]) + assert_page_match( + actual_page, [{"url": "origin1"}, {"url": "origin2"}], expected_page_token=False + ) + + # ko, expected_page_token expected but it's None, raises! + with pytest.raises(AssertionError): + assert_page_match( + actual_page, + [{"url": "origin1"}, {"url": "origin2"}], + expected_page_token=True, + ) + + # ko, results mismatch! Raises! + with pytest.raises(AssertionError): + assert_page_match(actual_page, [{"url": "origin1"}]) + + actual_page = PagedResult(results=[{"url": "origin1"}], next_page_token="something") + + # ok, expected token is not None, results match as well + assert_page_match(actual_page, [{"url": "origin1"}], expected_page_token=True) diff --git a/swh/search/tests/utils.py b/swh/search/tests/utils.py new file mode 100644 --- /dev/null +++ b/swh/search/tests/utils.py @@ -0,0 +1,26 @@ +# Copyright (C) 2020 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.search.interface import PagedResult +from typing import Dict, Iterable, List + + +def assert_results_match(actual_origins: Iterable[Dict], expected_origins: List[Dict]): + actual_urls = set(r["url"] for r in actual_origins) + assert actual_urls == set(o["url"] for o in expected_origins) + + +def assert_page_match( + actual_page: PagedResult, + expected_origins: List[Dict], + expected_page_token: bool = False, +): + + if expected_page_token: + assert actual_page.next_page_token is not None + else: + assert actual_page.next_page_token is None + + assert_results_match(actual_page.results, expected_origins) diff --git a/swh/search/utils.py b/swh/search/utils.py --- a/swh/search/utils.py +++ b/swh/search/utils.py @@ -1,16 +1,19 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information def stream_results(f, *args, **kwargs): + """Consume the paginated result and stream it directly + + """ if "page_token" in kwargs: raise TypeError('stream_results has no argument "page_token".') page_token = None while True: results = f(*args, page_token=page_token, **kwargs) - yield from results["results"] - page_token = results["next_page_token"] + yield from results.results + page_token = results.next_page_token if page_token is None: break