diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -19,14 +19,8 @@ from swh.scheduler.utils import utcnow from .exc import SchedulerException, StaleData, UnknownPolicy -from .model import ( - ListedOrigin, - ListedOriginPageToken, - Lister, - OriginVisitStats, - PaginatedListedOriginList, - SchedulerMetrics, -) +from .interface import ListedOriginPageToken, PaginatedListedOriginList +from .model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics logger = logging.getLogger(__name__) @@ -309,7 +303,7 @@ origins = [ListedOrigin(**d) for d in cur] if len(origins) == limit: - page_token = (origins[-1].lister_id, origins[-1].url) + page_token = (str(origins[-1].lister_id), origins[-1].url) else: page_token = None diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -5,20 +5,32 @@ import datetime -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import UUID from typing_extensions import Protocol, runtime_checkable from swh.core.api import remote_api_endpoint -from swh.scheduler.model import ( - ListedOrigin, - ListedOriginPageToken, - Lister, - OriginVisitStats, - PaginatedListedOriginList, - SchedulerMetrics, -) +from swh.core.api.classes import PagedResult +from swh.scheduler.model import ListedOrigin, Lister, OriginVisitStats, SchedulerMetrics + +ListedOriginPageToken = Tuple[str, str] + + +class PaginatedListedOriginList(PagedResult[ListedOrigin, ListedOriginPageToken]): + """A list of listed origins, with a continuation token""" + + def __init__( + self, + results: List[ListedOrigin], + next_page_token: Union[None, ListedOriginPageToken, List[str]], + ): + parsed_next_page_token: Optional[Tuple[str, str]] = None + if next_page_token is not None: + if len(next_page_token) != 2: + raise TypeError("Expected Tuple[str, str] or list of size 2.") + parsed_next_page_token = tuple(next_page_token) # type: ignore + super().__init__(results, parsed_next_page_token) @runtime_checkable diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py --- a/swh/scheduler/model.py +++ b/swh/scheduler/model.py @@ -4,7 +4,7 @@ # See top-level LICENSE file for more information import datetime -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple from uuid import UUID import attr @@ -177,37 +177,6 @@ } -ListedOriginPageToken = Tuple[UUID, str] - - -def convert_listed_origin_page_token( - input: Union[None, ListedOriginPageToken, List[Union[UUID, str]]] -) -> Optional[ListedOriginPageToken]: - if input is None: - return None - - if isinstance(input, tuple): - return input - - x, y = input - assert isinstance(x, UUID) - assert isinstance(y, str) - return (x, y) - - -@attr.s -class PaginatedListedOriginList(BaseSchedulerModel): - """A list of listed origins, with a continuation token""" - - origins = attr.ib(type=List[ListedOrigin], validator=[type_validator()]) - next_page_token = attr.ib( - type=Optional[ListedOriginPageToken], - validator=[type_validator()], - converter=convert_listed_origin_page_token, - default=None, - ) - - @attr.s(frozen=True, slots=True) class OriginVisitStats(BaseSchedulerModel): """Represents an aggregated origin visits view. diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -16,13 +16,8 @@ from swh.model.hashutil import hash_to_bytes from swh.scheduler.exc import SchedulerException, StaleData, UnknownPolicy -from swh.scheduler.interface import SchedulerInterface -from swh.scheduler.model import ( - ListedOrigin, - ListedOriginPageToken, - OriginVisitStats, - SchedulerMetrics, -) +from swh.scheduler.interface import ListedOriginPageToken, SchedulerInterface +from swh.scheduler.model import ListedOrigin, OriginVisitStats, SchedulerMetrics from swh.scheduler.utils import utcnow from .common import LISTERS, TASK_TYPES, TEMPLATES, tasks_from_template @@ -713,9 +708,9 @@ ) assert ret.next_page_token is None - assert len(ret.origins) == 1 - assert ret.origins[0].lister_id == origin.lister_id - assert ret.origins[0].url == origin.url + assert len(ret.results) == 1 + assert ret.results[0].lister_id == origin.lister_id + assert ret.results[0].url == origin.url @pytest.mark.parametrize("num_origins,limit", [(20, 6), (5, 42), (20, 20)]) def test_get_listed_origins_limit( @@ -736,7 +731,7 @@ limit=limit, page_token=next_page_token, ) - returned_origins.extend(ret.origins) + returned_origins.extend(ret.results) next_page_token = ret.next_page_token if next_page_token is None: break @@ -753,7 +748,7 @@ ret = swh_scheduler.get_listed_origins(limit=len(listed_origins) + 1) assert ret.next_page_token is None - assert len(ret.origins) == len(listed_origins) + assert len(ret.results) == len(listed_origins) def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type): """Basic origins setup for scheduling policy tests""" diff --git a/swh/scheduler/tests/test_simulator.py b/swh/scheduler/tests/test_simulator.py --- a/swh/scheduler/tests/test_simulator.py +++ b/swh/scheduler/tests/test_simulator.py @@ -5,6 +5,7 @@ import pytest +from swh.core.api.classes import stream_results import swh.scheduler.simulator as simulator from swh.scheduler.tests.common import TASK_TYPES @@ -18,9 +19,8 @@ simulator.fill_test_data(swh_scheduler, num_origins=NUM_ORIGINS) - res = swh_scheduler.get_listed_origins() - assert len(res.origins) == NUM_ORIGINS - assert res.next_page_token is None + origins = list(stream_results(swh_scheduler.get_listed_origins)) + assert len(origins) == NUM_ORIGINS res = swh_scheduler.search_tasks() assert len(res) == NUM_ORIGINS