diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -376,6 +376,8 @@ policy: str, enabled: bool = True, lister_uuid: Optional[str] = None, + lister_name: Optional[str] = None, + lister_instance_name: Optional[str] = None, timestamp: Optional[datetime.datetime] = None, absolute_cooldown: Optional[datetime.timedelta] = datetime.timedelta(hours=12), scheduled_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=7), @@ -390,6 +392,10 @@ origin_select_cols = ", ".join(ListedOrigin.select_columns()) + joins: Dict[str, str] = { + "origin_visit_stats": "USING (url, visit_type)", + } + query_args: List[Any] = [] where_clauses = [] @@ -511,14 +517,27 @@ where_clauses.append("lister_id = %s") query_args.append(lister_uuid) + if lister_name: + joins["listers"] = "on listed_origins.lister_id=listers.id" + where_clauses.append("listers.name = %s") + query_args.append(lister_name) + + if lister_instance_name: + joins["listers"] = "on listed_origins.lister_id=listers.id" + where_clauses.append("listers.instance_name = %s") + query_args.append(lister_instance_name) + + join_clause = "\n".join( + f"left join {table} {clause}" for table, clause in joins.items() + ) + # fmt: off common_table_expressions.insert(0, ("selected_origins", f""" SELECT {origin_select_cols}, next_visit_queue_position FROM {table} - LEFT JOIN - origin_visit_stats USING (url, visit_type) + {join_clause} WHERE ({") AND (".join(where_clauses)}) ORDER BY diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -413,6 +413,8 @@ policy: str, enabled: bool = True, lister_uuid: Optional[str] = None, + lister_name: Optional[str] = None, + lister_instance_name: Optional[str] = None, timestamp: Optional[datetime.datetime] = None, absolute_cooldown: Optional[datetime.timedelta] = datetime.timedelta(hours=12), scheduled_cooldown: Optional[datetime.timedelta] = datetime.timedelta(days=7), @@ -434,6 +436,9 @@ default, we want reasonably enabled origins. For some edge case, we might want the others. lister_uuid: Determine the list of origins listed from the lister with uuid + lister_name: Determine the list of origins listed from the lister with name + lister_instance_name: Determine the list of origins listed from the lister + with instance name timestamp: the mocked timestamp at which we're recording that the visits are being scheduled (defaults to the current time) absolute_cooldown: the minimal interval between two visits of the same origin diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -827,15 +827,15 @@ assert ret.next_page_token is None assert len(ret.results) == len(listed_origins_with_non_enabled) - def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type): + def _grab_next_visits_setup(self, swh_scheduler, listed_origins_by_type, limit=100): """Basic origins setup for scheduling policy tests""" visit_type = next(iter(listed_origins_by_type)) - origins = listed_origins_by_type[visit_type][:100] - assert len(origins) > 0 - recorded_origins = swh_scheduler.record_listed_origins(origins) + all_origins = listed_origins_by_type[visit_type] + origins = all_origins[:limit] if limit else all_origins + assert len(origins) > 0 - return visit_type, recorded_origins + return visit_type, swh_scheduler.record_listed_origins(origins) def _check_grab_next_visit_basic( self, swh_scheduler, visit_type, policy, expected, **kwargs @@ -1303,6 +1303,29 @@ expected=expected_origins, ) + def test_grab_next_visit_for_specific_lister( + self, swh_scheduler, listed_origins_by_type, stored_lister + ): + """Checks grab_next_visits filters on the given lister {name, instance name}""" + + visit_type, origins = self._grab_next_visits_setup( + swh_scheduler, listed_origins_by_type, limit=None + ) + + expected_origins = [origin for origin in listed_origins_by_type[visit_type]] + + ret = swh_scheduler.grab_next_visits( + visit_type=visit_type, + count=len(expected_origins), + policy="never_visited_oldest_update_first", + lister_name=stored_lister.name, + lister_instance_name=stored_lister.instance_name, + ) + + assert len(ret) == len(expected_origins) + for origin in ret: + assert origin.lister_id == stored_lister.id + def _create_task_types(self, scheduler): for tt in TASK_TYPES.values(): scheduler.create_task_type(tt)