diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -164,6 +164,22 @@ return [Lister(**ret) for ret in cur.fetchall()] + @db_transaction() + def get_listers_by_id( + self, lister_ids: List[str], db=None, cur=None + ) -> List[Lister]: + """Retrieve listers in batch, using their UUID""" + select_cols = ", ".join(Lister.select_columns()) + + query = f""" + select {select_cols} from listers + where id in %s + """ + + cur.execute(query, (tuple(lister_ids),)) + + return [Lister(**row) for row in cur] + @db_transaction() def get_lister( self, name: str, instance_name: Optional[str] = None, db=None, cur=None diff --git a/swh/scheduler/celery_backend/recurrent_visits.py b/swh/scheduler/celery_backend/recurrent_visits.py --- a/swh/scheduler/celery_backend/recurrent_visits.py +++ b/swh/scheduler/celery_backend/recurrent_visits.py @@ -24,7 +24,7 @@ from kombu.utils.uuid import uuid from swh.scheduler.celery_backend.config import get_available_slots -from swh.scheduler.utils import create_origin_task_dict +from swh.scheduler.utils import create_origin_task_dicts if TYPE_CHECKING: from ..interface import SchedulerInterface @@ -233,8 +233,7 @@ # scheduling policies have different resource usage patterns random.shuffle(origins) - for origin in origins: - task_dict = create_origin_task_dict(origin) + for task_dict in create_origin_task_dicts(origins, scheduler): app.send_task( queue_name, task_id=uuid(), diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py --- a/swh/scheduler/cli/origin.py +++ b/swh/scheduler/cli/origin.py @@ -10,7 +10,7 @@ import click from . import cli -from ..utils import create_origin_task_dict +from ..utils import create_origin_task_dicts if TYPE_CHECKING: from uuid import UUID @@ -130,12 +130,12 @@ created = scheduler.create_tasks( [ { - **create_origin_task_dict(origin), + **task_dict, "policy": "oneshot", "next_run": utcnow(), "retries_left": 1, } - for origin in origins + for task_dict in create_origin_task_dicts(origins, scheduler) ] ) @@ -211,8 +211,7 @@ ) click.echo(f"{len(origins)} visits to send to celery") - for origin in origins: - task_dict = create_origin_task_dict(origin) + for task_dict in create_origin_task_dicts(origins, scheduler): app.send_task( task_name, task_id=uuid(), diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -333,6 +333,10 @@ """Retrieve information about all listers from the database.""" ... + @remote_api_endpoint("listers/get_by_id") + def get_listers_by_id(self, lister_ids: List[str]) -> List[Lister]: + """Retrieve listers in batch, using their UUID""" + @remote_api_endpoint("lister/get") def get_lister( self, name: str, instance_name: Optional[str] = None diff --git a/swh/scheduler/simulator/__init__.py b/swh/scheduler/simulator/__init__.py --- a/swh/scheduler/simulator/__init__.py +++ b/swh/scheduler/simulator/__init__.py @@ -17,7 +17,7 @@ from simpy import Event from swh.scheduler.interface import SchedulerInterface -from swh.scheduler.utils import create_origin_task_dict +from swh.scheduler.utils import create_origin_task_dicts from . import origin_scheduler, task_scheduler from .common import Environment, Queue, SimulationReport, Task @@ -122,12 +122,14 @@ scheduler.create_tasks( [ { - **create_origin_task_dict(origin), + **task_dict, "policy": "recurring", "next_run": origin.last_update, "interval": timedelta(days=64), } - for origin in origins + for (origin, task_dict) in zip( + origins, create_origin_task_dicts(origins, scheduler) + ) ] ) diff --git a/swh/scheduler/tests/test_api_client.py b/swh/scheduler/tests/test_api_client.py --- a/swh/scheduler/tests/test_api_client.py +++ b/swh/scheduler/tests/test_api_client.py @@ -46,6 +46,7 @@ "lister/get_or_create", "lister/update", "listers/get", + "listers/get_by_id", "origins/get", "origins/grab_next", "origins/record", diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -671,6 +671,27 @@ assert swh_scheduler.get_listers() == db_listers + def test_get_listers_by_id(self, swh_scheduler): + assert swh_scheduler.get_listers_by_id([str(uuid.uuid4())]) == [] + + db_listers = [] + for lister_args in LISTERS: + db_listers.append(swh_scheduler.get_or_create_lister(**lister_args)) + + id0 = db_listers[0].id + id1 = db_listers[1].id + + assert swh_scheduler.get_listers_by_id([id0]) == [db_listers[0]] + assert swh_scheduler.get_listers_by_id([id1]) == [db_listers[1]] + assert swh_scheduler.get_listers_by_id([id0, id1]) == [ + db_listers[0], + db_listers[1], + ] + + assert swh_scheduler.get_listers_by_id([id0, str(uuid.uuid4())]) == [ + db_listers[0] + ] + def test_update_lister(self, swh_scheduler, stored_lister): lister = attr.evolve(stored_lister, current_state={"updated": "now"}) diff --git a/swh/scheduler/tests/test_utils.py b/swh/scheduler/tests/test_utils.py --- a/swh/scheduler/tests/test_utils.py +++ b/swh/scheduler/tests/test_utils.py @@ -9,6 +9,8 @@ from swh.scheduler import model, utils +from .common import LISTERS + @patch("swh.scheduler.utils.datetime") def test_create_oneshot_task_dict_simple(mock_datetime): @@ -84,32 +86,98 @@ def test_create_origin_task_dict(): + lister = model.Lister(**LISTERS[1], id=uuid.uuid4()) origin = model.ListedOrigin( - lister_id=uuid.uuid4(), + lister_id=lister.id, url="http://example.com/", visit_type="git", ) - task = utils.create_origin_task_dict(origin) + task = utils.create_origin_task_dict(origin, lister) assert task == { "type": "load-git", - "arguments": {"args": [], "kwargs": {"url": "http://example.com/"}}, + "arguments": { + "args": [], + "kwargs": {"url": "http://example.com/", "lister_name": LISTERS[1]["name"]}, + }, } loader_args = {"foo": "bar", "baz": {"foo": "bar"}} origin_w_args = model.ListedOrigin( - lister_id=uuid.uuid4(), + lister_id=lister.id, url="http://example.com/svn/", visit_type="svn", extra_loader_arguments=loader_args, ) - task_w_args = utils.create_origin_task_dict(origin_w_args) + task_w_args = utils.create_origin_task_dict(origin_w_args, lister) assert task_w_args == { "type": "load-svn", "arguments": { "args": [], - "kwargs": {"url": "http://example.com/svn/", **loader_args}, + "kwargs": { + "url": "http://example.com/svn/", + "lister_name": LISTERS[1]["name"], + **loader_args, + }, }, } + + +def test_create_origin_task_dicts(swh_scheduler): + listers = [] + for lister_args in LISTERS: + listers.append(swh_scheduler.get_or_create_lister(**lister_args)) + + origin1 = model.ListedOrigin( + lister_id=listers[0].id, + url="http://example.com/1", + visit_type="git", + ) + origin2 = model.ListedOrigin( + lister_id=listers[0].id, + url="http://example.com/2", + visit_type="git", + ) + origin3 = model.ListedOrigin( + lister_id=listers[1].id, + url="http://example.com/3", + visit_type="git", + ) + + origins = [origin1, origin2, origin3] + + tasks = utils.create_origin_task_dicts(origins, swh_scheduler) + assert tasks == [ + { + "type": "load-git", + "arguments": { + "args": [], + "kwargs": { + "url": "http://example.com/1", + "lister_name": LISTERS[0]["name"], + }, + }, + }, + { + "type": "load-git", + "arguments": { + "args": [], + "kwargs": { + "url": "http://example.com/2", + "lister_name": LISTERS[0]["name"], + }, + }, + }, + { + "type": "load-git", + "arguments": { + "args": [], + "kwargs": { + "url": "http://example.com/3", + "lister_name": LISTERS[1]["name"], + }, + }, + }, + ] diff --git a/swh/scheduler/utils.py b/swh/scheduler/utils.py --- a/swh/scheduler/utils.py +++ b/swh/scheduler/utils.py @@ -5,9 +5,10 @@ from datetime import datetime, timezone -from typing import Any, Dict +from typing import Any, Dict, List -from .model import ListedOrigin +from .interface import SchedulerInterface +from .model import ListedOrigin, Lister def utcnow(): @@ -67,16 +68,41 @@ return task -def create_origin_task_dict(origin: ListedOrigin) -> Dict[str, Any]: +def create_origin_task_dict(origin: ListedOrigin, lister: Lister) -> Dict[str, Any]: + if origin.lister_id != lister.id: + raise ValueError( + "origin.lister_id and lister.id differ", origin.lister_id, lister.id + ) return { "type": f"load-{origin.visit_type}", "arguments": { "args": [], - "kwargs": {"url": origin.url, **origin.extra_loader_arguments}, + "kwargs": { + "url": origin.url, + "lister_name": lister.name, + **origin.extra_loader_arguments, + }, }, } +def create_origin_task_dicts( + origins: List[ListedOrigin], scheduler: SchedulerInterface +) -> List[Dict[str, Any]]: + """Returns a task dict for each origin, in the same order.""" + + lister_ids = {o.lister_id for o in origins} + listers = { + lister.id: lister + for lister in scheduler.get_listers_by_id(list(map(str, lister_ids))) + } + + missing_lister_ids = lister_ids - set(listers) + assert not missing_lister_ids, f"Missing listers: {missing_lister_ids}" + + return [create_origin_task_dict(o, listers[o.lister_id]) for o in origins] + + def create_oneshot_task_dict(type, *args, **kwargs): """Create a oneshot task scheduled for as soon as possible.