diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -8,7 +8,7 @@ import os from time import monotonic as _monotonic import traceback -from typing import Any, Dict +from typing import Any, Dict, Optional import urllib.parse from celery import Celery @@ -225,6 +225,27 @@ return stats.get("messages") +MAX_NUM_TASKS = 10000 + + +def get_available_slots(app, queue_name: str, max_length: Optional[int]): + """Get the number of tasks that can be sent to `queue_name`, when + the queue is limited to `max_length`.""" + + if not max_length: + return MAX_NUM_TASKS + + try: + queue_length = get_queue_length(app, queue_name) + # Clamp the return value to MAX_NUM_TASKS + max_val = min(max_length - queue_length, MAX_NUM_TASKS) + except (ValueError, TypeError): + # Unknown queue length, just schedule all the tasks + max_val = MAX_NUM_TASKS + + return max_val + + def register_task_class(app, name, cls): """Register a class-based task under the given name""" if name in app.tasks: diff --git a/swh/scheduler/celery_backend/runner.py b/swh/scheduler/celery_backend/runner.py --- a/swh/scheduler/celery_backend/runner.py +++ b/swh/scheduler/celery_backend/runner.py @@ -10,6 +10,7 @@ from swh.core.statsd import statsd from swh.scheduler import get_scheduler +from swh.scheduler.celery_backend.config import get_available_slots from swh.scheduler.interface import SchedulerInterface from swh.scheduler.utils import utcnow @@ -60,19 +61,7 @@ if max_queue_length is None: max_queue_length = 0 backend_name = task_type["backend_name"] - if max_queue_length: - try: - queue_length = app.get_queue_length(backend_name) - except ValueError: - queue_length = None - - if queue_length is None: - # Running without RabbitMQ (probably a test env). - num_tasks = MAX_NUM_TASKS - else: - num_tasks = min(max_queue_length - queue_length, MAX_NUM_TASKS) - else: - num_tasks = MAX_NUM_TASKS + num_tasks = get_available_slots(app, backend_name, max_queue_length) # only pull tasks if the buffer is at least 1/5th empty (= 80% # full), to help postgresql use properly indexed queries. if num_tasks > min(MAX_NUM_TASKS, max_queue_length) // 5: diff --git a/swh/scheduler/tests/test_config.py b/swh/scheduler/tests/test_config.py --- a/swh/scheduler/tests/test_config.py +++ b/swh/scheduler/tests/test_config.py @@ -5,7 +5,12 @@ import pytest -from swh.scheduler.celery_backend.config import route_for_task +from swh.scheduler.celery_backend.config import ( + MAX_NUM_TASKS, + app, + get_available_slots, + route_for_task, +) @pytest.mark.parametrize("name", ["swh.something", "swh.anything"]) @@ -16,3 +21,35 @@ @pytest.mark.parametrize("name", [None, "foobar"]) def test_route_for_task_no_routing(name): assert route_for_task(name, [], {}, {}) is None + + +def test_get_available_slots_no_max_length(): + actual_num = get_available_slots(app, "anything", None) + assert actual_num == MAX_NUM_TASKS + + +def test_get_available_slots_issue_when_reading_queue(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + mock.side_effect = ValueError + + actual_num = get_available_slots(app, "anything", max_length=10) + assert actual_num == MAX_NUM_TASKS + assert mock.called + + +def test_get_available_slots_no_queue_length(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + mock.return_value = None + actual_num = get_available_slots(app, "anything", max_length=100) + assert actual_num == MAX_NUM_TASKS + assert mock.called + + +def test_get_available_slots(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + max_length = 100 + queue_length = 90 + mock.return_value = queue_length + actual_num = get_available_slots(app, "anything", max_length) + assert actual_num == max_length - queue_length + assert mock.called