diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py index c84db28..17ed731 100644 --- a/swh/core/api/classes.py +++ b/swh/core/api/classes.py @@ -1,35 +1,61 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass, field +import itertools from typing import Callable, Generic, Iterable, List, Optional, TypeVar TResult = TypeVar("TResult") TToken = TypeVar("TToken") @dataclass(eq=True) class PagedResult(Generic[TResult, TToken]): """Represents a page of results; with a token to get the next page""" results: List[TResult] = field(default_factory=list) next_page_token: Optional[TToken] = field(default=None) +def _stream_results(f, *args, page_token, **kwargs): + """Helper for stream_results() and stream_results_optional()""" + while True: + page_result = f(*args, page_token=page_token, **kwargs) + yield from page_result.results + page_token = page_result.next_page_token + if page_token is None: + break + + def stream_results( f: Callable[..., PagedResult[TResult, TToken]], *args, **kwargs ) -> Iterable[TResult]: """Consume the paginated result and stream the page results """ if "page_token" in kwargs: raise TypeError('stream_results has no argument "page_token".') - page_token = None - while True: - page_result = f(*args, page_token=page_token, **kwargs) - yield from page_result.results - page_token = page_result.next_page_token - if page_token is None: - break + yield from _stream_results(f, *args, page_token=None, **kwargs) + + +def stream_results_optional( + f: Callable[..., Optional[PagedResult[TResult, TToken]]], *args, **kwargs +) -> Optional[Iterable[TResult]]: + """Like stream_results(), but for functions ``f`` that return an Optional. + + """ + if "page_token" in kwargs: + raise TypeError('stream_results_optional has no argument "page_token".') + res = f(*args, page_token=None, **kwargs) + if res is None: + return None + else: + if res.next_page_token is None: + return iter(res.results) + else: + return itertools.chain( + res.results, + _stream_results(f, *args, page_token=res.next_page_token, **kwargs), + ) diff --git a/swh/core/api/tests/test_classes.py b/swh/core/api/tests/test_classes.py index 7509f51..9fb2dc1 100644 --- a/swh/core/api/tests/test_classes.py +++ b/swh/core/api/tests/test_classes.py @@ -1,60 +1,83 @@ # Copyright (C) 2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import TypeVar +from typing import Optional, TypeVar + +import pytest from swh.core.api.classes import PagedResult as CorePagedResult -from swh.core.api.classes import stream_results +from swh.core.api.classes import stream_results, stream_results_optional T = TypeVar("T") TestPagedResult = CorePagedResult[T, bytes] def test_stream_results_no_result(): def paged_results(page_token) -> TestPagedResult: return TestPagedResult(results=[], next_page_token=None) # only 1 call, no pagination actual_data = stream_results(paged_results) assert list(actual_data) == [] -def test_stream_results_no_pagination(): +def test_stream_results_optional(): + def paged_results(page_token) -> Optional[TestPagedResult]: + return None + + # Should be None, not an empty iterator! + actual_data = stream_results_optional(paged_results) + assert actual_data is None + + +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_kwarg(stream_results): + def paged_results(page_token): + assert False, "should not be called" + + with pytest.raises(TypeError): + actual_data = stream_results(paged_results, page_token=42) + list(actual_data) + + +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_no_pagination(stream_results): input_data = [ {"url": "something"}, {"url": "something2"}, ] def paged_results(page_token) -> TestPagedResult: return TestPagedResult(results=input_data, next_page_token=None) # only 1 call, no pagination actual_data = stream_results(paged_results) assert list(actual_data) == input_data -def test_stream_results_pagination(): +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_pagination(stream_results): input_data = [ {"url": "something"}, {"url": "something2"}, ] input_data2 = [ {"url": "something3"}, ] input_data3 = [ {"url": "something4"}, ] def page_results2(page_token=None) -> TestPagedResult: result_per_token = { None: TestPagedResult(results=input_data, next_page_token=b"two"), b"two": TestPagedResult(results=input_data2, next_page_token=b"three"), b"three": TestPagedResult(results=input_data3, next_page_token=None), } return result_per_token[page_token] # multiple calls to solve the pagination calls actual_data = stream_results(page_results2) assert list(actual_data) == input_data + input_data2 + input_data3