diff --git a/sql/updates/18.sql b/sql/updates/18.sql new file mode 100644 --- /dev/null +++ b/sql/updates/18.sql @@ -0,0 +1,8 @@ + +insert into dbversion (version, release, description) + values (18, now(), 'Work In Progress'); + +alter table listed_origins add column last_scheduled timestamptz; +comment on column listed_origins.last_scheduled is 'Time when this origin was scheduled to be visited last'; + +create index on listed_origins (last_scheduled); diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -16,7 +16,7 @@ from swh.core.db.common import db_transaction from swh.scheduler.utils import utcnow -from .exc import StaleData +from .exc import StaleData, UnknownPolicy from .model import ( ListedOrigin, ListedOriginPageToken, @@ -286,6 +286,38 @@ return PaginatedListedOriginList(origins, page_token) + @db_transaction() + def grab_next_visits( + self, count: int, policy: str, db=None, cur=None, + ) -> List[ListedOrigin]: + """Get at most the `count` next origins that need to be visited + according to the given scheduling `policy`. + + This will mark the origins as "being visited" in the listed_origins + table, to avoid scheduling multiple visits to the same origin. + """ + origin_select_cols = ", ".join(ListedOrigin.select_columns()) + + if policy == "oldest_scheduled_first": + query = f""" + with filtered_origins as ( + select lister_id, url, visit_type + from listed_origins + order by last_scheduled nulls first + limit %s + for update skip locked + ) + update listed_origins + set last_scheduled = now() + where (lister_id, url, visit_type) in (select * from filtered_origins) + returning {origin_select_cols} + """ + cur.execute(query, (count,)) + + return [ListedOrigin(**d) for d in cur] + else: + raise UnknownPolicy(f"Unknown scheduling policy {policy}") + task_create_keys = [ "type", "arguments", diff --git a/swh/scheduler/exc.py b/swh/scheduler/exc.py --- a/swh/scheduler/exc.py +++ b/swh/scheduler/exc.py @@ -6,6 +6,7 @@ __all__ = [ "SchedulerException", "StaleData", + "UnknownPolicy", ] @@ -15,3 +16,7 @@ class StaleData(SchedulerException): pass + + +class UnknownPolicy(SchedulerException): + pass diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -309,6 +309,16 @@ """ ... + @remote_api_endpoint("origins/grab_next") + def grab_next_visits(self, count: int, policy: str,) -> List[ListedOrigin]: + """Get at most the `count` next origins that need to be visited + according to the given scheduling `policy`. + + This will mark the origins as "being visited" in the listed_origins + table, to avoid scheduling multiple visits to the same origin. + """ + ... + @remote_api_endpoint("priority_ratios/get") def get_priority_ratios(self): ... diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py --- a/swh/scheduler/model.py +++ b/swh/scheduler/model.py @@ -146,6 +146,10 @@ type=Optional[datetime.datetime], validator=[type_validator()], default=None, ) + last_scheduled = attr.ib( + type=Optional[datetime.datetime], validator=[type_validator()], default=None, + ) + enabled = attr.ib(type=bool, validator=[type_validator()], default=True) first_seen = attr.ib( diff --git a/swh/scheduler/sql/30-schema.sql b/swh/scheduler/sql/30-schema.sql --- a/swh/scheduler/sql/30-schema.sql +++ b/swh/scheduler/sql/30-schema.sql @@ -11,7 +11,7 @@ comment on column dbversion.description is 'Version description'; insert into dbversion (version, release, description) - values (17, now(), 'Work In Progress'); + values (18, now(), 'Work In Progress'); create table task_type ( type text primary key, @@ -145,6 +145,9 @@ -- potentially provided by the lister last_update timestamptz, + -- visit scheduling information + last_scheduled timestamptz, + primary key (lister_id, url, visit_type) ); @@ -159,3 +162,5 @@ comment on column listed_origins.last_seen is 'Time at which the origin was last seen by the lister'; comment on column listed_origins.last_update is 'Time of the last update to the origin recorded by the remote'; + +comment on column listed_origins.last_scheduled is 'Time when this origin was scheduled to be visited last'; diff --git a/swh/scheduler/sql/60-indexes.sql b/swh/scheduler/sql/60-indexes.sql --- a/swh/scheduler/sql/60-indexes.sql +++ b/swh/scheduler/sql/60-indexes.sql @@ -17,3 +17,4 @@ -- listed origins create index on listed_origins (url); +create index on listed_origins (last_scheduled); diff --git a/swh/scheduler/tests/test_api_client.py b/swh/scheduler/tests/test_api_client.py --- a/swh/scheduler/tests/test_api_client.py +++ b/swh/scheduler/tests/test_api_client.py @@ -45,6 +45,7 @@ "lister/get_or_create", "lister/update", "origins/get", + "origins/grab_next", "origins/record", "priority_ratios/get", "task/create", diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -14,7 +14,7 @@ import attr import pytest -from swh.scheduler.exc import StaleData +from swh.scheduler.exc import StaleData, UnknownPolicy from swh.scheduler.interface import SchedulerInterface from swh.scheduler.model import ListedOrigin, ListedOriginPageToken from swh.scheduler.utils import utcnow @@ -726,6 +726,39 @@ assert ret.next_page_token is None assert len(ret.origins) == len(listed_origins) + @pytest.mark.parametrize("policy", ["oldest_scheduled_first"]) + def test_grab_next_visits(self, swh_scheduler, listed_origins, policy): + NUM_RESULTS = 5 + # Strict inequality to check that grab_next_visits doesn't return more + # results than requested + assert len(listed_origins) > NUM_RESULTS + + swh_scheduler.record_listed_origins(listed_origins) + + before = utcnow() + ret = swh_scheduler.grab_next_visits(NUM_RESULTS, policy=policy) + after = utcnow() + + assert len(ret) == NUM_RESULTS + for origin in ret: + assert before <= origin.last_scheduled <= after + + @pytest.mark.parametrize("policy", ["oldest_scheduled_first"]) + def test_grab_next_visits_underflow(self, swh_scheduler, listed_origins, policy): + NUM_RESULTS = 5 + assert len(listed_origins) >= NUM_RESULTS + + swh_scheduler.record_listed_origins(listed_origins[:NUM_RESULTS]) + + ret = swh_scheduler.grab_next_visits(NUM_RESULTS + 2, policy=policy) + + assert len(ret) == NUM_RESULTS + + def test_grab_next_visits_unknown_policy(self, swh_scheduler): + NUM_RESULTS = 5 + with pytest.raises(UnknownPolicy, match="non_existing_policy"): + swh_scheduler.grab_next_visits(NUM_RESULTS, policy="non_existing_policy") + def _create_task_types(self, scheduler): for tt in TASK_TYPES.values(): scheduler.create_task_type(tt)