diff --git a/swh/core/api/classes.py b/swh/core/api/classes.py --- a/swh/core/api/classes.py +++ b/swh/core/api/classes.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from dataclasses import dataclass, field +import itertools from typing import Callable, Generic, Iterable, List, Optional, TypeVar TResult = TypeVar("TResult") @@ -18,6 +19,16 @@ next_page_token: Optional[TToken] = field(default=None) +def _stream_results(f, *args, page_token, **kwargs): + """Helper for stream_results() and stream_results_optional()""" + while True: + page_result = f(*args, page_token=page_token, **kwargs) + yield from page_result.results + page_token = page_result.next_page_token + if page_token is None: + break + + def stream_results( f: Callable[..., PagedResult[TResult, TToken]], *args, **kwargs ) -> Iterable[TResult]: @@ -26,10 +37,25 @@ """ if "page_token" in kwargs: raise TypeError('stream_results has no argument "page_token".') - page_token = None - while True: - page_result = f(*args, page_token=page_token, **kwargs) - yield from page_result.results - page_token = page_result.next_page_token - if page_token is None: - break + yield from _stream_results(f, *args, page_token=None, **kwargs) + + +def stream_results_optional( + f: Callable[..., Optional[PagedResult[TResult, TToken]]], *args, **kwargs +) -> Optional[Iterable[TResult]]: + """Like stream_results(), but for functions ``f`` that return an Optional. + + """ + if "page_token" in kwargs: + raise TypeError('stream_results_optional has no argument "page_token".') + res = f(*args, page_token=None, **kwargs) + if res is None: + return None + else: + if res.next_page_token is None: + return iter(res.results) + else: + return itertools.chain( + res.results, + _stream_results(f, *args, page_token=res.next_page_token, **kwargs), + ) diff --git a/swh/core/api/tests/test_classes.py b/swh/core/api/tests/test_classes.py --- a/swh/core/api/tests/test_classes.py +++ b/swh/core/api/tests/test_classes.py @@ -3,10 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import TypeVar +from typing import Optional, TypeVar + +import pytest from swh.core.api.classes import PagedResult as CorePagedResult -from swh.core.api.classes import stream_results +from swh.core.api.classes import stream_results, stream_results_optional T = TypeVar("T") TestPagedResult = CorePagedResult[T, bytes] @@ -21,7 +23,27 @@ assert list(actual_data) == [] -def test_stream_results_no_pagination(): +def test_stream_results_optional(): + def paged_results(page_token) -> Optional[TestPagedResult]: + return None + + # Should be None, not an empty iterator! + actual_data = stream_results_optional(paged_results) + assert actual_data is None + + +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_kwarg(stream_results): + def paged_results(page_token): + assert False, "should not be called" + + with pytest.raises(TypeError): + actual_data = stream_results(paged_results, page_token=42) + list(actual_data) + + +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_no_pagination(stream_results): input_data = [ {"url": "something"}, {"url": "something2"}, @@ -35,7 +57,8 @@ assert list(actual_data) == input_data -def test_stream_results_pagination(): +@pytest.mark.parametrize("stream_results", [stream_results, stream_results_optional]) +def test_stream_results_pagination(stream_results): input_data = [ {"url": "something"}, {"url": "something2"},