diff --git a/swh/scheduler/backend.py b/swh/scheduler/backend.py --- a/swh/scheduler/backend.py +++ b/swh/scheduler/backend.py @@ -5,20 +5,26 @@ import json import logging +from uuid import UUID from arrow import Arrow, utcnow import attr import psycopg2.pool import psycopg2.extras -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from psycopg2.extensions import AsIs from swh.core.db import BaseDb from swh.core.db.common import db_transaction -from .exc import StaleData -from .model import Lister, ListedOrigin +from .exc import ArgumentError, StaleData +from .model import ( + Lister, + ListedOrigin, + ListedOriginPageToken, + PaginatedListedOriginList, +) logger = logging.getLogger(__name__) @@ -233,6 +239,63 @@ return [ListedOrigin(**d) for d in ret] + @db_transaction() + def get_listed_origins( + self, + lister_id: Optional[UUID] = None, + url: Optional[str] = None, + limit: int = 1000, + page_token: Optional[ListedOriginPageToken] = None, + db=None, + cur=None, + ) -> PaginatedListedOriginList: + """Get information on the listed origins matching either the `url` or + `lister_id`, or both arguments. + """ + + if limit > 1000: + raise ArgumentError("get_listed_origins: max page size is 1000.") + + query_filters: List[str] = [] + query_params: List[Union[int, str, UUID, Tuple[UUID, str]]] = [] + + if lister_id: + query_filters.append("lister_id = %s") + query_params.append(lister_id) + + if url is not None: + query_filters.append("url = %s") + query_params.append(url) + + if page_token is not None: + query_filters.append("(lister_id, url) > %s") + # the typeshed annotation for tuple() is too strict. + query_params.append(tuple(page_token)) # type: ignore + + query_params.append(limit) + + select_cols = ", ".join(ListedOrigin.select_columns()) + if query_filters: + where_clause = "where %s" % (" and ".join(query_filters)) + else: + where_clause = "" + + query = f"""SELECT {select_cols} + from listed_origins + {where_clause} + ORDER BY lister_id, url + LIMIT %s""" + + cur.execute(query, tuple(query_params)) + origins = [ListedOrigin(**d) for d in cur] + + if len(origins) == limit: + page_token = (origins[-1].lister_id, origins[-1].url) + else: + page_token = None + + return PaginatedListedOriginList(origins, page_token) + task_create_keys = [ "type", "arguments", diff --git a/swh/scheduler/exc.py b/swh/scheduler/exc.py --- a/swh/scheduler/exc.py +++ b/swh/scheduler/exc.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information __all__ = [ + "ArgumentError", "SchedulerException", "StaleData", ] @@ -15,3 +16,7 @@ class StaleData(SchedulerException): pass + + +class ArgumentError(SchedulerException): + pass diff --git a/swh/scheduler/interface.py b/swh/scheduler/interface.py --- a/swh/scheduler/interface.py +++ b/swh/scheduler/interface.py @@ -4,10 +4,16 @@ # See top-level LICENSE file for more information from typing import Any, Dict, Iterable, List, Optional +from uuid import UUID from swh.core.api import remote_api_endpoint -from swh.scheduler.model import ListedOrigin, Lister +from swh.scheduler.model import ( + ListedOrigin, + ListedOriginPageToken, + Lister, + PaginatedListedOriginList, +) class SchedulerInterface: @@ -285,6 +291,22 @@ """ ... + @remote_api_endpoint("origins/get") + def get_listed_origins( + self, + lister_id: Optional[UUID] = None, + url: Optional[str] = None, + limit: int = 1000, + page_token: Optional[ListedOriginPageToken] = None, + ) -> PaginatedListedOriginList: + """Get information on the listed origins matching either the `url` or + `lister_id`, or both arguments. + + Use the `limit` and `page_token` arguments for continuation. The next + page token, if any, is returned in the PaginatedListedOriginList object. + """ + ... + @remote_api_endpoint("priority_ratios/get") def get_priority_ratios(self): ... diff --git a/swh/scheduler/model.py b/swh/scheduler/model.py --- a/swh/scheduler/model.py +++ b/swh/scheduler/model.py @@ -5,7 +5,7 @@ import datetime from uuid import UUID -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import attr import attr.converters @@ -160,3 +160,34 @@ default=None, metadata={"auto_now": True}, ) + + +ListedOriginPageToken = Tuple[UUID, str] + + +def convert_listed_origin_page_token( + input: Union[None, ListedOriginPageToken, List[Union[UUID, str]]] +) -> Optional[ListedOriginPageToken]: + if input is None: + return None + + if isinstance(input, tuple): + return input + + x, y = input + assert isinstance(x, UUID) + assert isinstance(y, str) + return (x, y) + + +@attr.s +class PaginatedListedOriginList(BaseSchedulerModel): + """A list of listed origins, with a continuation token""" + + origins = attr.ib(type=List[ListedOrigin], validator=[type_validator()]) + next_page_token = attr.ib( + type=Optional[ListedOriginPageToken], + validator=[type_validator()], + converter=convert_listed_origin_page_token, + default=None, + ) diff --git a/swh/scheduler/sql/60-swh-indexes.sql b/swh/scheduler/sql/60-swh-indexes.sql --- a/swh/scheduler/sql/60-swh-indexes.sql +++ b/swh/scheduler/sql/60-swh-indexes.sql @@ -14,3 +14,6 @@ -- lister schema create unique index on listers (name, instance_name); + +-- listed origins +create index on listed_origins (url); diff --git a/swh/scheduler/tests/test_api_client.py b/swh/scheduler/tests/test_api_client.py --- a/swh/scheduler/tests/test_api_client.py +++ b/swh/scheduler/tests/test_api_client.py @@ -44,6 +44,7 @@ for rule in ( "lister/get_or_create", "lister/update", + "origins/get", "origins/record", "priority_ratios/get", "task/create", diff --git a/swh/scheduler/tests/test_scheduler.py b/swh/scheduler/tests/test_scheduler.py --- a/swh/scheduler/tests/test_scheduler.py +++ b/swh/scheduler/tests/test_scheduler.py @@ -10,14 +10,15 @@ from collections import defaultdict import inspect -from typing import Any, Dict +from typing import Any, Dict, List, Optional from arrow import utcnow import attr import pytest -from swh.scheduler.exc import StaleData +from swh.scheduler.exc import ArgumentError, StaleData from swh.scheduler.interface import SchedulerInterface +from swh.scheduler.model import ListedOrigin, ListedOriginPageToken from .common import tasks_from_template, TEMPLATES, TASK_TYPES, LISTERS @@ -673,6 +674,58 @@ # But a single "last seen" value assert len(set(origin.last_seen for origin in ret)) == 1 + def test_get_listed_origins_exact(self, swh_scheduler, listed_origins): + swh_scheduler.record_listed_origins(listed_origins) + + for i, origin in enumerate(listed_origins): + ret = swh_scheduler.get_listed_origins( + lister_id=origin.lister_id, url=origin.url + ) + + assert ret.next_page_token is None + assert len(ret.origins) == 1 + assert ret.origins[0].lister_id == origin.lister_id + assert ret.origins[0].url == origin.url + + @pytest.mark.parametrize("num_origins,limit", [(20, 6), (5, 42), (20, 20)]) + def test_get_listed_origins_limit( + self, swh_scheduler, listed_origins, num_origins, limit + ) -> None: + added_origins = sorted( + listed_origins[:num_origins], key=lambda o: (o.lister_id, o.url) + ) + swh_scheduler.record_listed_origins(added_origins) + + returned_origins: List[ListedOrigin] = [] + call_count = 0 + next_page_token: Optional[ListedOriginPageToken] = None + while True: + call_count += 1 + ret = swh_scheduler.get_listed_origins( + lister_id=listed_origins[0].lister_id, + limit=limit, + page_token=next_page_token, + ) + returned_origins.extend(ret.origins) + next_page_token = ret.next_page_token + if next_page_token is None: + break + + assert call_count == (num_origins // limit) + 1 + + assert len(returned_origins) == num_origins + assert [(origin.lister_id, origin.url) for origin in returned_origins] == [ + (origin.lister_id, origin.url) for origin in added_origins + ] + + def test_get_listed_origins_max_limit(self, swh_scheduler): + ret = swh_scheduler.get_listed_origins() + assert ret.origins == [] + assert ret.next_page_token is None + + with pytest.raises(ArgumentError, match="max page size"): + swh_scheduler.get_listed_origins(limit=1001) + def _create_task_types(self, scheduler): for tt in TASK_TYPES.values(): scheduler.create_task_type(tt)