Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7124835
D888.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
6 KB
Subscribers
None
D888.diff
View Options
diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py
--- a/swh/scheduler/celery_backend/config.py
+++ b/swh/scheduler/celery_backend/config.py
@@ -108,7 +108,7 @@
and obj != Task # Don't register the abstract class itself
):
class_name = '%s.%s' % (module_name, name)
- instance.app.register_task_class(class_name, obj)
+ register_task_class(instance.app, class_name, obj)
for task_name in instance.app.tasks:
if task_name.startswith('swh.'):
@@ -127,58 +127,59 @@
return {'queue': name}
-class CustomCelery(Celery):
- def get_queue_stats(self, queue_name):
- """Get the statistics regarding a queue on the broker.
-
- Arguments:
- queue_name: name of the queue to check
-
- Returns a dictionary raw from the RabbitMQ management API;
- or `None` if the current configuration does not use RabbitMQ.
-
- Interesting keys:
- - consumers (number of consumers for the queue)
- - messages (number of messages in queue)
- - messages_unacknowledged (number of messages currently being
- processed)
-
- Documentation: https://www.rabbitmq.com/management.html#http-api
- """
-
- conn_info = self.connection().info()
- if conn_info['transport'] == 'memory':
- # We're running in a test environment, without RabbitMQ.
- return None
- url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
- hostname=conn_info['hostname'],
- port=conn_info['port'] + 10000,
- vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
- queue=urllib.parse.quote(queue_name, safe=''),
- )
- credentials = (conn_info['userid'], conn_info['password'])
- r = requests.get(url, auth=credentials)
- if r.status_code == 404:
- return {}
- if r.status_code != 200:
- raise ValueError('Got error %s when reading queue stats: %s' % (
- r.status_code, r.json()))
- return r.json()
-
- def get_queue_length(self, queue_name):
- """Shortcut to get a queue's length"""
- stats = self.get_queue_stats(queue_name)
- if stats:
- return stats.get('messages')
-
- def register_task_class(self, name, cls):
- """Register a class-based task under the given name"""
- if name in self.tasks:
- return
-
- task_instance = cls()
- task_instance.name = name
- self.register_task(task_instance)
+def get_queue_stats(app, queue_name):
+ """Get the statistics regarding a queue on the broker.
+
+ Arguments:
+ queue_name: name of the queue to check
+
+ Returns a dictionary raw from the RabbitMQ management API;
+ or `None` if the current configuration does not use RabbitMQ.
+
+ Interesting keys:
+ - consumers (number of consumers for the queue)
+ - messages (number of messages in queue)
+ - messages_unacknowledged (number of messages currently being
+ processed)
+
+ Documentation: https://www.rabbitmq.com/management.html#http-api
+ """
+
+ conn_info = app.connection().info()
+ if conn_info['transport'] == 'memory':
+ # We're running in a test environment, without RabbitMQ.
+ return None
+ url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
+ hostname=conn_info['hostname'],
+ port=conn_info['port'] + 10000,
+ vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
+ queue=urllib.parse.quote(queue_name, safe=''),
+ )
+ credentials = (conn_info['userid'], conn_info['password'])
+ r = requests.get(url, auth=credentials)
+ if r.status_code == 404:
+ return {}
+ if r.status_code != 200:
+ raise ValueError('Got error %s when reading queue stats: %s' % (
+ r.status_code, r.json()))
+ return r.json()
+
+
+def get_queue_length(app, queue_name):
+ """Shortcut to get a queue's length"""
+ stats = get_queue_stats(app, queue_name)
+ if stats:
+ return stats.get('messages')
+
+
+def register_task_class(app, name, cls):
+ """Register a class-based task under the given name"""
+ if name in app.tasks:
+ return
+
+ task_instance = cls()
+ task_instance.name = name
+ app.register_task(task_instance)
INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR)
@@ -196,11 +197,7 @@
for queue in CONFIG['task_queues']:
CELERY_QUEUES.append(Queue(queue, Exchange(queue), routing_key=queue))
-# Instantiate the Celery app
-app = CustomCelery()
-app.conf.update(
- # The broker
- broker_url=CONFIG['task_broker'],
+CELERY_DEFAULT_CONFIG = dict(
# Timezone configuration: all in UTC
enable_utc=True,
timezone='UTC',
@@ -254,4 +251,11 @@
worker_send_task_events=True,
# Do not send useless task_sent events
task_send_sent_event=False,
-)
+ )
+
+# Instantiate the Celery app
+app = Celery(broker=CONFIG['task_broker'])
+app.add_defaults(CELERY_DEFAULT_CONFIG)
+
+# XXX for BW compat
+Celery.get_queue_length = get_queue_length
diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py
--- a/swh/scheduler/tests/scheduler_testing.py
+++ b/swh/scheduler/tests/scheduler_testing.py
@@ -4,7 +4,7 @@
from celery.result import AsyncResult
from celery.contrib.testing.worker import start_worker
-import celery.contrib.testing.tasks # noqa
+import celery.contrib.testing.tasks # noqa
import pytest
from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES
@@ -12,7 +12,7 @@
from swh.scheduler import get_scheduler
from swh.scheduler.celery_backend.runner import run_ready_tasks
-from swh.scheduler.celery_backend.config import app
+from swh.scheduler.celery_backend.config import app, register_task_class
from swh.scheduler.tests.celery_testing import CeleryTestFixture
from . import SQL_DIR
@@ -42,7 +42,7 @@
}
self.scheduler.create_task_type(task_type)
if task_class:
- app.register_task_class(backend_name, task_class)
+ register_task_class(app, backend_name, task_class)
def run_ready_tasks(self):
"""Runs the scheduler and a Celery worker, then blocks until
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Dec 21 2024, 9:25 PM (11 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222069
Attached To
D888: Kill the CustomCelery class
Event Timeline
Log In to Comment