diff --git a/swh/scheduler/tests/conftest.py b/swh/scheduler/tests/conftest.py --- a/swh/scheduler/tests/conftest.py +++ b/swh/scheduler/tests/conftest.py @@ -13,7 +13,6 @@ from swh.core.utils import numfile_sortkey as sortkey from swh.scheduler import get_scheduler from swh.scheduler.tests import SQL_DIR -from swh.scheduler.tests.tasks import register_test_tasks # make sure we are not fooled by CELERY_ config environment vars @@ -62,15 +61,14 @@ } -# override the celery_session_app fixture to monkeypatch the 'main' +# use the celery_session_app fixture to monkeypatch the 'main' # swh.scheduler.celery_backend.config.app Celery application -# with the test application (and also register test tasks) +# with the test application @pytest.fixture(scope='session') def swh_app(celery_session_app): - from swh.scheduler.celery_backend.config import app - register_test_tasks(celery_session_app) - app = celery_session_app # noqa - yield app + from swh.scheduler.celery_backend import config + config.app = celery_session_app + yield celery_session_app @pytest.fixture diff --git a/swh/scheduler/tests/tasks.py b/swh/scheduler/tests/tasks.py --- a/swh/scheduler/tests/tasks.py +++ b/swh/scheduler/tests/tasks.py @@ -5,39 +5,34 @@ from celery import group +from swh.scheduler.celery_backend.config import app -def register_test_tasks(app): - """Register test tasks for the specific app passed as parameter. - - In the test context, app is the swh_app and not the runtime one. - - Args: - app: Celery app. Expects the tests application - (swh.scheduler.tests.conftest.swh_app) - - """ - @app.task(name='swh.scheduler.tests.tasks.ping', bind=True) - def ping(self, **kw): - # check this is a SWHTask - assert hasattr(self, 'log') - assert not hasattr(self, 'run_task') - assert 'SWHTask' in [x.__name__ for x in self.__class__.__mro__] - self.log.debug(self.name) - if kw: - return 'OK (kw=%s)' % kw - return 'OK' - - @app.task(name='swh.scheduler.tests.tasks.multiping', bind=True) - def multiping(self, n=10): - promise = group(ping.s(i=i) for i in range(n))() - self.log.debug('%s OK (spawned %s subtasks)' % (self.name, n)) - promise.save() - return promise.id - - @app.task(name='swh.scheduler.tests.tasks.error') - def not_implemented(): - raise NotImplementedError('Nope') - - @app.task(name='swh.scheduler.tests.tasks.add') - def add(x, y): - return x + y + +@app.task(name='swh.scheduler.tests.tasks.ping', bind=True) +def ping(self, **kw): + # check this is a SWHTask + assert hasattr(self, 'log') + assert not hasattr(self, 'run_task') + assert 'SWHTask' in [x.__name__ for x in self.__class__.__mro__] + self.log.debug(self.name) + if kw: + return 'OK (kw=%s)' % kw + return 'OK' + + +@app.task(name='swh.scheduler.tests.tasks.multiping', bind=True) +def multiping(self, n=10): + promise = group(ping.s(i=i) for i in range(n))() + self.log.debug('%s OK (spawned %s subtasks)' % (self.name, n)) + promise.save() + return promise.id + + +@app.task(name='swh.scheduler.tests.tasks.error') +def not_implemented(): + raise NotImplementedError('Nope') + + +@app.task(name='swh.scheduler.tests.tasks.add') +def add(x, y): + return x + y