diff --git a/swh/storage/cassandra/storage.py b/swh/storage/cassandra/storage.py --- a/swh/storage/cassandra/storage.py +++ b/swh/storage/cassandra/storage.py @@ -705,26 +705,33 @@ results.append(None) return results - def origin_list(self, page_token: Optional[str] = None, limit: int = 100) -> dict: + def origin_list( + self, page_token: Optional[str] = None, limit: int = 100 + ) -> PagedResult[Origin]: # Compute what token to begin the listing from start_token = TOKEN_BEGIN if page_token: start_token = int(page_token) if not (TOKEN_BEGIN <= start_token <= TOKEN_END): raise StorageArgumentException("Invalid page_token.") + next_page_token = None - rows = self._cql_runner.origin_list(start_token, limit) - rows = list(rows) + origins = [] + # Take one more origin so we can reuse it as the next page token if any + for row in self._cql_runner.origin_list(start_token, limit + 1): + origins.append(Origin(url=row.url)) + # keep reference of the last id for pagination purposes + last_id = row.tok - if len(rows) == limit: - next_page_token: Optional[str] = str(rows[-1].tok + 1) - else: - next_page_token = None + if len(origins) > limit: + # last origin id is the next page token + next_page_token = str(last_id) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] - return { - "origins": [{"url": row.url} for row in rows], - "next_page_token": next_page_token, - } + assert (len(origins)) <= limit + + return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -960,15 +960,16 @@ origin_get_range_cols = ["id", "url"] - def origin_get_range(self, origin_from=1, origin_count=100, cur=None): + def origin_get_range(self, origin_from: int = 1, origin_count: int = 100, cur=None): """Retrieve ``origin_count`` origins whose ids are greater or equal than ``origin_from``. Origins are sorted by id before retrieving them. Args: - origin_from (int): the minimum id of origins to retrieve - origin_count (int): the maximum number of origins to retrieve + origin_from: the minimum id of origins to retrieve + origin_count: the maximum number of origins to retrieve + """ cur = self._cursor(cur) diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -690,23 +690,25 @@ origin = self._convert_origin(self._origins[self._origins_by_id[idx]]) yield {"id": idx + 1, **origin} - def origin_list(self, page_token: Optional[str] = None, limit: int = 100) -> dict: + def origin_list( + self, page_token: Optional[str] = None, limit: int = 100 + ) -> PagedResult[Origin]: origin_urls = sorted(self._origins) - if page_token: - from_ = bisect.bisect_left(origin_urls, page_token) - else: - from_ = 0 + from_ = bisect.bisect_left(origin_urls, page_token) if page_token else 0 + next_page_token = None - result = { - "origins": [ - {"url": origin_url} for origin_url in origin_urls[from_ : from_ + limit] - ] - } + # Take one more origin so we can reuse it as the next page token if any + origins = [Origin(url=url) for url in origin_urls[from_ : from_ + limit + 1]] - if from_ + limit < len(origin_urls): - result["next_page_token"] = origin_urls[from_ + limit] + if len(origins) > limit: + # last origin id is the next page token + next_page_token = str(origins[-1].url) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] - return result + assert len(origins) <= limit + + return PagedResult(results=origins, next_page_token=next_page_token) def origin_search( self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False diff --git a/swh/storage/interface.py b/swh/storage/interface.py --- a/swh/storage/interface.py +++ b/swh/storage/interface.py @@ -1022,7 +1022,9 @@ ... @remote_api_endpoint("origin/list") - def origin_list(self, page_token: Optional[str] = None, limit: int = 100) -> dict: + def origin_list( + self, page_token: Optional[str] = None, limit: int = 100 + ) -> PagedResult[Origin]: """Returns the list of origins Args: @@ -1030,12 +1032,9 @@ limit: the maximum number of results to return Returns: - dict: dict with the following keys: - - **next_page_token** (str, optional): opaque token to be used as - `page_token` for retrieving the next page. if absent, there is - no more pages to gather. - - **origins** (List[dict]): list of origins, as returned by - `origin_get`. + Page of Origin data model objects. if next_page_token is None, there is + no longer data to retrieve. + """ ... diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -10,7 +10,6 @@ from collections import defaultdict from contextlib import contextmanager from typing import ( - Any, Counter, Dict, Iterable, @@ -1083,26 +1082,28 @@ @db_transaction() def origin_list( self, page_token: Optional[str] = None, limit: int = 100, *, db=None, cur=None - ) -> dict: + ) -> PagedResult[Origin]: page_token = page_token or "0" if not isinstance(page_token, str): raise StorageArgumentException("page_token must be a string.") origin_from = int(page_token) - result: Dict[str, Any] = { - "origins": [ - dict(zip(db.origin_get_range_cols, origin)) - for origin in db.origin_get_range(origin_from, limit, cur) - ], - } - - assert len(result["origins"]) <= limit - if len(result["origins"]) == limit: - result["next_page_token"] = str(result["origins"][limit - 1]["id"] + 1) - - for origin in result["origins"]: - del origin["id"] + next_page_token = None - return result + origins: List[Origin] = [] + # Take one more origin so we can reuse it as the next page token if any + for row_d in self.origin_get_range(origin_from, limit + 1, db=db, cur=cur): + origins.append(Origin(url=row_d["url"])) + # keep the last_id for the pagination if needed + last_id = row_d["id"] + + if len(origins) > limit: # data left for subsequent call + # last origin id is the next page token + next_page_token = str(last_id) + # excluding that origin from the result to respect the limit size + origins = origins[:limit] + + assert len(origins) <= limit + return PagedResult(results=origins, next_page_token=next_page_token) @timed @db_transaction_generator() diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -4005,22 +4005,21 @@ page_token = None i = 0 while True: - result = swh_storage.origin_list(page_token=page_token, limit=limit) - assert len(result["origins"]) <= limit + actual_page = swh_storage.origin_list(page_token=page_token, limit=limit) + assert len(actual_page.results) <= limit - returned_origins.extend(origin["url"] for origin in result["origins"]) + returned_origins.extend(actual_page.results) i += 1 - page_token = result.get("next_page_token") + page_token = actual_page.next_page_token if page_token is None: assert i * limit >= len(swh_origins) break else: - assert len(result["origins"]) == limit + assert len(actual_page.results) == limit - expected_origins = [origin.url for origin in swh_origins] - assert sorted(returned_origins) == sorted(expected_origins) + assert sorted(returned_origins) == sorted(swh_origins) def test_origin_count(self, swh_storage, sample_data): swh_storage.origin_add(sample_data.origins)