diff --git a/requirements-test.txt b/requirements-test.txt --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,4 @@ hypothesis pytest +pytest-postgresql celery >= 4 diff --git a/swh/scheduler/tests/celery_testing.py b/swh/scheduler/tests/celery_testing.py deleted file mode 100644 --- a/swh/scheduler/tests/celery_testing.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - - -def setup_celery(): - os.environ.setdefault('CELERY_BROKER_URL', 'memory://') - os.environ.setdefault('CELERY_RESULT_BACKEND', 'cache+memory://') - - -class CeleryTestFixture: - """Mix this in a test subject class to setup Celery config for testing - purpose. - - Can be overriden by CELERY_BROKER_URL and CELERY_RESULT_BACKEND env vars. - """ - - def setUp(self): - setup_celery() - super().setUp() diff --git a/swh/scheduler/tests/conftest.py b/swh/scheduler/tests/conftest.py --- a/swh/scheduler/tests/conftest.py +++ b/swh/scheduler/tests/conftest.py @@ -1,4 +1,17 @@ +import os import pytest +import glob +from datetime import timedelta + +from swh.core.utils import numfile_sortkey as sortkey +from swh.scheduler import get_scheduler +from swh.scheduler.tests import SQL_DIR + +DUMP_FILES = os.path.join(SQL_DIR, '*.sql') + +# celery tasks for testing purpose; tasks themselves should be +# in swh/scheduler/tests/celery_tasks.py +TASK_NAMES = ['ping', 'multiping', 'add', 'error'] @pytest.fixture(scope='session') @@ -16,7 +29,16 @@ @pytest.fixture(scope='session') def celery_parameters(): return { - 'task_cls': 'swh.scheduler.task:SWHTask', + 'task_cls': 'swh.scheduler.task:SWHTask', + } + + +@pytest.fixture(scope='session') +def celery_config(): + return { + 'accept_content': ['application/x-msgpack', 'application/json'], + 'task_serializer': 'msgpack', + 'result_serializer': 'msgpack', } @@ -28,3 +50,35 @@ import swh.scheduler.celery_backend.config swh.scheduler.celery_backend.config.app = celery_session_app yield celery_session_app + + +@pytest.fixture +def swh_scheduler(request, postgresql_proc, postgresql): + scheduler_config = { + 'scheduling_db': 'postgresql://{user}@{host}:{port}/{dbname}'.format( + host=postgresql_proc.host, + port=postgresql_proc.port, + user='postgres', + dbname='tests') + } + + all_dump_files = sorted(glob.glob(DUMP_FILES), key=sortkey) + + cursor = postgresql.cursor() + for fname in all_dump_files: + with open(fname) as fobj: + cursor.execute(fobj.read()) + postgresql.commit() + + scheduler = get_scheduler('local', scheduler_config) + for taskname in TASK_NAMES: + scheduler.create_task_type({ + 'type': 'swh-test-{}'.format(taskname), + 'description': 'The {} testing task'.format(taskname), + 'backend_name': 'swh.scheduler.tests.tasks.{}'.format(taskname), + 'default_interval': timedelta(days=1), + 'min_interval': timedelta(hours=6), + 'max_interval': timedelta(days=12), + }) + + return scheduler diff --git a/swh/scheduler/tests/scheduler_testing.py b/swh/scheduler/tests/scheduler_testing.py deleted file mode 100644 --- a/swh/scheduler/tests/scheduler_testing.py +++ /dev/null @@ -1,80 +0,0 @@ -import glob -import os.path -import datetime - -from celery.result import AsyncResult -from celery.contrib.testing.worker import start_worker -import celery.contrib.testing.tasks # noqa -import pytest - -from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES -from swh.core.utils import numfile_sortkey as sortkey - -from swh.scheduler import get_scheduler -from swh.scheduler.celery_backend.runner import run_ready_tasks -from swh.scheduler.celery_backend.config import app, register_task_class -from swh.scheduler.tests.celery_testing import CeleryTestFixture - -from . import SQL_DIR - -DUMP_FILES = os.path.join(SQL_DIR, '*.sql') - - -@pytest.mark.db -class SchedulerTestFixture(CeleryTestFixture, DbTestFixture): - """Base class for test case classes, providing an SWH scheduler as - the `scheduler` attribute.""" - SCHEDULER_DB_NAME = 'softwareheritage-scheduler-test-fixture' - - def add_scheduler_task_type(self, task_type, backend_name, - task_class=None): - task_type = { - 'type': task_type, - 'description': 'Update a git repository', - 'backend_name': backend_name, - 'default_interval': datetime.timedelta(days=64), - 'min_interval': datetime.timedelta(hours=12), - 'max_interval': datetime.timedelta(days=64), - 'backoff_factor': 2, - 'max_queue_length': None, - 'num_retries': 7, - 'retry_delay': datetime.timedelta(hours=2), - } - self.scheduler.create_task_type(task_type) - if task_class: - register_task_class(app, backend_name, task_class) - - def run_ready_tasks(self): - """Runs the scheduler and a Celery worker, then blocks until - all tasks are completed.""" - - # Make sure the worker is listening to all task-specific queues - for task in self.scheduler.get_task_types(): - app.amqp.queues.select_add(task['backend_name']) - - with start_worker(app): - backend_tasks = run_ready_tasks(self.scheduler, app) - for task in backend_tasks: - # Make sure the task completed - AsyncResult(id=task['backend_id']).get() - - @classmethod - def setUpClass(cls): - all_dump_files = sorted(glob.glob(DUMP_FILES), key=sortkey) - - all_dump_files = [(x, DB_DUMP_TYPES[os.path.splitext(x)[1]]) - for x in all_dump_files] - - cls.add_db(name=cls.SCHEDULER_DB_NAME, - dumps=all_dump_files) - super().setUpClass() - - def setUp(self): - super().setUp() - self.scheduler_config = { - 'scheduling_db': 'dbname=' + self.SCHEDULER_DB_NAME} - self.scheduler = get_scheduler('local', self.scheduler_config) - - def tearDown(self): - self.scheduler.close_connection() - super().tearDown() diff --git a/swh/scheduler/tests/tasks.py b/swh/scheduler/tests/tasks.py --- a/swh/scheduler/tests/tasks.py +++ b/swh/scheduler/tests/tasks.py @@ -25,3 +25,17 @@ self.log.debug('%s OK (spawned %s subtasks)' % (self.name, n)) promise.save() return promise.id + + +@app.task(name='swh.scheduler.tests.tasks.error', + bind=True) +def not_implemented(self): + self.log.debug(self.name) + raise NotImplementedError('Nope') + + +@app.task(name='swh.scheduler.tests.tasks.add', + bind=True) +def add(self, x, y): + self.log.debug(self.name) + return x + y diff --git a/swh/scheduler/tests/test_celery_tasks.py b/swh/scheduler/tests/test_celery_tasks.py --- a/swh/scheduler/tests/test_celery_tasks.py +++ b/swh/scheduler/tests/test_celery_tasks.py @@ -1,5 +1,11 @@ from time import sleep from celery.result import GroupResult +from celery.result import AsyncResult + +import pytest + +from swh.scheduler.utils import create_task_dict +from swh.scheduler.celery_backend.runner import run_ready_tasks def test_ping(swh_app, celery_session_worker): @@ -33,3 +39,52 @@ results = [x.get() for x in promise.results] for i in range(5): assert ("OK (kw={'i': %s})" % i) in results + + +def test_scheduler_fixture(swh_app, celery_session_worker, swh_scheduler): + "Test that the scheduler fixture works properly" + task_type = swh_scheduler.get_task_type('swh-test-ping') + + assert task_type + assert task_type['backend_name'] == 'swh.scheduler.tests.tasks.ping' + + swh_scheduler.create_tasks([create_task_dict( + 'swh-test-ping', 'oneshot')]) + + backend_tasks = run_ready_tasks(swh_scheduler, swh_app) + assert backend_tasks + for task in backend_tasks: + # Make sure the task completed + AsyncResult(id=task['backend_id']).get() + + +def test_task_return_value(swh_app, celery_session_worker, swh_scheduler): + task_type = swh_scheduler.get_task_type('swh-test-add') + assert task_type + assert task_type['backend_name'] == 'swh.scheduler.tests.tasks.add' + + swh_scheduler.create_tasks([create_task_dict( + 'swh-test-add', 'oneshot', 12, 30)]) + + backend_tasks = run_ready_tasks(swh_scheduler, swh_app) + assert len(backend_tasks) == 1 + task = backend_tasks[0] + value = AsyncResult(id=task['backend_id']).get() + assert value == 42 + + +def test_task_exception(swh_app, celery_session_worker, swh_scheduler): + task_type = swh_scheduler.get_task_type('swh-test-error') + assert task_type + assert task_type['backend_name'] == 'swh.scheduler.tests.tasks.error' + + swh_scheduler.create_tasks([create_task_dict( + 'swh-test-error', 'oneshot')]) + + backend_tasks = run_ready_tasks(swh_scheduler, swh_app) + assert len(backend_tasks) == 1 + + task = backend_tasks[0] + result = AsyncResult(id=task['backend_id']) + with pytest.raises(NotImplementedError): + result.get() diff --git a/swh/scheduler/tests/test_fixtures.py b/swh/scheduler/tests/test_fixtures.py deleted file mode 100644 --- a/swh/scheduler/tests/test_fixtures.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (C) 2018 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import unittest - -from swh.scheduler.tests.scheduler_testing import SchedulerTestFixture -from swh.scheduler.task import Task -from swh.scheduler.utils import create_task_dict - -task_has_run = False - - -class SomeTestTask(Task): - def run(self, *, foo): - global task_has_run - assert foo == 'bar' - task_has_run = True - - -class FixtureTest(SchedulerTestFixture, unittest.TestCase): - def setUp(self): - super().setUp() - self.add_scheduler_task_type( - 'some_test_task_type', - 'swh.scheduler.tests.test_fixtures.SomeTestTask', - SomeTestTask, - ) - - def test_task_run(self): - self.scheduler.create_tasks([create_task_dict( - 'some_test_task_type', - 'oneshot', - foo='bar', - )]) - self.assertEqual(task_has_run, False) - self.run_ready_tasks() - self.assertEqual(task_has_run, True) diff --git a/swh/scheduler/tests/test_task.py b/swh/scheduler/tests/test_task.py deleted file mode 100644 --- a/swh/scheduler/tests/test_task.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (C) 2015 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -import unittest - -from celery import current_app as app - -from swh.scheduler import task -from .celery_testing import CeleryTestFixture - - -class Task(CeleryTestFixture, unittest.TestCase): - - def test_not_implemented_task(self): - class NotImplementedTask(task.Task): - name = 'NotImplementedTask' - - pass - - app.register_task(NotImplementedTask()) - - with self.assertRaises(NotImplementedError): - NotImplementedTask().run() - - def test_add_task(self): - class AddTask(task.Task): - name = 'AddTask' - - def run_task(self, x, y): - return x + y - - app.register_task(AddTask()) - - r = AddTask().apply([2, 3]) - self.assertTrue(r.successful()) - self.assertEqual(r.result, 5)