Page MenuHomeSoftware Heritage

D888.diff
No OneTemporary

D888.diff

diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py
--- a/swh/scheduler/celery_backend/config.py
+++ b/swh/scheduler/celery_backend/config.py
@@ -108,7 +108,7 @@
and obj != Task # Don't register the abstract class itself
):
class_name = '%s.%s' % (module_name, name)
- instance.app.register_task_class(class_name, obj)
+ register_task_class(instance.app, class_name, obj)
for task_name in instance.app.tasks:
if task_name.startswith('swh.'):
@@ -127,58 +127,59 @@
return {'queue': name}
-class CustomCelery(Celery):
- def get_queue_stats(self, queue_name):
- """Get the statistics regarding a queue on the broker.
-
- Arguments:
- queue_name: name of the queue to check
-
- Returns a dictionary raw from the RabbitMQ management API;
- or `None` if the current configuration does not use RabbitMQ.
-
- Interesting keys:
- - consumers (number of consumers for the queue)
- - messages (number of messages in queue)
- - messages_unacknowledged (number of messages currently being
- processed)
-
- Documentation: https://www.rabbitmq.com/management.html#http-api
- """
-
- conn_info = self.connection().info()
- if conn_info['transport'] == 'memory':
- # We're running in a test environment, without RabbitMQ.
- return None
- url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
- hostname=conn_info['hostname'],
- port=conn_info['port'] + 10000,
- vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
- queue=urllib.parse.quote(queue_name, safe=''),
- )
- credentials = (conn_info['userid'], conn_info['password'])
- r = requests.get(url, auth=credentials)
- if r.status_code == 404:
- return {}
- if r.status_code != 200:
- raise ValueError('Got error %s when reading queue stats: %s' % (
- r.status_code, r.json()))
- return r.json()
-
- def get_queue_length(self, queue_name):
- """Shortcut to get a queue's length"""
- stats = self.get_queue_stats(queue_name)
- if stats:
- return stats.get('messages')
-
- def register_task_class(self, name, cls):
- """Register a class-based task under the given name"""
- if name in self.tasks:
- return
-
- task_instance = cls()
- task_instance.name = name
- self.register_task(task_instance)
+def get_queue_stats(app, queue_name):
+ """Get the statistics regarding a queue on the broker.
+
+ Arguments:
+ queue_name: name of the queue to check
+
+ Returns a dictionary raw from the RabbitMQ management API;
+ or `None` if the current configuration does not use RabbitMQ.
+
+ Interesting keys:
+ - consumers (number of consumers for the queue)
+ - messages (number of messages in queue)
+ - messages_unacknowledged (number of messages currently being
+ processed)
+
+ Documentation: https://www.rabbitmq.com/management.html#http-api
+ """
+
+ conn_info = app.connection().info()
+ if conn_info['transport'] == 'memory':
+ # We're running in a test environment, without RabbitMQ.
+ return None
+ url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
+ hostname=conn_info['hostname'],
+ port=conn_info['port'] + 10000,
+ vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
+ queue=urllib.parse.quote(queue_name, safe=''),
+ )
+ credentials = (conn_info['userid'], conn_info['password'])
+ r = requests.get(url, auth=credentials)
+ if r.status_code == 404:
+ return {}
+ if r.status_code != 200:
+ raise ValueError('Got error %s when reading queue stats: %s' % (
+ r.status_code, r.json()))
+ return r.json()
+
+
+def get_queue_length(app, queue_name):
+ """Shortcut to get a queue's length"""
+ stats = get_queue_stats(app, queue_name)
+ if stats:
+ return stats.get('messages')
+
+
+def register_task_class(app, name, cls):
+ """Register a class-based task under the given name"""
+ if name in app.tasks:
+ return
+
+ task_instance = cls()
+ task_instance.name = name
+ app.register_task(task_instance)
INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR)
@@ -196,11 +197,7 @@
for queue in CONFIG['task_queues']:
CELERY_QUEUES.append(Queue(queue, Exchange(queue), routing_key=queue))
-# Instantiate the Celery app
-app = CustomCelery()
-app.conf.update(
- # The broker
- broker_url=CONFIG['task_broker'],
+CELERY_DEFAULT_CONFIG = dict(
# Timezone configuration: all in UTC
enable_utc=True,
timezone='UTC',
@@ -254,4 +251,11 @@
worker_send_task_events=True,
# Do not send useless task_sent events
task_send_sent_event=False,
-)
+ )
+
+# Instantiate the Celery app
+app = Celery(broker=CONFIG['task_broker'])
+app.add_defaults(CELERY_DEFAULT_CONFIG)
+
+# XXX for BW compat
+Celery.get_queue_length = get_queue_length
diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py
--- a/swh/scheduler/tests/scheduler_testing.py
+++ b/swh/scheduler/tests/scheduler_testing.py
@@ -4,7 +4,7 @@
from celery.result import AsyncResult
from celery.contrib.testing.worker import start_worker
-import celery.contrib.testing.tasks # noqa
+import celery.contrib.testing.tasks # noqa
import pytest
from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES
@@ -12,7 +12,7 @@
from swh.scheduler import get_scheduler
from swh.scheduler.celery_backend.runner import run_ready_tasks
-from swh.scheduler.celery_backend.config import app
+from swh.scheduler.celery_backend.config import app, register_task_class
from swh.scheduler.tests.celery_testing import CeleryTestFixture
from . import SQL_DIR
@@ -42,7 +42,7 @@
}
self.scheduler.create_task_type(task_type)
if task_class:
- app.register_task_class(backend_name, task_class)
+ register_task_class(app, backend_name, task_class)
def run_ready_tasks(self):
"""Runs the scheduler and a Celery worker, then blocks until

File Metadata

Mime Type
text/plain
Expires
Dec 21 2024, 9:25 PM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222069

Event Timeline