diff --git a/swh/loader/bzr/tasks.py b/swh/loader/bzr/tasks.py index d192163..964020e 100644 --- a/swh/loader/bzr/tasks.py +++ b/swh/loader/bzr/tasks.py @@ -1,26 +1,26 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Optional - from celery import shared_task from swh.loader.core.utils import parse_visit_date from .loader import BazaarLoader +def _process_kwargs(kwargs): + if "visit_date" in kwargs: + kwargs["visit_date"] = parse_visit_date(kwargs["visit_date"]) + return kwargs + + @shared_task(name=__name__ + ".LoadBazaar") -def load_bzr( - *, url: str, directory: Optional[str] = None, visit_date: Optional[str] = None -): +def load_bzr(**kwargs): """Bazaar repository loading Args: see :func:`BazaarLoader` constructor. """ - loader = BazaarLoader.from_configfile( - url=url, directory=directory, visit_date=parse_visit_date(visit_date) - ) + loader = BazaarLoader.from_configfile(**_process_kwargs(kwargs)) return loader.load() diff --git a/swh/loader/bzr/tests/test_tasks.py b/swh/loader/bzr/tests/test_tasks.py index 7e7158d..1d6c7f6 100644 --- a/swh/loader/bzr/tests/test_tasks.py +++ b/swh/loader/bzr/tests/test_tasks.py @@ -1,27 +1,76 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import uuid + +import pytest + +from swh.scheduler.model import ListedOrigin, Lister +from swh.scheduler.utils import create_origin_task_dict + + +@pytest.fixture(autouse=True) +def celery_worker_and_swh_config(swh_scheduler_celery_worker, swh_config): + pass + + +@pytest.fixture +def bzr_lister(): + return Lister(name="bzr-lister", instance_name="example", id=uuid.uuid4()) + + +@pytest.fixture +def bzr_listed_origin(bzr_lister): + return ListedOrigin( + lister_id=bzr_lister.id, url="https://bzr.example.org/repo", visit_type="bzr" + ) + def test_loader( - mocker, swh_config, swh_scheduler_celery_app, swh_scheduler_celery_worker + mocker, + swh_scheduler_celery_app, ): mock_loader = mocker.patch("swh.loader.bzr.loader.BazaarLoader.load") mock_loader.return_value = {"status": "eventful"} res = swh_scheduler_celery_app.send_task( "swh.loader.bzr.tasks.LoadBazaar", kwargs={ "url": "origin_url", "directory": "/some/repo", "visit_date": "now", }, ) assert res res.wait() assert res.successful() assert res.result == {"status": "eventful"} mock_loader.assert_called_once_with() + + +def test_loader_for_listed_origin( + mocker, + swh_scheduler_celery_app, + bzr_lister, + bzr_listed_origin, +): + mock_loader = mocker.patch("swh.loader.bzr.loader.BazaarLoader.load") + mock_loader.return_value = {"status": "eventful"} + + task_dict = create_origin_task_dict(bzr_listed_origin, bzr_lister) + + res = swh_scheduler_celery_app.send_task( + "swh.loader.bzr.tasks.LoadBazaar", + kwargs=task_dict["arguments"]["kwargs"], + ) + + assert res + res.wait() + assert res.successful() + + assert res.result == {"status": "eventful"} + mock_loader.assert_called_once_with()