diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py index 20b21bd..82f701f 100644 --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -1,342 +1,363 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools import logging import os from time import monotonic as _monotonic import traceback -from typing import Any, Dict +from typing import Any, Dict, Optional import urllib.parse from celery import Celery from celery.signals import celeryd_after_setup, setup_logging, worker_init from celery.utils.log import ColorFormatter from celery.worker.control import Panel from kombu import Exchange, Queue import pkg_resources import requests from swh.core.config import load_named_config, merge_configs from swh.core.sentry import init_sentry from swh.scheduler import CONFIG as SWH_CONFIG try: from swh.core.logger import JournalHandler except ImportError: JournalHandler = None # type: ignore DEFAULT_CONFIG_NAME = "worker" CONFIG_NAME_ENVVAR = "SWH_WORKER_INSTANCE" CONFIG_NAME_TEMPLATE = "worker/%s" DEFAULT_CONFIG = { "task_broker": ("str", "amqp://guest@localhost//"), "task_modules": ("list[str]", []), "task_queues": ("list[str]", []), "task_soft_time_limit": ("int", 0), } logger = logging.getLogger(__name__) # Celery eats tracebacks in signal callbacks, this decorator catches # and prints them. # Also tries to notify Sentry if possible. def _print_errors(f): @functools.wraps(f) def newf(*args, **kwargs): try: return f(*args, **kwargs) except Exception as exc: traceback.print_exc() try: import sentry_sdk sentry_sdk.capture_exception(exc) except Exception: traceback.print_exc() return newf @setup_logging.connect @_print_errors def setup_log_handler( loglevel=None, logfile=None, format=None, colorize=None, log_console=None, log_journal=None, **kwargs, ): """Setup logging according to Software Heritage preferences. We use the command-line loglevel for tasks only, as we never really care about the debug messages from celery. """ if loglevel is None: loglevel = logging.DEBUG if isinstance(loglevel, str): loglevel = logging._nameToLevel[loglevel] formatter = logging.Formatter(format) root_logger = logging.getLogger("") root_logger.setLevel(logging.INFO) log_target = os.environ.get("SWH_LOG_TARGET", "console") if log_target == "console": log_console = True elif log_target == "journal": log_journal = True # this looks for log levels *higher* than DEBUG if loglevel <= logging.DEBUG and log_console is None: log_console = True if log_console: color_formatter = ColorFormatter(format) if colorize else formatter console = logging.StreamHandler() console.setLevel(logging.DEBUG) console.setFormatter(color_formatter) root_logger.addHandler(console) if log_journal: if not JournalHandler: root_logger.warning( "JournalHandler is not available, skipping. " "Please install swh-core[logging]." ) else: systemd_journal = JournalHandler() systemd_journal.setLevel(logging.DEBUG) systemd_journal.setFormatter(formatter) root_logger.addHandler(systemd_journal) logging.getLogger("celery").setLevel(logging.INFO) # Silence amqp heartbeat_tick messages logger = logging.getLogger("amqp") logger.addFilter(lambda record: not record.msg.startswith("heartbeat_tick")) logger.setLevel(loglevel) # Silence useless "Starting new HTTP connection" messages logging.getLogger("urllib3").setLevel(logging.WARNING) # Completely disable azure logspam azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") azure_logger.setLevel(logging.WARNING) logging.getLogger("swh").setLevel(loglevel) # get_task_logger makes the swh tasks loggers children of celery.task logging.getLogger("celery.task").setLevel(loglevel) return loglevel @celeryd_after_setup.connect @_print_errors def setup_queues_and_tasks(sender, instance, **kwargs): """Signal called on worker start. This automatically registers swh.scheduler.task.Task subclasses as available celery tasks. This also subscribes the worker to the "implicit" per-task queues defined for these task classes. """ logger.info("Setup Queues & Tasks for %s", sender) instance.app.conf["worker_name"] = sender @worker_init.connect @_print_errors def on_worker_init(*args, **kwargs): try: from sentry_sdk.integrations.celery import CeleryIntegration except ImportError: integrations = [] else: integrations = [CeleryIntegration()] sentry_dsn = None # will be set in `init_sentry` function init_sentry(sentry_dsn, integrations=integrations) @Panel.register def monotonic(state): """Get the current value for the monotonic clock""" return {"monotonic": _monotonic()} def route_for_task(name, args, kwargs, options, task=None, **kw): """Route tasks according to the task_queue attribute in the task class""" if name is not None and name.startswith("swh."): return {"queue": name} def get_queue_stats(app, queue_name): """Get the statistics regarding a queue on the broker. Arguments: queue_name: name of the queue to check Returns a dictionary raw from the RabbitMQ management API; or `None` if the current configuration does not use RabbitMQ. Interesting keys: - Consumers (number of consumers for the queue) - messages (number of messages in queue) - messages_unacknowledged (number of messages currently being processed) Documentation: https://www.rabbitmq.com/management.html#http-api """ conn_info = app.connection().info() if conn_info["transport"] == "memory": # We're running in a test environment, without RabbitMQ. return None url = "http://{hostname}:{port}/api/queues/{vhost}/{queue}".format( hostname=conn_info["hostname"], port=conn_info["port"] + 10000, vhost=urllib.parse.quote(conn_info["virtual_host"], safe=""), queue=urllib.parse.quote(queue_name, safe=""), ) credentials = (conn_info["userid"], conn_info["password"]) r = requests.get(url, auth=credentials) if r.status_code == 404: return {} if r.status_code != 200: raise ValueError( "Got error %s when reading queue stats: %s" % (r.status_code, r.json()) ) return r.json() def get_queue_length(app, queue_name): """Shortcut to get a queue's length""" stats = get_queue_stats(app, queue_name) if stats: return stats.get("messages") +MAX_NUM_TASKS = 10000 + + +def get_available_slots(app, queue_name: str, max_length: Optional[int]): + """Get the number of tasks that can be sent to `queue_name`, when + the queue is limited to `max_length`.""" + + if not max_length: + return MAX_NUM_TASKS + + try: + queue_length = get_queue_length(app, queue_name) + # Clamp the return value to MAX_NUM_TASKS + max_val = min(max_length - queue_length, MAX_NUM_TASKS) + except (ValueError, TypeError): + # Unknown queue length, just schedule all the tasks + max_val = MAX_NUM_TASKS + + return max_val + + def register_task_class(app, name, cls): """Register a class-based task under the given name""" if name in app.tasks: return task_instance = cls() task_instance.name = name app.register_task(task_instance) INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR) CONFIG_NAME = os.environ.get("SWH_CONFIG_FILENAME") CONFIG = {} # type: Dict[str, Any] if CONFIG_NAME: # load the celery config from the main config file given as # SWH_CONFIG_FILENAME environment variable. # This is expected to have a [celery] section in which we have the # celery specific configuration. SWH_CONFIG.clear() SWH_CONFIG.update(load_named_config(CONFIG_NAME)) CONFIG = SWH_CONFIG.get("celery", {}) if not CONFIG: # otherwise, back to compat config loading mechanism if INSTANCE_NAME: CONFIG_NAME = CONFIG_NAME_TEMPLATE % INSTANCE_NAME else: CONFIG_NAME = DEFAULT_CONFIG_NAME # Load the Celery config CONFIG = load_named_config(CONFIG_NAME, DEFAULT_CONFIG) CONFIG.setdefault("task_modules", []) # load tasks modules declared as plugin entry points for entrypoint in pkg_resources.iter_entry_points("swh.workers"): worker_registrer_fn = entrypoint.load() # The registry function is expected to return a dict which the 'tasks' key # is a string (or a list of strings) with the name of the python module in # which celery tasks are defined. task_modules = worker_registrer_fn().get("task_modules", []) CONFIG["task_modules"].extend(task_modules) # Celery Queues CELERY_QUEUES = [Queue("celery", Exchange("celery"), routing_key="celery")] CELERY_DEFAULT_CONFIG = dict( # Timezone configuration: all in UTC enable_utc=True, timezone="UTC", # Imported modules imports=CONFIG.get("task_modules", []), # Time (in seconds, or a timedelta object) for when after stored task # tombstones will be deleted. None means to never expire results. result_expires=None, # A string identifying the default serialization method to use. Can # be json (default), pickle, yaml, msgpack, or any custom # serialization methods that have been registered with task_serializer="msgpack", # Result serialization format result_serializer="msgpack", # Acknowledge tasks as soon as they're received. We can do this as we have # external monitoring to decide if we need to retry tasks. task_acks_late=False, # A string identifying the default serialization method to use. # Can be pickle (default), json, yaml, msgpack or any custom serialization # methods that have been registered with kombu.serialization.registry accept_content=["msgpack", "json"], # If True the task will report its status as “started” # when the task is executed by a worker. task_track_started=True, # Default compression used for task messages. Can be gzip, bzip2 # (if available), or any custom compression schemes registered # in the Kombu compression registry. # result_compression='bzip2', # task_compression='bzip2', # Disable all rate limits, even if tasks has explicit rate limits set. # (Disabling rate limits altogether is recommended if you don’t have any # tasks using them.) worker_disable_rate_limits=True, # Task routing task_routes=route_for_task, # Allow pool restarts from remote worker_pool_restarts=True, # Do not prefetch tasks worker_prefetch_multiplier=1, # Send events worker_send_task_events=True, # Do not send useless task_sent events task_send_sent_event=False, ) def build_app(config=None): config = merge_configs( {k: v for (k, (_, v)) in DEFAULT_CONFIG.items()}, config or {} ) config["task_queues"] = CELERY_QUEUES + [ Queue(queue, Exchange(queue), routing_key=queue) for queue in config.get("task_queues", ()) ] logger.debug("Creating a Celery app with %s", config) # Instantiate the Celery app app = Celery(broker=config["task_broker"], task_cls="swh.scheduler.task:SWHTask") app.add_defaults(CELERY_DEFAULT_CONFIG) app.add_defaults(config) return app app = build_app(CONFIG) # XXX for BW compat Celery.get_queue_length = get_queue_length diff --git a/swh/scheduler/celery_backend/runner.py b/swh/scheduler/celery_backend/runner.py index ed088d5..e1fe84d 100644 --- a/swh/scheduler/celery_backend/runner.py +++ b/swh/scheduler/celery_backend/runner.py @@ -1,165 +1,154 @@ # Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import logging from typing import Dict, List, Tuple from kombu.utils.uuid import uuid from swh.core.statsd import statsd from swh.scheduler import get_scheduler +from swh.scheduler.celery_backend.config import get_available_slots from swh.scheduler.interface import SchedulerInterface from swh.scheduler.utils import utcnow logger = logging.getLogger(__name__) # Max batch size for tasks MAX_NUM_TASKS = 10000 def run_ready_tasks(backend: SchedulerInterface, app) -> List[Dict]: """Schedule tasks ready to be scheduled. This lookups any tasks per task type and mass schedules those accordingly (send messages to rabbitmq and mark as scheduled equivalent tasks in the scheduler backend). If tasks (per task type) with priority exist, they will get redirected to dedicated high priority queue (standard queue name prefixed with `save_code_now:`). Args: backend: scheduler backend to interact with (read/update tasks) app (App): Celery application to send tasks to Returns: A list of dictionaries:: { 'task': the scheduler's task id, 'backend_id': Celery's task id, 'scheduler': utcnow() } The result can be used to block-wait for the tasks' results:: backend_tasks = run_ready_tasks(self.scheduler, app) for task in backend_tasks: AsyncResult(id=task['backend_id']).get() """ all_backend_tasks: List[Dict] = [] while True: task_types = {} pending_tasks = [] for task_type in backend.get_task_types(): task_type_name = task_type["type"] task_types[task_type_name] = task_type max_queue_length = task_type["max_queue_length"] if max_queue_length is None: max_queue_length = 0 backend_name = task_type["backend_name"] - if max_queue_length: - try: - queue_length = app.get_queue_length(backend_name) - except ValueError: - queue_length = None - - if queue_length is None: - # Running without RabbitMQ (probably a test env). - num_tasks = MAX_NUM_TASKS - else: - num_tasks = min(max_queue_length - queue_length, MAX_NUM_TASKS) - else: - num_tasks = MAX_NUM_TASKS + num_tasks = get_available_slots(app, backend_name, max_queue_length) # only pull tasks if the buffer is at least 1/5th empty (= 80% # full), to help postgresql use properly indexed queries. if num_tasks > min(MAX_NUM_TASKS, max_queue_length) // 5: # Only grab num_tasks tasks with no priority grabbed_tasks = backend.grab_ready_tasks( task_type_name, num_tasks=num_tasks ) if grabbed_tasks: pending_tasks.extend(grabbed_tasks) logger.info( "Grabbed %s tasks %s", len(grabbed_tasks), task_type_name ) statsd.increment( "swh_scheduler_runner_scheduled_task_total", len(grabbed_tasks), tags={"task_type": task_type_name}, ) # grab max_queue_length (or 10) potential tasks with any priority for the # same type (limit the result to avoid too long running queries) grabbed_priority_tasks = backend.grab_ready_priority_tasks( task_type_name, num_tasks=max_queue_length or 10 ) if grabbed_priority_tasks: pending_tasks.extend(grabbed_priority_tasks) logger.info( "Grabbed %s tasks %s (priority)", len(grabbed_priority_tasks), task_type_name, ) statsd.increment( "swh_scheduler_runner_scheduled_task_total", len(grabbed_priority_tasks), tags={"task_type": task_type_name}, ) if not pending_tasks: return all_backend_tasks backend_tasks = [] celery_tasks: List[Tuple[bool, str, str, List, Dict]] = [] for task in pending_tasks: args = task["arguments"]["args"] kwargs = task["arguments"]["kwargs"] backend_name = task_types[task["type"]]["backend_name"] backend_id = uuid() celery_tasks.append( ( task.get("priority") is not None, backend_name, backend_id, args, kwargs, ) ) data = { "task": task["id"], "backend_id": backend_id, "scheduled": utcnow(), } backend_tasks.append(data) logger.debug("Sent %s celery tasks", len(backend_tasks)) backend.mass_schedule_task_runs(backend_tasks) for with_priority, backend_name, backend_id, args, kwargs in celery_tasks: kw = dict(task_id=backend_id, args=args, kwargs=kwargs,) if with_priority: kw["queue"] = f"save_code_now:{backend_name}" app.send_task(backend_name, **kw) all_backend_tasks.extend(backend_tasks) def main(): from .config import app as main_app for module in main_app.conf.CELERY_IMPORTS: __import__(module) main_backend = get_scheduler("local") try: run_ready_tasks(main_backend, main_app) except Exception: main_backend.rollback() raise if __name__ == "__main__": main() diff --git a/swh/scheduler/tests/test_config.py b/swh/scheduler/tests/test_config.py index c166f62..f0a705c 100644 --- a/swh/scheduler/tests/test_config.py +++ b/swh/scheduler/tests/test_config.py @@ -1,18 +1,55 @@ # Copyright (C) 2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest -from swh.scheduler.celery_backend.config import route_for_task +from swh.scheduler.celery_backend.config import ( + MAX_NUM_TASKS, + app, + get_available_slots, + route_for_task, +) @pytest.mark.parametrize("name", ["swh.something", "swh.anything"]) def test_route_for_task_routing(name): assert route_for_task(name, [], {}, {}) == {"queue": name} @pytest.mark.parametrize("name", [None, "foobar"]) def test_route_for_task_no_routing(name): assert route_for_task(name, [], {}, {}) is None + + +def test_get_available_slots_no_max_length(): + actual_num = get_available_slots(app, "anything", None) + assert actual_num == MAX_NUM_TASKS + + +def test_get_available_slots_issue_when_reading_queue(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + mock.side_effect = ValueError + + actual_num = get_available_slots(app, "anything", max_length=10) + assert actual_num == MAX_NUM_TASKS + assert mock.called + + +def test_get_available_slots_no_queue_length(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + mock.return_value = None + actual_num = get_available_slots(app, "anything", max_length=100) + assert actual_num == MAX_NUM_TASKS + assert mock.called + + +def test_get_available_slots(mocker): + mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length") + max_length = 100 + queue_length = 90 + mock.return_value = queue_length + actual_num = get_available_slots(app, "anything", max_length) + assert actual_num == max_length - queue_length + assert mock.called