diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -108,7 +108,7 @@ and obj != Task # Don't register the abstract class itself ): class_name = '%s.%s' % (module_name, name) - instance.app.register_task_class(class_name, obj) + register_task_class(instance.app, class_name, obj) for task_name in instance.app.tasks: if task_name.startswith('swh.'): @@ -127,58 +127,59 @@ return {'queue': name} -class CustomCelery(Celery): - def get_queue_stats(self, queue_name): - """Get the statistics regarding a queue on the broker. - - Arguments: - queue_name: name of the queue to check - - Returns a dictionary raw from the RabbitMQ management API; - or `None` if the current configuration does not use RabbitMQ. - - Interesting keys: - - consumers (number of consumers for the queue) - - messages (number of messages in queue) - - messages_unacknowledged (number of messages currently being - processed) - - Documentation: https://www.rabbitmq.com/management.html#http-api - """ - - conn_info = self.connection().info() - if conn_info['transport'] == 'memory': - # We're running in a test environment, without RabbitMQ. - return None - url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format( - hostname=conn_info['hostname'], - port=conn_info['port'] + 10000, - vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''), - queue=urllib.parse.quote(queue_name, safe=''), - ) - credentials = (conn_info['userid'], conn_info['password']) - r = requests.get(url, auth=credentials) - if r.status_code == 404: - return {} - if r.status_code != 200: - raise ValueError('Got error %s when reading queue stats: %s' % ( - r.status_code, r.json())) - return r.json() - - def get_queue_length(self, queue_name): - """Shortcut to get a queue's length""" - stats = self.get_queue_stats(queue_name) - if stats: - return stats.get('messages') - - def register_task_class(self, name, cls): - """Register a class-based task under the given name""" - if name in self.tasks: - return - - task_instance = cls() - task_instance.name = name - self.register_task(task_instance) +def get_queue_stats(app, queue_name): + """Get the statistics regarding a queue on the broker. + + Arguments: + queue_name: name of the queue to check + + Returns a dictionary raw from the RabbitMQ management API; + or `None` if the current configuration does not use RabbitMQ. + + Interesting keys: + - consumers (number of consumers for the queue) + - messages (number of messages in queue) + - messages_unacknowledged (number of messages currently being + processed) + + Documentation: https://www.rabbitmq.com/management.html#http-api + """ + + conn_info = app.connection().info() + if conn_info['transport'] == 'memory': + # We're running in a test environment, without RabbitMQ. + return None + url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format( + hostname=conn_info['hostname'], + port=conn_info['port'] + 10000, + vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''), + queue=urllib.parse.quote(queue_name, safe=''), + ) + credentials = (conn_info['userid'], conn_info['password']) + r = requests.get(url, auth=credentials) + if r.status_code == 404: + return {} + if r.status_code != 200: + raise ValueError('Got error %s when reading queue stats: %s' % ( + r.status_code, r.json())) + return r.json() + + +def get_queue_length(app, queue_name): + """Shortcut to get a queue's length""" + stats = get_queue_stats(app, queue_name) + if stats: + return stats.get('messages') + + +def register_task_class(app, name, cls): + """Register a class-based task under the given name""" + if name in app.tasks: + return + + task_instance = cls() + task_instance.name = name + app.register_task(task_instance) INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR) @@ -196,11 +197,7 @@ for queue in CONFIG['task_queues']: CELERY_QUEUES.append(Queue(queue, Exchange(queue), routing_key=queue)) -# Instantiate the Celery app -app = CustomCelery() -app.conf.update( - # The broker - broker_url=CONFIG['task_broker'], +CELERY_DEFAULT_CONFIG = dict( # Timezone configuration: all in UTC enable_utc=True, timezone='UTC', @@ -254,4 +251,11 @@ worker_send_task_events=True, # Do not send useless task_sent events task_send_sent_event=False, -) + ) + +# Instantiate the Celery app +app = Celery(broker=CONFIG['task_broker']) +app.add_defaults(CELERY_DEFAULT_CONFIG) + +# XXX for BW compat +Celery.get_queue_length = get_queue_length diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py --- a/swh/scheduler/tests/scheduler_testing.py +++ b/swh/scheduler/tests/scheduler_testing.py @@ -4,7 +4,7 @@ from celery.result import AsyncResult from celery.contrib.testing.worker import start_worker -import celery.contrib.testing.tasks # noqa +import celery.contrib.testing.tasks # noqa import pytest from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES @@ -12,7 +12,7 @@ from swh.scheduler import get_scheduler from swh.scheduler.celery_backend.runner import run_ready_tasks -from swh.scheduler.celery_backend.config import app +from swh.scheduler.celery_backend.config import app, register_task_class from swh.scheduler.tests.celery_testing import CeleryTestFixture from . import SQL_DIR @@ -42,7 +42,7 @@ } self.scheduler.create_task_type(task_type) if task_class: - app.register_task_class(backend_name, task_class) + register_task_class(app, backend_name, task_class) def run_ready_tasks(self): """Runs the scheduler and a Celery worker, then blocks until