Page MenuHomeSoftware Heritage

No OneTemporary

diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py
index 20b21bd..82f701f 100644
--- a/swh/scheduler/celery_backend/config.py
+++ b/swh/scheduler/celery_backend/config.py
@@ -1,342 +1,363 @@
-# Copyright (C) 2015-2019 The Software Heritage developers
+# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import functools
import logging
import os
from time import monotonic as _monotonic
import traceback
-from typing import Any, Dict
+from typing import Any, Dict, Optional
import urllib.parse
from celery import Celery
from celery.signals import celeryd_after_setup, setup_logging, worker_init
from celery.utils.log import ColorFormatter
from celery.worker.control import Panel
from kombu import Exchange, Queue
import pkg_resources
import requests
from swh.core.config import load_named_config, merge_configs
from swh.core.sentry import init_sentry
from swh.scheduler import CONFIG as SWH_CONFIG
try:
from swh.core.logger import JournalHandler
except ImportError:
JournalHandler = None # type: ignore
DEFAULT_CONFIG_NAME = "worker"
CONFIG_NAME_ENVVAR = "SWH_WORKER_INSTANCE"
CONFIG_NAME_TEMPLATE = "worker/%s"
DEFAULT_CONFIG = {
"task_broker": ("str", "amqp://guest@localhost//"),
"task_modules": ("list[str]", []),
"task_queues": ("list[str]", []),
"task_soft_time_limit": ("int", 0),
}
logger = logging.getLogger(__name__)
# Celery eats tracebacks in signal callbacks, this decorator catches
# and prints them.
# Also tries to notify Sentry if possible.
def _print_errors(f):
@functools.wraps(f)
def newf(*args, **kwargs):
try:
return f(*args, **kwargs)
except Exception as exc:
traceback.print_exc()
try:
import sentry_sdk
sentry_sdk.capture_exception(exc)
except Exception:
traceback.print_exc()
return newf
@setup_logging.connect
@_print_errors
def setup_log_handler(
loglevel=None,
logfile=None,
format=None,
colorize=None,
log_console=None,
log_journal=None,
**kwargs,
):
"""Setup logging according to Software Heritage preferences.
We use the command-line loglevel for tasks only, as we never
really care about the debug messages from celery.
"""
if loglevel is None:
loglevel = logging.DEBUG
if isinstance(loglevel, str):
loglevel = logging._nameToLevel[loglevel]
formatter = logging.Formatter(format)
root_logger = logging.getLogger("")
root_logger.setLevel(logging.INFO)
log_target = os.environ.get("SWH_LOG_TARGET", "console")
if log_target == "console":
log_console = True
elif log_target == "journal":
log_journal = True
# this looks for log levels *higher* than DEBUG
if loglevel <= logging.DEBUG and log_console is None:
log_console = True
if log_console:
color_formatter = ColorFormatter(format) if colorize else formatter
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
console.setFormatter(color_formatter)
root_logger.addHandler(console)
if log_journal:
if not JournalHandler:
root_logger.warning(
"JournalHandler is not available, skipping. "
"Please install swh-core[logging]."
)
else:
systemd_journal = JournalHandler()
systemd_journal.setLevel(logging.DEBUG)
systemd_journal.setFormatter(formatter)
root_logger.addHandler(systemd_journal)
logging.getLogger("celery").setLevel(logging.INFO)
# Silence amqp heartbeat_tick messages
logger = logging.getLogger("amqp")
logger.addFilter(lambda record: not record.msg.startswith("heartbeat_tick"))
logger.setLevel(loglevel)
# Silence useless "Starting new HTTP connection" messages
logging.getLogger("urllib3").setLevel(logging.WARNING)
# Completely disable azure logspam
azure_logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
azure_logger.setLevel(logging.WARNING)
logging.getLogger("swh").setLevel(loglevel)
# get_task_logger makes the swh tasks loggers children of celery.task
logging.getLogger("celery.task").setLevel(loglevel)
return loglevel
@celeryd_after_setup.connect
@_print_errors
def setup_queues_and_tasks(sender, instance, **kwargs):
"""Signal called on worker start.
This automatically registers swh.scheduler.task.Task subclasses as
available celery tasks.
This also subscribes the worker to the "implicit" per-task queues defined
for these task classes.
"""
logger.info("Setup Queues & Tasks for %s", sender)
instance.app.conf["worker_name"] = sender
@worker_init.connect
@_print_errors
def on_worker_init(*args, **kwargs):
try:
from sentry_sdk.integrations.celery import CeleryIntegration
except ImportError:
integrations = []
else:
integrations = [CeleryIntegration()]
sentry_dsn = None # will be set in `init_sentry` function
init_sentry(sentry_dsn, integrations=integrations)
@Panel.register
def monotonic(state):
"""Get the current value for the monotonic clock"""
return {"monotonic": _monotonic()}
def route_for_task(name, args, kwargs, options, task=None, **kw):
"""Route tasks according to the task_queue attribute in the task class"""
if name is not None and name.startswith("swh."):
return {"queue": name}
def get_queue_stats(app, queue_name):
"""Get the statistics regarding a queue on the broker.
Arguments:
queue_name: name of the queue to check
Returns a dictionary raw from the RabbitMQ management API;
or `None` if the current configuration does not use RabbitMQ.
Interesting keys:
- Consumers (number of consumers for the queue)
- messages (number of messages in queue)
- messages_unacknowledged (number of messages currently being
processed)
Documentation: https://www.rabbitmq.com/management.html#http-api
"""
conn_info = app.connection().info()
if conn_info["transport"] == "memory":
# We're running in a test environment, without RabbitMQ.
return None
url = "http://{hostname}:{port}/api/queues/{vhost}/{queue}".format(
hostname=conn_info["hostname"],
port=conn_info["port"] + 10000,
vhost=urllib.parse.quote(conn_info["virtual_host"], safe=""),
queue=urllib.parse.quote(queue_name, safe=""),
)
credentials = (conn_info["userid"], conn_info["password"])
r = requests.get(url, auth=credentials)
if r.status_code == 404:
return {}
if r.status_code != 200:
raise ValueError(
"Got error %s when reading queue stats: %s" % (r.status_code, r.json())
)
return r.json()
def get_queue_length(app, queue_name):
"""Shortcut to get a queue's length"""
stats = get_queue_stats(app, queue_name)
if stats:
return stats.get("messages")
+MAX_NUM_TASKS = 10000
+
+
+def get_available_slots(app, queue_name: str, max_length: Optional[int]):
+ """Get the number of tasks that can be sent to `queue_name`, when
+ the queue is limited to `max_length`."""
+
+ if not max_length:
+ return MAX_NUM_TASKS
+
+ try:
+ queue_length = get_queue_length(app, queue_name)
+ # Clamp the return value to MAX_NUM_TASKS
+ max_val = min(max_length - queue_length, MAX_NUM_TASKS)
+ except (ValueError, TypeError):
+ # Unknown queue length, just schedule all the tasks
+ max_val = MAX_NUM_TASKS
+
+ return max_val
+
+
def register_task_class(app, name, cls):
"""Register a class-based task under the given name"""
if name in app.tasks:
return
task_instance = cls()
task_instance.name = name
app.register_task(task_instance)
INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR)
CONFIG_NAME = os.environ.get("SWH_CONFIG_FILENAME")
CONFIG = {} # type: Dict[str, Any]
if CONFIG_NAME:
# load the celery config from the main config file given as
# SWH_CONFIG_FILENAME environment variable.
# This is expected to have a [celery] section in which we have the
# celery specific configuration.
SWH_CONFIG.clear()
SWH_CONFIG.update(load_named_config(CONFIG_NAME))
CONFIG = SWH_CONFIG.get("celery", {})
if not CONFIG:
# otherwise, back to compat config loading mechanism
if INSTANCE_NAME:
CONFIG_NAME = CONFIG_NAME_TEMPLATE % INSTANCE_NAME
else:
CONFIG_NAME = DEFAULT_CONFIG_NAME
# Load the Celery config
CONFIG = load_named_config(CONFIG_NAME, DEFAULT_CONFIG)
CONFIG.setdefault("task_modules", [])
# load tasks modules declared as plugin entry points
for entrypoint in pkg_resources.iter_entry_points("swh.workers"):
worker_registrer_fn = entrypoint.load()
# The registry function is expected to return a dict which the 'tasks' key
# is a string (or a list of strings) with the name of the python module in
# which celery tasks are defined.
task_modules = worker_registrer_fn().get("task_modules", [])
CONFIG["task_modules"].extend(task_modules)
# Celery Queues
CELERY_QUEUES = [Queue("celery", Exchange("celery"), routing_key="celery")]
CELERY_DEFAULT_CONFIG = dict(
# Timezone configuration: all in UTC
enable_utc=True,
timezone="UTC",
# Imported modules
imports=CONFIG.get("task_modules", []),
# Time (in seconds, or a timedelta object) for when after stored task
# tombstones will be deleted. None means to never expire results.
result_expires=None,
# A string identifying the default serialization method to use. Can
# be json (default), pickle, yaml, msgpack, or any custom
# serialization methods that have been registered with
task_serializer="msgpack",
# Result serialization format
result_serializer="msgpack",
# Acknowledge tasks as soon as they're received. We can do this as we have
# external monitoring to decide if we need to retry tasks.
task_acks_late=False,
# A string identifying the default serialization method to use.
# Can be pickle (default), json, yaml, msgpack or any custom serialization
# methods that have been registered with kombu.serialization.registry
accept_content=["msgpack", "json"],
# If True the task will report its status as “started”
# when the task is executed by a worker.
task_track_started=True,
# Default compression used for task messages. Can be gzip, bzip2
# (if available), or any custom compression schemes registered
# in the Kombu compression registry.
# result_compression='bzip2',
# task_compression='bzip2',
# Disable all rate limits, even if tasks has explicit rate limits set.
# (Disabling rate limits altogether is recommended if you don’t have any
# tasks using them.)
worker_disable_rate_limits=True,
# Task routing
task_routes=route_for_task,
# Allow pool restarts from remote
worker_pool_restarts=True,
# Do not prefetch tasks
worker_prefetch_multiplier=1,
# Send events
worker_send_task_events=True,
# Do not send useless task_sent events
task_send_sent_event=False,
)
def build_app(config=None):
config = merge_configs(
{k: v for (k, (_, v)) in DEFAULT_CONFIG.items()}, config or {}
)
config["task_queues"] = CELERY_QUEUES + [
Queue(queue, Exchange(queue), routing_key=queue)
for queue in config.get("task_queues", ())
]
logger.debug("Creating a Celery app with %s", config)
# Instantiate the Celery app
app = Celery(broker=config["task_broker"], task_cls="swh.scheduler.task:SWHTask")
app.add_defaults(CELERY_DEFAULT_CONFIG)
app.add_defaults(config)
return app
app = build_app(CONFIG)
# XXX for BW compat
Celery.get_queue_length = get_queue_length
diff --git a/swh/scheduler/celery_backend/runner.py b/swh/scheduler/celery_backend/runner.py
index ed088d5..e1fe84d 100644
--- a/swh/scheduler/celery_backend/runner.py
+++ b/swh/scheduler/celery_backend/runner.py
@@ -1,165 +1,154 @@
# Copyright (C) 2015-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import logging
from typing import Dict, List, Tuple
from kombu.utils.uuid import uuid
from swh.core.statsd import statsd
from swh.scheduler import get_scheduler
+from swh.scheduler.celery_backend.config import get_available_slots
from swh.scheduler.interface import SchedulerInterface
from swh.scheduler.utils import utcnow
logger = logging.getLogger(__name__)
# Max batch size for tasks
MAX_NUM_TASKS = 10000
def run_ready_tasks(backend: SchedulerInterface, app) -> List[Dict]:
"""Schedule tasks ready to be scheduled.
This lookups any tasks per task type and mass schedules those accordingly (send
messages to rabbitmq and mark as scheduled equivalent tasks in the scheduler
backend).
If tasks (per task type) with priority exist, they will get redirected to dedicated
high priority queue (standard queue name prefixed with `save_code_now:`).
Args:
backend: scheduler backend to interact with (read/update tasks)
app (App): Celery application to send tasks to
Returns:
A list of dictionaries::
{
'task': the scheduler's task id,
'backend_id': Celery's task id,
'scheduler': utcnow()
}
The result can be used to block-wait for the tasks' results::
backend_tasks = run_ready_tasks(self.scheduler, app)
for task in backend_tasks:
AsyncResult(id=task['backend_id']).get()
"""
all_backend_tasks: List[Dict] = []
while True:
task_types = {}
pending_tasks = []
for task_type in backend.get_task_types():
task_type_name = task_type["type"]
task_types[task_type_name] = task_type
max_queue_length = task_type["max_queue_length"]
if max_queue_length is None:
max_queue_length = 0
backend_name = task_type["backend_name"]
- if max_queue_length:
- try:
- queue_length = app.get_queue_length(backend_name)
- except ValueError:
- queue_length = None
-
- if queue_length is None:
- # Running without RabbitMQ (probably a test env).
- num_tasks = MAX_NUM_TASKS
- else:
- num_tasks = min(max_queue_length - queue_length, MAX_NUM_TASKS)
- else:
- num_tasks = MAX_NUM_TASKS
+ num_tasks = get_available_slots(app, backend_name, max_queue_length)
# only pull tasks if the buffer is at least 1/5th empty (= 80%
# full), to help postgresql use properly indexed queries.
if num_tasks > min(MAX_NUM_TASKS, max_queue_length) // 5:
# Only grab num_tasks tasks with no priority
grabbed_tasks = backend.grab_ready_tasks(
task_type_name, num_tasks=num_tasks
)
if grabbed_tasks:
pending_tasks.extend(grabbed_tasks)
logger.info(
"Grabbed %s tasks %s", len(grabbed_tasks), task_type_name
)
statsd.increment(
"swh_scheduler_runner_scheduled_task_total",
len(grabbed_tasks),
tags={"task_type": task_type_name},
)
# grab max_queue_length (or 10) potential tasks with any priority for the
# same type (limit the result to avoid too long running queries)
grabbed_priority_tasks = backend.grab_ready_priority_tasks(
task_type_name, num_tasks=max_queue_length or 10
)
if grabbed_priority_tasks:
pending_tasks.extend(grabbed_priority_tasks)
logger.info(
"Grabbed %s tasks %s (priority)",
len(grabbed_priority_tasks),
task_type_name,
)
statsd.increment(
"swh_scheduler_runner_scheduled_task_total",
len(grabbed_priority_tasks),
tags={"task_type": task_type_name},
)
if not pending_tasks:
return all_backend_tasks
backend_tasks = []
celery_tasks: List[Tuple[bool, str, str, List, Dict]] = []
for task in pending_tasks:
args = task["arguments"]["args"]
kwargs = task["arguments"]["kwargs"]
backend_name = task_types[task["type"]]["backend_name"]
backend_id = uuid()
celery_tasks.append(
(
task.get("priority") is not None,
backend_name,
backend_id,
args,
kwargs,
)
)
data = {
"task": task["id"],
"backend_id": backend_id,
"scheduled": utcnow(),
}
backend_tasks.append(data)
logger.debug("Sent %s celery tasks", len(backend_tasks))
backend.mass_schedule_task_runs(backend_tasks)
for with_priority, backend_name, backend_id, args, kwargs in celery_tasks:
kw = dict(task_id=backend_id, args=args, kwargs=kwargs,)
if with_priority:
kw["queue"] = f"save_code_now:{backend_name}"
app.send_task(backend_name, **kw)
all_backend_tasks.extend(backend_tasks)
def main():
from .config import app as main_app
for module in main_app.conf.CELERY_IMPORTS:
__import__(module)
main_backend = get_scheduler("local")
try:
run_ready_tasks(main_backend, main_app)
except Exception:
main_backend.rollback()
raise
if __name__ == "__main__":
main()
diff --git a/swh/scheduler/tests/test_config.py b/swh/scheduler/tests/test_config.py
index c166f62..f0a705c 100644
--- a/swh/scheduler/tests/test_config.py
+++ b/swh/scheduler/tests/test_config.py
@@ -1,18 +1,55 @@
# Copyright (C) 2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import pytest
-from swh.scheduler.celery_backend.config import route_for_task
+from swh.scheduler.celery_backend.config import (
+ MAX_NUM_TASKS,
+ app,
+ get_available_slots,
+ route_for_task,
+)
@pytest.mark.parametrize("name", ["swh.something", "swh.anything"])
def test_route_for_task_routing(name):
assert route_for_task(name, [], {}, {}) == {"queue": name}
@pytest.mark.parametrize("name", [None, "foobar"])
def test_route_for_task_no_routing(name):
assert route_for_task(name, [], {}, {}) is None
+
+
+def test_get_available_slots_no_max_length():
+ actual_num = get_available_slots(app, "anything", None)
+ assert actual_num == MAX_NUM_TASKS
+
+
+def test_get_available_slots_issue_when_reading_queue(mocker):
+ mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length")
+ mock.side_effect = ValueError
+
+ actual_num = get_available_slots(app, "anything", max_length=10)
+ assert actual_num == MAX_NUM_TASKS
+ assert mock.called
+
+
+def test_get_available_slots_no_queue_length(mocker):
+ mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length")
+ mock.return_value = None
+ actual_num = get_available_slots(app, "anything", max_length=100)
+ assert actual_num == MAX_NUM_TASKS
+ assert mock.called
+
+
+def test_get_available_slots(mocker):
+ mock = mocker.patch("swh.scheduler.celery_backend.config.get_queue_length")
+ max_length = 100
+ queue_length = 90
+ mock.return_value = queue_length
+ actual_num = get_available_slots(app, "anything", max_length)
+ assert actual_num == max_length - queue_length
+ assert mock.called

File Metadata

Mime Type
text/x-diff
Expires
Fri, Jul 4, 12:48 PM (2 w, 1 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3285071

Event Timeline