diff --git a/swh/loader/tar/tasks.py b/swh/loader/tar/tasks.py --- a/swh/loader/tar/tasks.py +++ b/swh/loader/tar/tasks.py @@ -3,24 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.scheduler.task import Task +from celery import current_app as app from swh.loader.tar.loader import RemoteTarLoader -class LoadTarRepository(Task): +@app.task(name=__name__ + '.LoadTarRepository') +def load_tar(origin, visit_date, last_modified): """Import a remote or local archive to Software Heritage - """ - task_queue = 'swh_loader_tar' - - def run_task(self, *, origin, visit_date, last_modified): - """Import a tarball into swh. - - Args: see :func:`TarLoader.prepare`. - - """ - loader = RemoteTarLoader() - loader.log = self.log - return loader.load( - origin=origin, visit_date=visit_date, last_modified=last_modified) + loader = RemoteTarLoader() + return loader.load( + origin=origin, visit_date=visit_date, last_modified=last_modified) diff --git a/swh/loader/tar/tests/conftest.py b/swh/loader/tar/tests/conftest.py new file mode 100644 --- /dev/null +++ b/swh/loader/tar/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from swh.scheduler.tests.conftest import * # noqa + + +@pytest.fixture(scope='session') +def celery_includes(): + return [ + 'swh.loader.tar.tasks', + ] diff --git a/swh/loader/tar/tests/test_tasks.py b/swh/loader/tar/tests/test_tasks.py --- a/swh/loader/tar/tests/test_tasks.py +++ b/swh/loader/tar/tests/test_tasks.py @@ -3,29 +3,25 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import unittest from unittest.mock import patch -from swh.loader.tar.tasks import LoadTarRepository +@patch('swh.loader.tar.loader.RemoteTarLoader.load') +def test_tar_loader_task(mock_loader, swh_app, celery_session_worker): + mock_loader.return_value = {'status': 'eventful'} -class TestTasks(unittest.TestCase): - def test_check_task_name(self): - task = LoadTarRepository() - self.assertEqual(task.task_queue, 'swh_loader_tar') + res = swh_app.send_task( + 'swh.loader.tar.tasks.LoadTarRepository', + ('origin', 'visit_date', 'last_modified')) + assert res + res.wait() + assert res.successful() - @patch('swh.loader.tar.loader.RemoteTarLoader.load') - def test_task(self, mock_loader): - mock_loader.return_value = {'status': 'eventful'} - task = LoadTarRepository() + # given + actual_result = res.result - # given - actual_result = task.run_task( - origin='origin', visit_date='visit_date', - last_modified='last_modified') + assert actual_result == {'status': 'eventful'} - self.assertEqual(actual_result, {'status': 'eventful'}) - - mock_loader.assert_called_once_with( - origin='origin', visit_date='visit_date', - last_modified='last_modified') + mock_loader.assert_called_once_with( + origin='origin', visit_date='visit_date', + last_modified='last_modified')