diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py --- a/swh/lister/pattern.py +++ b/swh/lister/pattern.py @@ -84,6 +84,8 @@ identifies the :attr:`LISTER_NAME`, the second level the lister :attr:`instance`. The final level is a list of dicts containing the expected credentials for the given instance of that lister. + max_pages: the maximum number of pages listed in a full listing operation + max_origins_per_page: the maximum number of origins processed per page Generic types: - *StateType*: concrete lister type; should usually be a :class:`dataclass` for @@ -102,6 +104,8 @@ url: str, instance: Optional[str] = None, credentials: CredentialsType = None, + max_origins_per_page: Optional[int] = None, + max_pages: Optional[int] = None, with_github_session: bool = False, ): if not self.LISTER_NAME: @@ -140,6 +144,8 @@ ) self.recorded_origins: Set[str] = set() + self.max_pages = max_pages + self.max_origins_per_page = max_origins_per_page @http_retry(before_sleep=before_sleep_log(logger, logging.WARNING)) def http_request(self, url: str, method="GET", **kwargs) -> requests.Response: @@ -172,11 +178,25 @@ try: for page in self.get_pages(): full_stats.pages += 1 - origins = self.get_origins_from_page(page) + origins = list(self.get_origins_from_page(page)) + if ( + self.max_origins_per_page + and len(origins) > self.max_origins_per_page + ): + logger.info( + "Max origins per page set, truncated %s page results down to %s", + len(origins), + self.max_origins_per_page, + ) + origins = origins[: self.max_origins_per_page] sent_origins = self.send_origins(origins) self.recorded_origins.update(sent_origins) full_stats.origins = len(self.recorded_origins) self.commit_page(page) + + if self.max_pages and full_stats.pages >= self.max_pages: + logger.info("Reached page limit of %s, terminating", self.max_pages) + break finally: self.finalize() if self.updated: diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py --- a/swh/lister/tests/test_pattern.py +++ b/swh/lister/tests/test_pattern.py @@ -215,3 +215,70 @@ assert run_result.pages == 2 assert run_result.origins == 1 + + +class ListerWithALotOfPagesWithALotOfOrigins(RunnableStatelessLister): + def get_pages(self) -> Iterator[PageType]: + for page in range(10): + yield [ + {"url": f"https://example.org/page{page}/origin{origin}"} + for origin in range(10) + ] + + +@pytest.mark.parametrize( + "max_pages,expected_pages", + [ + (2, 2), + (10, 10), + (100, 10), + # The default returns all 10 pages + (None, 10), + ], +) +def test_lister_max_pages(swh_scheduler, max_pages, expected_pages): + extra_kwargs = {} + if max_pages is not None: + extra_kwargs["max_pages"] = max_pages + + lister = ListerWithALotOfPagesWithALotOfOrigins( + scheduler=swh_scheduler, + url="https://example.org", + instance="example.org", + **extra_kwargs, + ) + + run_result = lister.run() + + assert run_result.pages == expected_pages + assert run_result.origins == expected_pages * 10 + + +@pytest.mark.parametrize( + "max_origins_per_page,expected_origins_per_page", + [ + (2, 2), + (10, 10), + (100, 10), + # The default returns all 10 origins per page + (None, 10), + ], +) +def test_lister_max_origins_per_page( + swh_scheduler, max_origins_per_page, expected_origins_per_page +): + extra_kwargs = {} + if max_origins_per_page is not None: + extra_kwargs["max_origins_per_page"] = max_origins_per_page + + lister = ListerWithALotOfPagesWithALotOfOrigins( + scheduler=swh_scheduler, + url="https://example.org", + instance="example.org", + **extra_kwargs, + ) + + run_result = lister.run() + + assert run_result.pages == 10 + assert run_result.origins == 10 * expected_origins_per_page