diff --git a/swh/scheduler/celery_backend/recurrent_visits.py b/swh/scheduler/celery_backend/recurrent_visits.py --- a/swh/scheduler/celery_backend/recurrent_visits.py +++ b/swh/scheduler/celery_backend/recurrent_visits.py @@ -32,30 +32,28 @@ logger = logging.getLogger(__name__) -_VCS_POLICY_WEIGHTS: Dict[str, float] = { - "already_visited_order_by_lag": 49, - "never_visited_oldest_update_first": 49, - "origins_without_last_update": 2, +DEFAULT_POLICY = { + "already_visited_order_by_lag": {"weight": 50}, + "never_visited_oldest_update_first": {"weight": 50}, } - -POLICY_WEIGHTS: Dict[str, Dict[str, float]] = { - "default": { - "already_visited_order_by_lag": 50, - "never_visited_oldest_update_first": 50, - }, - "git": _VCS_POLICY_WEIGHTS, - "hg": _VCS_POLICY_WEIGHTS, - "svn": _VCS_POLICY_WEIGHTS, - "cvs": _VCS_POLICY_WEIGHTS, - "bzr": _VCS_POLICY_WEIGHTS, +DEFAULT_DVCS_POLICY = { + "already_visited_order_by_lag": {"weight": 49}, + "never_visited_oldest_update_first": {"weight": 49}, + "origins_without_last_update": {"weight": 2}, +} +DEFAULT_GIT_POLICY = { + "already_visited_order_by_lag": {"weight": 49, "tablesample": 0.1}, + "never_visited_oldest_update_first": {"weight": 49, "tablesample": 0.1}, + "origins_without_last_update": {"weight": 2, "tablesample": 0.1}, } -POLICY_ADDITIONAL_PARAMETERS: Dict[str, Dict[str, Any]] = { - "git": { - "already_visited_order_by_lag": {"tablesample": 0.1}, - "never_visited_oldest_update_first": {"tablesample": 0.1}, - "origins_without_last_update": {"tablesample": 0.1}, - } +DEFAULT_POLICY_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = { + "default": DEFAULT_POLICY, + "hg": DEFAULT_DVCS_POLICY, + "svn": DEFAULT_DVCS_POLICY, + "cvs": DEFAULT_DVCS_POLICY, + "bzr": DEFAULT_DVCS_POLICY, + "git": DEFAULT_GIT_POLICY, } """Scheduling policies to use to retrieve visits for the given visit types, with their @@ -82,7 +80,7 @@ def grab_next_visits_policy_weights( - scheduler: SchedulerInterface, visit_type: str, num_visits: int + scheduler: SchedulerInterface, visit_type: str, num_visits: int, policy_cfg: Dict ) -> List[ListedOrigin]: """Get the next ``num_visits`` for the given ``visit_type`` using the corresponding set of scheduling policies. @@ -97,7 +95,8 @@ Returns: at most ``num_visits`` :py:class:`~swh.scheduler.model.ListedOrigin` objects """ - policy_weights = POLICY_WEIGHTS.get(visit_type, POLICY_WEIGHTS["default"]) + + policy_weights = {name: cfg["weight"] for name, cfg in policy_cfg.items()} total_weight = sum(policy_weights.values()) if not total_weight: @@ -107,6 +106,10 @@ policy: weight / total_weight for policy, weight in policy_weights.items() } + policy_cfg = { + name: {k: v for k, v in cfg.items() if k != "weight"} + for name, cfg in policy_cfg.items() + } fetched_origins: Dict[str, List[ListedOrigin]] = {} for policy, ratio in policy_ratio.items(): @@ -115,7 +118,7 @@ visit_type, num_tasks_to_send, policy=policy, - **POLICY_ADDITIONAL_PARAMETERS.get(visit_type, {}).get(policy, {}), + **policy_cfg[policy], ) all_origins: List[ListedOrigin] = list( @@ -157,6 +160,7 @@ app, visit_type: str, task_type: Dict, + policy_config: Dict, ) -> float: """Schedule the next batch of visits for the given ``visit_type``. @@ -199,7 +203,9 @@ if available_slots < min_available_slots: return current_iteration_start + QUEUE_FULL_BACKOFF - origins = grab_next_visits_policy_weights(scheduler, visit_type, available_slots) + origins = grab_next_visits_policy_weights( + scheduler, visit_type, available_slots, policy_config + ) if not origins: logger.debug("No origins to visit for type %s", visit_type) @@ -248,10 +254,14 @@ app = build_app(config.get("celery")) scheduler = get_scheduler(**config["scheduler"]) task_type = scheduler.get_task_type(f"load-{visit_type}") - if task_type is None: raise ValueError(f"Unknown task type: load-{visit_type}") + policy_cfg = config.get("scheduling_policy", DEFAULT_POLICY_CONFIG) + policy_cfg = {**DEFAULT_POLICY_CONFIG, **policy_cfg} + for policy in policy_cfg.values(): + assert "weight" in policy + next_iteration = time.monotonic() while True: @@ -270,7 +280,11 @@ logger.warn("Received unexpected message %s in command queue", msg) next_iteration = send_visits_for_visit_type( - scheduler, app, visit_type, task_type + scheduler, + app, + visit_type, + task_type, + policy_cfg.get(visit_type, policy_cfg["default"]), ) except BaseException as e: diff --git a/swh/scheduler/tests/test_recurrent_visits.py b/swh/scheduler/tests/test_recurrent_visits.py --- a/swh/scheduler/tests/test_recurrent_visits.py +++ b/swh/scheduler/tests/test_recurrent_visits.py @@ -6,12 +6,12 @@ from datetime import timedelta import logging from queue import Queue -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from swh.scheduler.celery_backend.recurrent_visits import ( - POLICY_ADDITIONAL_PARAMETERS, + DEFAULT_DVCS_POLICY, VisitSchedulerThreads, grab_next_visits_policy_weights, send_visits_for_visit_type, @@ -120,7 +120,11 @@ for visit_type in ["test-git", "test-svn"]: task_type = f"load-{visit_type}" send_visits_for_visit_type( - swh_scheduler, mock_celery_app, visit_type, all_task_types[task_type] + swh_scheduler, + mock_celery_app, + visit_type, + all_task_types[task_type], + DEFAULT_DVCS_POLICY, ) assert mock_available_slots.called, "The available slots functions should be called" @@ -143,27 +147,23 @@ assert expected_record in set(records) -@patch.dict( - POLICY_ADDITIONAL_PARAMETERS, {"test-git": POLICY_ADDITIONAL_PARAMETERS["git"]} -) @pytest.mark.parametrize( - "visit_type, tablesamples", - [("test-hg", {}), ("test-git", POLICY_ADDITIONAL_PARAMETERS["git"])], + "visit_type, extras", + [("test-hg", {}), ("test-git", {"tablesample": 0.1})], ) def test_recurrent_visit_additional_parameters( - swh_scheduler, mocker, visit_type, tablesamples + swh_scheduler, mocker, visit_type, extras ): """Testing additional policy parameters""" mock_grab_next_visits = mocker.patch.object(swh_scheduler, "grab_next_visits") mock_grab_next_visits.return_value = [] - grab_next_visits_policy_weights(swh_scheduler, visit_type, 10) + policy_cfg = {name: {**cfg, **extras} for name, cfg in DEFAULT_DVCS_POLICY.items()} + grab_next_visits_policy_weights(swh_scheduler, visit_type, 10, policy_cfg) for call in mock_grab_next_visits.call_args_list: - assert call[1].get("tablesample") == tablesamples.get( - call[1]["policy"], {} - ).get("tablesample") + assert call[1].get("tablesample") == extras.get("tablesample") @pytest.fixture