diff --git a/swh/loader/tar/tasks.py b/swh/loader/tar/tasks.py index cba07fb..dc3edce 100644 --- a/swh/loader/tar/tasks.py +++ b/swh/loader/tar/tasks.py @@ -1,26 +1,17 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.scheduler.task import Task +from celery import current_app as app from swh.loader.tar.loader import RemoteTarLoader -class LoadTarRepository(Task): +@app.task(name=__name__ + '.LoadTarRepository') +def load_tar(origin, visit_date, last_modified): """Import a remote or local archive to Software Heritage - """ - task_queue = 'swh_loader_tar' - - def run_task(self, *, origin, visit_date, last_modified): - """Import a tarball into swh. - - Args: see :func:`TarLoader.prepare`. - - """ - loader = RemoteTarLoader() - loader.log = self.log - return loader.load( - origin=origin, visit_date=visit_date, last_modified=last_modified) + loader = RemoteTarLoader() + return loader.load( + origin=origin, visit_date=visit_date, last_modified=last_modified) diff --git a/swh/loader/tar/tests/conftest.py b/swh/loader/tar/tests/conftest.py new file mode 100644 index 0000000..972dd2f --- /dev/null +++ b/swh/loader/tar/tests/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from swh.scheduler.tests.conftest import * # noqa + + +@pytest.fixture(scope='session') +def celery_includes(): + return [ + 'swh.loader.tar.tasks', + ] diff --git a/swh/loader/tar/tests/test_tasks.py b/swh/loader/tar/tests/test_tasks.py index 2fa9a30..3e5daac 100644 --- a/swh/loader/tar/tests/test_tasks.py +++ b/swh/loader/tar/tests/test_tasks.py @@ -1,31 +1,27 @@ # Copyright (C) 2015-2018 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import unittest from unittest.mock import patch -from swh.loader.tar.tasks import LoadTarRepository +@patch('swh.loader.tar.loader.RemoteTarLoader.load') +def test_tar_loader_task(mock_loader, swh_app, celery_session_worker): + mock_loader.return_value = {'status': 'eventful'} -class TestTasks(unittest.TestCase): - def test_check_task_name(self): - task = LoadTarRepository() - self.assertEqual(task.task_queue, 'swh_loader_tar') + res = swh_app.send_task( + 'swh.loader.tar.tasks.LoadTarRepository', + ('origin', 'visit_date', 'last_modified')) + assert res + res.wait() + assert res.successful() - @patch('swh.loader.tar.loader.RemoteTarLoader.load') - def test_task(self, mock_loader): - mock_loader.return_value = {'status': 'eventful'} - task = LoadTarRepository() + # given + actual_result = res.result - # given - actual_result = task.run_task( - origin='origin', visit_date='visit_date', - last_modified='last_modified') + assert actual_result == {'status': 'eventful'} - self.assertEqual(actual_result, {'status': 'eventful'}) - - mock_loader.assert_called_once_with( - origin='origin', visit_date='visit_date', - last_modified='last_modified') + mock_loader.assert_called_once_with( + origin='origin', visit_date='visit_date', + last_modified='last_modified')