diff --git a/swh/scheduler/celery_backend/config.py b/swh/scheduler/celery_backend/config.py index aee97ea..6bed5d0 100644 --- a/swh/scheduler/celery_backend/config.py +++ b/swh/scheduler/celery_backend/config.py @@ -1,217 +1,257 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import itertools +import importlib import logging import os import urllib.parse from celery import Celery from celery.signals import setup_logging, celeryd_after_setup from celery.utils.log import ColorFormatter from celery.worker.control import Panel from kombu import Exchange, Queue from kombu.five import monotonic as _monotonic import requests +from swh.scheduler.task import Task + from swh.core.config import load_named_config from swh.core.logger import JournalHandler DEFAULT_CONFIG_NAME = 'worker' CONFIG_NAME_ENVVAR = 'SWH_WORKER_INSTANCE' CONFIG_NAME_TEMPLATE = 'worker/%s' DEFAULT_CONFIG = { 'task_broker': ('str', 'amqp://guest@localhost//'), 'task_modules': ('list[str]', []), 'task_queues': ('list[str]', []), 'task_soft_time_limit': ('int', 0), } @setup_logging.connect def setup_log_handler(loglevel=None, logfile=None, format=None, colorize=None, **kwargs): """Setup logging according to Software Heritage preferences. We use the command-line loglevel for tasks only, as we never really care about the debug messages from celery. """ if loglevel is None: loglevel = logging.DEBUG formatter = logging.Formatter(format) root_logger = logging.getLogger('') root_logger.setLevel(logging.INFO) if loglevel == logging.DEBUG: color_formatter = ColorFormatter(format) if colorize else formatter console = logging.StreamHandler() console.setLevel(logging.DEBUG) console.setFormatter(color_formatter) root_logger.addHandler(console) systemd_journal = JournalHandler() systemd_journal.setLevel(logging.DEBUG) systemd_journal.setFormatter(formatter) root_logger.addHandler(systemd_journal) celery_logger = logging.getLogger('celery') celery_logger.setLevel(logging.INFO) # Silence useless "Starting new HTTP connection" messages urllib3_logger = logging.getLogger('urllib3') urllib3_logger.setLevel(logging.WARNING) swh_logger = logging.getLogger('swh') swh_logger.setLevel(loglevel) # get_task_logger makes the swh tasks loggers children of celery.task celery_task_logger = logging.getLogger('celery.task') celery_task_logger.setLevel(loglevel) @celeryd_after_setup.connect def setup_queues_and_tasks(sender, instance, **kwargs): + """Signal called on worker start. + + This automatically registers swh.scheduler.task.Task subclasses as + available celery tasks. + + This also subscribes the worker to the "implicit" per-task queues defined + for these task classes. + + """ + + for module_name in itertools.chain( + # celery worker -I flag + instance.app.conf['include'], + # set from the celery / swh worker instance configuration file + instance.app.conf['imports'], + ): + module = importlib.import_module(module_name) + for name in dir(module): + obj = getattr(module, name) + if ( + isinstance(obj, type) + and issubclass(obj, Task) + and obj != Task # Don't register the abstract class itself + ): + class_name = '%s.%s' % (module_name, name) + instance.app.register_task_class(class_name, obj) + for task_name in instance.app.tasks: if task_name.startswith('swh.'): instance.app.amqp.queues.select_add(task_name) @Panel.register def monotonic(state): """Get the current value for the monotonic clock""" return {'monotonic': _monotonic()} class TaskRouter: """Route tasks according to the task_queue attribute in the task class""" def route_for_task(self, task, *args, **kwargs): if task.startswith('swh.'): return {'queue': task} class CustomCelery(Celery): def get_queue_stats(self, queue_name): """Get the statistics regarding a queue on the broker. Arguments: queue_name: name of the queue to check Returns a dictionary raw from the RabbitMQ management API; or `None` if the current configuration does not use RabbitMQ. Interesting keys: - consumers (number of consumers for the queue) - messages (number of messages in queue) - messages_unacknowledged (number of messages currently being processed) Documentation: https://www.rabbitmq.com/management.html#http-api """ conn_info = self.connection().info() if conn_info['transport'] == 'memory': # We're running in a test environment, without RabbitMQ. return None url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format( hostname=conn_info['hostname'], port=conn_info['port'] + 10000, vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''), queue=urllib.parse.quote(queue_name, safe=''), ) credentials = (conn_info['userid'], conn_info['password']) r = requests.get(url, auth=credentials) if r.status_code == 404: return {} if r.status_code != 200: raise ValueError('Got error %s when reading queue stats: %s' % ( r.status_code, r.json())) return r.json() def get_queue_length(self, queue_name): """Shortcut to get a queue's length""" stats = self.get_queue_stats(queue_name) if stats: return stats.get('messages') + def register_task_class(self, name, cls): + """Register a class-based task under the given name""" + if name in self.tasks: + return + + task_instance = cls() + task_instance.name = name + self.register_task(task_instance) + INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR) if INSTANCE_NAME: CONFIG_NAME = CONFIG_NAME_TEMPLATE % INSTANCE_NAME else: CONFIG_NAME = DEFAULT_CONFIG_NAME # Load the Celery config CONFIG = load_named_config(CONFIG_NAME, DEFAULT_CONFIG) # Celery Queues CELERY_QUEUES = [Queue('celery', Exchange('celery'), routing_key='celery')] for queue in CONFIG['task_queues']: CELERY_QUEUES.append(Queue(queue, Exchange(queue), routing_key=queue)) # Instantiate the Celery app app = CustomCelery() app.conf.update( # The broker broker_url=CONFIG['task_broker'], # Timezone configuration: all in UTC enable_utc=True, timezone='UTC', # Imported modules imports=CONFIG['task_modules'], # Time (in seconds, or a timedelta object) for when after stored task # tombstones will be deleted. None means to never expire results. result_expires=None, # A string identifying the default serialization method to use. Can # be json (default), pickle, yaml, msgpack, or any custom # serialization methods that have been registered with task_serializer='msgpack', # Result serialization format result_serializer='msgpack', # Late ack means the task messages will be acknowledged after the task has # been executed, not just before, which is the default behavior. task_acks_late=True, # A string identifying the default serialization method to use. # Can be pickle (default), json, yaml, msgpack or any custom serialization # methods that have been registered with kombu.serialization.registry accept_content=['msgpack', 'json'], # If True the task will report its status as “started” # when the task is executed by a worker. task_track_started=True, # Default compression used for task messages. Can be gzip, bzip2 # (if available), or any custom compression schemes registered # in the Kombu compression registry. # result_compression='bzip2', # task_compression='bzip2', # Disable all rate limits, even if tasks has explicit rate limits set. # (Disabling rate limits altogether is recommended if you don’t have any # tasks using them.) worker_disable_rate_limits=True, # Task hard time limit in seconds. The worker processing the task will be # killed and replaced with a new one when this is exceeded. # task_time_limit=3600, # Task soft time limit in seconds. # The SoftTimeLimitExceeded exception will be raised when this is exceeded. # The task can catch this to e.g. clean up before the hard time limit # comes. task_soft_time_limit=CONFIG['task_soft_time_limit'], # Task routing task_routes=TaskRouter(), # Task queues this worker will consume from task_queues=CELERY_QUEUES, # Allow pool restarts from remote worker_pool_restarts=True, # Do not prefetch tasks worker_prefetch_multiplier=1, # Send events worker_send_task_events=True, # Do not send useless task_sent events task_send_sent_event=False, ) diff --git a/swh/scheduler/task.py b/swh/scheduler/task.py index d26eb56..0bea878 100644 --- a/swh/scheduler/task.py +++ b/swh/scheduler/task.py @@ -1,178 +1,50 @@ # Copyright (C) 2015-2017 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import celery.app.task from celery.utils.log import get_task_logger -from celery.app.task import TaskType -if TaskType is type: - # From Celery 3.1.25, celery/celery/app/task.py - # Copyright (c) 2015 Ask Solem & contributors. All rights reserved. - # Copyright (c) 2012-2014 GoPivotal, Inc. All rights reserved. - # Copyright (c) 2009, 2010, 2011, 2012 Ask Solem, and individual - # contributors. All rights reserved. - # - # Redistribution and use in source and binary forms, with or without - # modification, are permitted provided that the following conditions are - # met: - # * Redistributions of source code must retain the above copyright - # notice, this list of conditions and the following disclaimer. - # * Redistributions in binary form must reproduce the above copyright - # notice, this list of conditions and the following disclaimer in the - # documentation and/or other materials provided with the - # distribution. - # * Neither the name of Ask Solem, nor the names of its contributors - # may be used to endorse or promote products derived from this - # software without specific prior written permission. - # - # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS - # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL Ask Solem OR CONTRIBUTORS BE - # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF - # THE POSSIBILITY OF SUCH DAMAGE. - from celery import current_app - from celery.local import Proxy - from celery.utils import gen_task_name - - class _CompatShared(object): - - def __init__(self, name, cons): - self.name = name - self.cons = cons - - def __hash__(self): - return hash(self.name) - - def __repr__(self): - return '' % (self.name, ) - - def __call__(self, app): - return self.cons(app) - - class TaskType(type): - """Meta class for tasks. - - Automatically registers the task in the task registry (except if the - :attr:`Task.abstract`` attribute is set). - - If no :attr:`Task.name` attribute is provided, then the name is - generated from the module and class name. - - """ - _creation_count = {} # used by old non-abstract task classes - - def __new__(cls, name, bases, attrs): - new = super(TaskType, cls).__new__ - task_module = attrs.get('__module__') or '__main__' - - # - Abstract class: abstract attribute should not be inherited. - abstract = attrs.pop('abstract', None) - if abstract or not attrs.get('autoregister', True): - return new(cls, name, bases, attrs) - - # The 'app' attribute is now a property, with the real app located - # in the '_app' attribute. Previously this was a regular attribute, - # so we should support classes defining it. - app = attrs.pop('_app', None) or attrs.pop('app', None) - - # Attempt to inherit app from one the bases - if not isinstance(app, Proxy) and app is None: - for base in bases: - if getattr(base, '_app', None): - app = base._app - break - else: - app = current_app._get_current_object() - attrs['_app'] = app - - # - Automatically generate missing/empty name. - task_name = attrs.get('name') - if not task_name: - attrs['name'] = task_name = gen_task_name(app, name, - task_module) - - if not attrs.get('_decorated'): - # non decorated tasks must also be shared in case - # an app is created multiple times due to modules - # imported under multiple names. - # Hairy stuff, here to be compatible with 2.x. - # People should not use non-abstract task classes anymore, - # use the task decorator. - from celery._state import connect_on_app_finalize - unique_name = '.'.join([task_module, name]) - if unique_name not in cls._creation_count: - # the creation count is used as a safety - # so that the same task is not added recursively - # to the set of constructors. - cls._creation_count[unique_name] = 1 - connect_on_app_finalize(_CompatShared( - unique_name, - lambda app: TaskType.__new__(cls, name, bases, - dict(attrs, _app=app)), - )) - - # - Create and register class. - # Because of the way import happens (recursively) - # we may or may not be the first time the task tries to register - # with the framework. There should only be one class for each task - # name, so we always return the registered version. - tasks = app._tasks - if task_name not in tasks: - tasks.register(new(cls, name, bases, attrs)) - instance = tasks[task_name] - instance.bind(app) - return instance.__class__ - - -class Task(celery.app.task.Task, metaclass=TaskType): +class Task(celery.app.task.Task): """a schedulable task (abstract class) - Sub-classes must implement the run_task() method. Sub-classes that - want their tasks to get routed to a non-default task queue must - override the task_queue attribute. + Sub-classes must implement the run_task() method. Current implementation is based on Celery. See http://docs.celeryproject.org/en/latest/reference/celery.app.task.html for how to use tasks once instantiated """ abstract = True def run(self, *args, **kwargs): """This method is called by the celery worker when a task is received. Should not be overridden as we need our special events to be sent for the reccurrent scheduler. Override run_task instead.""" try: result = self.run_task(*args, **kwargs) except Exception as e: self.send_event('task-result-exception') raise e from None else: self.send_event('task-result', result=result) return result def run_task(self, *args, **kwargs): """Perform the task. Must return a json-serializable value as it is passed back to the task scheduler using a celery event. """ raise NotImplementedError('tasks must implement the run_task() method') @property def log(self): if not hasattr(self, '__log'): self.__log = get_task_logger('%s.%s' % (__name__, self.__class__.__name__)) return self.__log diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py index ef4e441..1fad1f6 100644 --- a/swh/scheduler/tests/scheduler_testing.py +++ b/swh/scheduler/tests/scheduler_testing.py @@ -1,76 +1,80 @@ import glob -import pytest import os.path import datetime from celery.result import AsyncResult from celery.contrib.testing.worker import start_worker import celery.contrib.testing.tasks # noqa +import pytest from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES from swh.core.utils import numfile_sortkey as sortkey from swh.scheduler import get_scheduler from swh.scheduler.celery_backend.runner import run_ready_tasks from swh.scheduler.celery_backend.config import app from swh.scheduler.tests.celery_testing import CeleryTestFixture from . import SQL_DIR DUMP_FILES = os.path.join(SQL_DIR, '*.sql') @pytest.mark.db class SchedulerTestFixture(CeleryTestFixture, DbTestFixture): """Base class for test case classes, providing an SWH scheduler as the `scheduler` attribute.""" SCHEDULER_DB_NAME = 'softwareheritage-scheduler-test-fixture' - def add_scheduler_task_type(self, task_type, backend_name): + def add_scheduler_task_type(self, task_type, backend_name, + task_class=None): task_type = { 'type': task_type, 'description': 'Update a git repository', 'backend_name': backend_name, 'default_interval': datetime.timedelta(days=64), 'min_interval': datetime.timedelta(hours=12), 'max_interval': datetime.timedelta(days=64), 'backoff_factor': 2, 'max_queue_length': None, 'num_retries': 7, 'retry_delay': datetime.timedelta(hours=2), } self.scheduler.create_task_type(task_type) + if task_class: + app.register_task_class(backend_name, task_class) def run_ready_tasks(self): """Runs the scheduler and a Celery worker, then blocks until all tasks are completed.""" # Make sure the worker is listening to all task-specific queues for task in self.scheduler.get_task_types(): app.amqp.queues.select_add(task['backend_name']) with start_worker(app): backend_tasks = run_ready_tasks(self.scheduler, app) for task in backend_tasks: + # Make sure the task completed AsyncResult(id=task['backend_id']).get() @classmethod def setUpClass(cls): all_dump_files = sorted(glob.glob(DUMP_FILES), key=sortkey) all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) for x in all_dump_files] cls.add_db(name=cls.SCHEDULER_DB_NAME, dumps=all_dump_files) super().setUpClass() def setUp(self): super().setUp() self.scheduler_config = { 'scheduling_db': 'dbname=' + self.SCHEDULER_DB_NAME} self.scheduler = get_scheduler('local', self.scheduler_config) def tearDown(self): self.scheduler.close_connection() super().tearDown() diff --git a/swh/scheduler/tests/test_fixtures.py b/swh/scheduler/tests/test_fixtures.py index 9e9146e..e69ac9d 100644 --- a/swh/scheduler/tests/test_fixtures.py +++ b/swh/scheduler/tests/test_fixtures.py @@ -1,38 +1,39 @@ # Copyright (C) 2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest from swh.scheduler.tests.scheduler_testing import SchedulerTestFixture from swh.scheduler.task import Task from swh.scheduler.utils import create_task_dict task_has_run = False class SomeTestTask(Task): def run(self, *, foo): global task_has_run assert foo == 'bar' task_has_run = True class FixtureTest(SchedulerTestFixture, unittest.TestCase): def setUp(self): super().setUp() self.add_scheduler_task_type( 'some_test_task_type', - 'swh.scheduler.tests.test_fixtures.SomeTestTask') + 'swh.scheduler.tests.test_fixtures.SomeTestTask', + SomeTestTask, + ) def test_task_run(self): self.scheduler.create_tasks([create_task_dict( 'some_test_task_type', 'oneshot', foo='bar', )]) - self.assertEqual(task_has_run, False) self.run_ready_tasks() self.assertEqual(task_has_run, True) diff --git a/swh/scheduler/tests/test_task.py b/swh/scheduler/tests/test_task.py index 7e2130e..9abe842 100644 --- a/swh/scheduler/tests/test_task.py +++ b/swh/scheduler/tests/test_task.py @@ -1,28 +1,38 @@ # Copyright (C) 2015 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import unittest +from celery import current_app as app + from swh.scheduler import task from .celery_testing import CeleryTestFixture class Task(CeleryTestFixture, unittest.TestCase): def test_not_implemented_task(self): class NotImplementedTask(task.Task): + name = 'NotImplementedTask' + pass + app.register_task(NotImplementedTask()) + with self.assertRaises(NotImplementedError): NotImplementedTask().run() def test_add_task(self): class AddTask(task.Task): + name = 'AddTask' + def run_task(self, x, y): return x + y + app.register_task(AddTask()) + r = AddTask().apply([2, 3]) self.assertTrue(r.successful()) self.assertEqual(r.result, 5)