diff --git a/sql/updates/20.sql b/sql/updates/20.sql new file mode 100644 --- /dev/null +++ b/sql/updates/20.sql @@ -0,0 +1,6 @@ + +insert into dbversion (version, release, description) + values (20, now(), 'Work In Progress'); + +create index on listed_origins (visit_type, last_scheduled); +drop index listed_origins_last_scheduled_idx; diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -289,10 +289,10 @@ @db_transaction() def grab_next_visits( - self, count: int, policy: str, db=None, cur=None, + self, visit_type: str, count: int, policy: str, db=None, cur=None, ) -> List[ListedOrigin]: - """Get at most the `count` next origins that need to be visited - according to the given scheduling `policy`. + """Get at most the `count` next origins that need to be visited with + the `visit_type` loader according to the given scheduling `policy`. This will mark the origins as "being visited" in the listed_origins table, to avoid scheduling multiple visits to the same origin. @@ -304,6 +304,7 @@ with filtered_origins as ( select lister_id, url, visit_type from listed_origins + where visit_type = %s order by last_scheduled nulls first limit %s for update skip locked @@ -313,7 +314,7 @@ where (lister_id, url, visit_type) in (select * from filtered_origins) returning {origin_select_cols} """ - cur.execute(query, (count,)) + cur.execute(query, (visit_type, count)) return [ListedOrigin(**d) for d in cur] else: diff --git a/swh/scheduler/cli/origin.py b/swh/scheduler/cli/origin.py --- a/swh/scheduler/cli/origin.py +++ b/swh/scheduler/cli/origin.py @@ -84,10 +84,14 @@ default=True, help="Print the CSV header?", ) +@click.argument("type", type=str) @click.argument("count", type=int) @click.pass_context -def grab_next(ctx, policy: str, fields: Optional[str], with_header: bool, count: int): - """Grab the next COUNT origins to visit from the listed origins table.""" +def grab_next( + ctx, policy: str, fields: Optional[str], with_header: bool, type: str, count: int +): + """Grab the next COUNT origins to visit using the TYPE loader from the + listed origins table.""" if fields: parsed_fields: Optional[List[str]] = fields.split(",") @@ -96,7 +100,7 @@ scheduler = ctx.obj["scheduler"] - origins = scheduler.grab_next_visits(count, policy=policy) + origins = scheduler.grab_next_visits(type, count, policy=policy) for line in format_origins(origins, fields=parsed_fields, with_header=with_header): click.echo(line) @@ -105,16 +109,18 @@ @click.option( "--policy", "-p", default="oldest_scheduled_first", help="Scheduling policy" ) +@click.argument("type", type=str) @click.argument("count", type=int) @click.pass_context -def schedule_next(ctx, policy: str, count: int): - """Send the next COUNT origin visits to the scheduler as one-shot tasks.""" +def schedule_next(ctx, policy: str, type: str, count: int): + """Send the next COUNT origin visits of the TYPE loader to the scheduler as + one-shot tasks.""" from ..utils import utcnow from .task import pretty_print_task scheduler = ctx.obj["scheduler"] - origins = scheduler.grab_next_visits(count, policy=policy) + origins = scheduler.grab_next_visits(type, count, policy=policy) created = scheduler.create_tasks( [ diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -311,9 +311,11 @@ ... @remote_api_endpoint("origins/grab_next") - def grab_next_visits(self, count: int, policy: str,) -> List[ListedOrigin]: - """Get at most the `count` next origins that need to be visited - according to the given scheduling `policy`. + def grab_next_visits( + self, visit_type: str, count: int, policy: str + ) -> List[ListedOrigin]: + """Get at most the `count` next origins that need to be visited with + the `visit_type` loader according to the given scheduling `policy`. This will mark the origins as "being visited" in the listed_origins table, to avoid scheduling multiple visits to the same origin. diff --git a/swh/scheduler/sql/30-schema.sql b/swh/scheduler/sql/30-schema.sql --- a/swh/scheduler/sql/30-schema.sql +++ b/swh/scheduler/sql/30-schema.sql @@ -11,7 +11,7 @@ comment on column dbversion.description is 'Version description'; insert into dbversion (version, release, description) - values (19, now(), 'Work In Progress'); + values (20, now(), 'Work In Progress'); create table task_type ( type text primary key, diff --git a/swh/scheduler/sql/60-indexes.sql b/swh/scheduler/sql/60-indexes.sql --- a/swh/scheduler/sql/60-indexes.sql +++ b/swh/scheduler/sql/60-indexes.sql @@ -17,4 +17,4 @@ -- listed origins create index on listed_origins (url); -create index on listed_origins (last_scheduled); +create index on listed_origins (visit_type, last_scheduled); diff --git a/swh/scheduler/tests/test_cli_origin.py b/swh/scheduler/tests/test_cli_origin.py --- a/swh/scheduler/tests/test_cli_origin.py +++ b/swh/scheduler/tests/test_cli_origin.py @@ -56,48 +56,57 @@ assert output[i + 1] == f"{origin.lister_id},{origin.url},{origin.visit_type}" -def test_grab_next(swh_scheduler, listed_origins): - num_origins = 10 - assert len(listed_origins) >= num_origins +def test_grab_next(swh_scheduler, listed_origins_by_type): + NUM_RESULTS = 10 + # Strict inequality to check that grab_next_visits doesn't return more + # results than requested + visit_type = next(iter(listed_origins_by_type)) + assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS - swh_scheduler.record_listed_origins(listed_origins) + for origins in listed_origins_by_type.values(): + swh_scheduler.record_listed_origins(origins) - result = invoke(swh_scheduler, args=("grab-next", str(num_origins))) + result = invoke(swh_scheduler, args=("grab-next", visit_type, str(NUM_RESULTS))) assert result.exit_code == 0 out_lines = result.stdout.splitlines() - assert len(out_lines) == num_origins + 1 + assert len(out_lines) == NUM_RESULTS + 1 fields = out_lines[0].split(",") returned_origins = [dict(zip(fields, line.split(","))) for line in out_lines[1:]] # Check that we've received origins we had listed in the first place assert set(origin["url"] for origin in returned_origins) <= set( - origin.url for origin in listed_origins + origin.url for origin in listed_origins_by_type[visit_type] ) -def test_schedule_next(swh_scheduler, listed_origins): +def test_schedule_next(swh_scheduler, listed_origins_by_type): for task_type in TASK_TYPES.values(): swh_scheduler.create_task_type(task_type) - num_origins = 10 - assert len(listed_origins) >= num_origins + NUM_RESULTS = 10 + # Strict inequality to check that grab_next_visits doesn't return more + # results than requested + visit_type = next(iter(listed_origins_by_type)) + assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS - swh_scheduler.record_listed_origins(listed_origins) + for origins in listed_origins_by_type.values(): + swh_scheduler.record_listed_origins(origins) - result = invoke(swh_scheduler, args=("schedule-next", str(num_origins))) + result = invoke(swh_scheduler, args=("schedule-next", visit_type, str(NUM_RESULTS))) assert result.exit_code == 0 # pull all tasks out of the scheduler tasks = swh_scheduler.search_tasks() - assert len(tasks) == num_origins + assert len(tasks) == NUM_RESULTS scheduled_tasks = { (task["type"], task["arguments"]["kwargs"]["url"]) for task in tasks } all_possible_tasks = { - (f"load-{origin.visit_type}", origin.url) for origin in listed_origins + (f"load-{origin.visit_type}", origin.url) + for origin in listed_origins_by_type[visit_type] } assert scheduled_tasks <= all_possible_tasks diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -727,16 +727,18 @@ assert len(ret.origins) == len(listed_origins) @pytest.mark.parametrize("policy", ["oldest_scheduled_first"]) - def test_grab_next_visits(self, swh_scheduler, listed_origins, policy): + def test_grab_next_visits(self, swh_scheduler, listed_origins_by_type, policy): NUM_RESULTS = 5 # Strict inequality to check that grab_next_visits doesn't return more # results than requested - assert len(listed_origins) > NUM_RESULTS + visit_type = next(iter(listed_origins_by_type)) + assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS - swh_scheduler.record_listed_origins(listed_origins) + for origins in listed_origins_by_type.values(): + swh_scheduler.record_listed_origins(origins) before = utcnow() - ret = swh_scheduler.grab_next_visits(NUM_RESULTS, policy=policy) + ret = swh_scheduler.grab_next_visits(visit_type, NUM_RESULTS, policy=policy) after = utcnow() assert len(ret) == NUM_RESULTS @@ -744,20 +746,27 @@ assert before <= origin.last_scheduled <= after @pytest.mark.parametrize("policy", ["oldest_scheduled_first"]) - def test_grab_next_visits_underflow(self, swh_scheduler, listed_origins, policy): + def test_grab_next_visits_underflow( + self, swh_scheduler, listed_origins_by_type, policy + ): NUM_RESULTS = 5 - assert len(listed_origins) >= NUM_RESULTS + visit_type = next(iter(listed_origins_by_type)) + assert len(listed_origins_by_type[visit_type]) > NUM_RESULTS - swh_scheduler.record_listed_origins(listed_origins[:NUM_RESULTS]) + swh_scheduler.record_listed_origins( + listed_origins_by_type[visit_type][:NUM_RESULTS] + ) - ret = swh_scheduler.grab_next_visits(NUM_RESULTS + 2, policy=policy) + ret = swh_scheduler.grab_next_visits(visit_type, NUM_RESULTS + 2, policy=policy) assert len(ret) == NUM_RESULTS def test_grab_next_visits_unknown_policy(self, swh_scheduler): NUM_RESULTS = 5 with pytest.raises(UnknownPolicy, match="non_existing_policy"): - swh_scheduler.grab_next_visits(NUM_RESULTS, policy="non_existing_policy") + swh_scheduler.grab_next_visits( + "type", NUM_RESULTS, policy="non_existing_policy" + ) def _create_task_types(self, scheduler): for tt in TASK_TYPES.values():