diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py --- a/swh/lister/pattern.py +++ b/swh/lister/pattern.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Set, TypeVar from urllib.parse import urlparse +import attr import requests from tenacity.before_sleep import before_sleep_log @@ -86,6 +87,7 @@ expected credentials for the given instance of that lister. max_pages: the maximum number of pages listed in a full listing operation max_origins_per_page: the maximum number of origins processed per page + enable_origins: whether the created origins should be enabled or not Generic types: - *StateType*: concrete lister type; should usually be a :class:`dataclass` for @@ -106,6 +108,7 @@ credentials: CredentialsType = None, max_origins_per_page: Optional[int] = None, max_pages: Optional[int] = None, + enable_origins: bool = True, with_github_session: bool = False, ): if not self.LISTER_NAME: @@ -146,6 +149,7 @@ self.recorded_origins: Set[str] = set() self.max_pages = max_pages self.max_origins_per_page = max_origins_per_page + self.enable_origins = enable_origins @http_retry(before_sleep=before_sleep_log(logger, logging.WARNING)) def http_request(self, url: str, method="GET", **kwargs) -> requests.Response: @@ -189,6 +193,11 @@ self.max_origins_per_page, ) origins = origins[: self.max_origins_per_page] + if not self.enable_origins: + logger.info( + "Disabling origins before sending them to the scheduler" + ) + origins = [attr.evolve(origin, enabled=False) for origin in origins] sent_origins = self.send_origins(origins) self.recorded_origins.update(sent_origins) full_stats.origins = len(self.recorded_origins) diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py --- a/swh/lister/tests/test_pattern.py +++ b/swh/lister/tests/test_pattern.py @@ -282,3 +282,37 @@ assert run_result.pages == 10 assert run_result.origins == 10 * expected_origins_per_page + + +@pytest.mark.parametrize( + "enable_origins,expected", + [ + (True, True), + (False, False), + # default behavior is to enable all listed origins + (None, True), + ], +) +def test_lister_enable_origins(swh_scheduler, enable_origins, expected): + extra_kwargs = {} + if enable_origins is not None: + extra_kwargs["enable_origins"] = enable_origins + + lister = ListerWithALotOfPagesWithALotOfOrigins( + scheduler=swh_scheduler, + url="https://example.org", + instance="example.org", + **extra_kwargs, + ) + + run_result = lister.run() + assert run_result.pages == 10 + assert run_result.origins == 100 + + origins = swh_scheduler.get_listed_origins( + lister_id=lister.lister_obj.id, enabled=None + ).results + + assert origins + + assert all(origin.enabled == expected for origin in origins)