diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import itertools +import importlib import logging import os import urllib.parse @@ -17,6 +19,8 @@ import requests +from swh.scheduler.task import Task + from swh.core.config import load_named_config from swh.core.logger import JournalHandler @@ -78,6 +82,33 @@ @celeryd_after_setup.connect def setup_queues_and_tasks(sender, instance, **kwargs): + """Signal called on worker start. + + This automatically registers swh.scheduler.task.Task subclasses as + available celery tasks. + + This also subscribes the worker to the "implicit" per-task queues defined + for these task classes. + + """ + + for module_name in itertools.chain( + # celery worker -I flag + instance.app.conf['include'], + # set from the celery / swh worker instance configuration file + instance.app.conf['imports'], + ): + module = importlib.import_module(module_name) + for name in dir(module): + obj = getattr(module, name) + if ( + isinstance(obj, type) + and issubclass(obj, Task) + and obj != Task # Don't register the abstract class itself + ): + class_name = '%s.%s' % (module_name, name) + instance.app.register_task_class(class_name, obj) + for task_name in instance.app.tasks: if task_name.startswith('swh.'): instance.app.amqp.queues.select_add(task_name) @@ -140,6 +171,15 @@ if stats: return stats.get('messages') + def register_task_class(self, name, cls): + """Register a class-based task under the given name""" + if name in self.tasks: + return + + task_instance = cls() + task_instance.name = name + self.register_task(task_instance) + INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR) if INSTANCE_NAME: diff --git a/swh/scheduler/task.py b/swh/scheduler/task.py --- a/swh/scheduler/task.py +++ b/swh/scheduler/task.py @@ -6,139 +6,11 @@ import celery.app.task from celery.utils.log import get_task_logger -from celery.app.task import TaskType -if TaskType is type: - # From Celery 3.1.25, celery/celery/app/task.py - # Copyright (c) 2015 Ask Solem & contributors. All rights reserved. - # Copyright (c) 2012-2014 GoPivotal, Inc. All rights reserved. - # Copyright (c) 2009, 2010, 2011, 2012 Ask Solem, and individual - # contributors. All rights reserved. - # - # Redistribution and use in source and binary forms, with or without - # modification, are permitted provided that the following conditions are - # met: - # * Redistributions of source code must retain the above copyright - # notice, this list of conditions and the following disclaimer. - # * Redistributions in binary form must reproduce the above copyright - # notice, this list of conditions and the following disclaimer in the - # documentation and/or other materials provided with the - # distribution. - # * Neither the name of Ask Solem, nor the names of its contributors - # may be used to endorse or promote products derived from this - # software without specific prior written permission. - # - # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS - # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Ask Solem OR CONTRIBUTORS BE - # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF - # THE POSSIBILITY OF SUCH DAMAGE. - from celery import current_app - from celery.local import Proxy - from celery.utils import gen_task_name - - class _CompatShared(object): - - def __init__(self, name, cons): - self.name = name - self.cons = cons - - def __hash__(self): - return hash(self.name) - - def __repr__(self): - return '' % (self.name, ) - - def __call__(self, app): - return self.cons(app) - - class TaskType(type): - """Meta class for tasks. - - Automatically registers the task in the task registry (except if the - :attr:`Task.abstract`` attribute is set). - - If no :attr:`Task.name` attribute is provided, then the name is - generated from the module and class name. - - """ - _creation_count = {} # used by old non-abstract task classes - - def __new__(cls, name, bases, attrs): - new = super(TaskType, cls).__new__ - task_module = attrs.get('__module__') or '__main__' - - # - Abstract class: abstract attribute should not be inherited. - abstract = attrs.pop('abstract', None) - if abstract or not attrs.get('autoregister', True): - return new(cls, name, bases, attrs) - - # The 'app' attribute is now a property, with the real app located - # in the '_app' attribute. Previously this was a regular attribute, - # so we should support classes defining it. - app = attrs.pop('_app', None) or attrs.pop('app', None) - - # Attempt to inherit app from one the bases - if not isinstance(app, Proxy) and app is None: - for base in bases: - if getattr(base, '_app', None): - app = base._app - break - else: - app = current_app._get_current_object() - attrs['_app'] = app - - # - Automatically generate missing/empty name. - task_name = attrs.get('name') - if not task_name: - attrs['name'] = task_name = gen_task_name(app, name, - task_module) - - if not attrs.get('_decorated'): - # non decorated tasks must also be shared in case - # an app is created multiple times due to modules - # imported under multiple names. - # Hairy stuff, here to be compatible with 2.x. - # People should not use non-abstract task classes anymore, - # use the task decorator. - from celery._state import connect_on_app_finalize - unique_name = '.'.join([task_module, name]) - if unique_name not in cls._creation_count: - # the creation count is used as a safety - # so that the same task is not added recursively - # to the set of constructors. - cls._creation_count[unique_name] = 1 - connect_on_app_finalize(_CompatShared( - unique_name, - lambda app: TaskType.__new__(cls, name, bases, - dict(attrs, _app=app)), - )) - - # - Create and register class. - # Because of the way import happens (recursively) - # we may or may not be the first time the task tries to register - # with the framework. There should only be one class for each task - # name, so we always return the registered version. - tasks = app._tasks - if task_name not in tasks: - tasks.register(new(cls, name, bases, attrs)) - instance = tasks[task_name] - instance.bind(app) - return instance.__class__ - - -class Task(celery.app.task.Task, metaclass=TaskType): +class Task(celery.app.task.Task): """a schedulable task (abstract class) - Sub-classes must implement the run_task() method. Sub-classes that - want their tasks to get routed to a non-default task queue must - override the task_queue attribute. + Sub-classes must implement the run_task() method. Current implementation is based on Celery. See http://docs.celeryproject.org/en/latest/reference/celery.app.task.html for diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py --- a/swh/scheduler/tests/scheduler_testing.py +++ b/swh/scheduler/tests/scheduler_testing.py @@ -1,11 +1,11 @@ import glob -import pytest import os.path import datetime from celery.result import AsyncResult from celery.contrib.testing.worker import start_worker import celery.contrib.testing.tasks # noqa +import pytest from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES from swh.core.utils import numfile_sortkey as sortkey @@ -26,7 +26,8 @@ the `scheduler` attribute.""" SCHEDULER_DB_NAME = 'softwareheritage-scheduler-test-fixture' - def add_scheduler_task_type(self, task_type, backend_name): + def add_scheduler_task_type(self, task_type, backend_name, + task_class=None): task_type = { 'type': task_type, 'description': 'Update a git repository', @@ -40,6 +41,8 @@ 'retry_delay': datetime.timedelta(hours=2), } self.scheduler.create_task_type(task_type) + if task_class: + app.register_task_class(backend_name, task_class) def run_ready_tasks(self): """Runs the scheduler and a Celery worker, then blocks until @@ -52,6 +55,7 @@ with start_worker(app): backend_tasks = run_ready_tasks(self.scheduler, app) for task in backend_tasks: + # Make sure the task completed AsyncResult(id=task['backend_id']).get() @classmethod diff --git a/swh/scheduler/tests/test_fixtures.py b/swh/scheduler/tests/test_fixtures.py --- a/swh/scheduler/tests/test_fixtures.py +++ b/swh/scheduler/tests/test_fixtures.py @@ -24,7 +24,9 @@ super().setUp() self.add_scheduler_task_type( 'some_test_task_type', - 'swh.scheduler.tests.test_fixtures.SomeTestTask') + 'swh.scheduler.tests.test_fixtures.SomeTestTask', + SomeTestTask, + ) def test_task_run(self): self.scheduler.create_tasks([create_task_dict( @@ -32,7 +34,6 @@ 'oneshot', foo='bar', )]) - self.assertEqual(task_has_run, False) self.run_ready_tasks() self.assertEqual(task_has_run, True) diff --git a/swh/scheduler/tests/test_task.py b/swh/scheduler/tests/test_task.py --- a/swh/scheduler/tests/test_task.py +++ b/swh/scheduler/tests/test_task.py @@ -5,6 +5,8 @@ import unittest +from celery import current_app as app + from swh.scheduler import task from .celery_testing import CeleryTestFixture @@ -13,16 +15,24 @@ def test_not_implemented_task(self): class NotImplementedTask(task.Task): + name = 'NotImplementedTask' + pass + app.register_task(NotImplementedTask()) + with self.assertRaises(NotImplementedError): NotImplementedTask().run() def test_add_task(self): class AddTask(task.Task): + name = 'AddTask' + def run_task(self, x, y): return x + y + app.register_task(AddTask()) + r = AddTask().apply([2, 3]) self.assertTrue(r.successful()) self.assertEqual(r.result, 5)