diff --git a/swh/web/api/views/origin.py b/swh/web/api/views/origin.py --- a/swh/web/api/views/origin.py +++ b/swh/web/api/views/origin.py @@ -80,28 +80,27 @@ .. parsed-literal:: :swh_web_api:`origins?origin_count=500` + """ - origin_from = int(request.query_params.get("origin_from", "1")) - origin_count = int(request.query_params.get("origin_count", "100")) - origin_count = min(origin_count, 10000) - results = api_lookup( - service.lookup_origins, - origin_from, - origin_count + 1, - enrich_fn=enrich_origin, - request=request, - ) - response = {"results": results, "headers": {}} - if len(results) > origin_count: - origin_from = results.pop()["id"] + old_param_origin_from = request.query_params.get("origin_from") + + if old_param_origin_from: + raise BadInputExc("Please use the Link header to browse through result") + + page_token = request.query_params.get("page_token", None) + limit = min(int(request.query_params.get("origin_count", "100")), 10000) + + page_result = service.lookup_origins(page_token, limit) + origins = [enrich_origin(o, request=request) for o in page_result.results] + next_page_token = page_result.next_page_token + + response = {"results": origins, "headers": {}} + if next_page_token is not None: response["headers"]["link-next"] = reverse( "api-1-origins", - query_params={"origin_from": origin_from, "origin_count": origin_count}, + query_params={"page_token": next_page_token, "origin_count": limit}, request=request, ) - for result in results: - if "id" in result: - del result["id"] return response diff --git a/swh/web/common/service.py b/swh/web/common/service.py --- a/swh/web/common/service.py +++ b/swh/web/common/service.py @@ -22,7 +22,12 @@ from swh.web.common import query from swh.web.common.exc import BadInputExc, NotFoundExc from swh.web.common.origin_visits import get_origin_visit -from swh.web.common.typing import OriginInfo, OriginVisitInfo, OriginMetadataInfo +from swh.web.common.typing import ( + OriginInfo, + OriginVisitInfo, + OriginMetadataInfo, + PagedResult, +) search = config.search() @@ -243,8 +248,8 @@ def lookup_origins( - origin_from: int = 1, origin_count: int = 100 -) -> Iterator[OriginInfo]: + page_token: Optional[str], limit: int = 100 +) -> PagedResult[OriginInfo]: """Get list of archived software origins in a paginated way. Origins are sorted by id before returning them @@ -253,11 +258,15 @@ origin_from (int): The minimum id of the origins to return origin_count (int): The maximum number of origins to return - Yields: - origins information as dicts + Returns: + Page of OriginInfo + """ - origins = storage.origin_get_range(origin_from, origin_count) - return map(converters.from_origin, origins) + page = storage.origin_list(page_token=page_token, limit=limit) + return PagedResult( + [converters.from_origin(o.to_dict()) for o in page.results], + next_page_token=page.next_page_token, + ) def search_origin( diff --git a/swh/web/common/typing.py b/swh/web/common/typing.py --- a/swh/web/common/typing.py +++ b/swh/web/common/typing.py @@ -3,7 +3,9 @@ # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Dict, List, Optional, Union +from swh.core.api.classes import PagedResult as CorePagedResult + +from typing import Any, Dict, List, Optional, TypeVar, Union from typing_extensions import TypedDict from django.http import QueryDict @@ -213,3 +215,9 @@ synthetic: bool type: str snapshot: Optional[str] + + +TResult = TypeVar("TResult") + + +PagedResult = CorePagedResult[TResult, str] diff --git a/swh/web/tests/api/views/test_origin.py b/swh/web/tests/api/views/test_origin.py --- a/swh/web/tests/api/views/test_origin.py +++ b/swh/web/tests/api/views/test_origin.py @@ -364,13 +364,32 @@ } +def test_api_origins_wrong_input(api_client, archive_data): + """Should fail with 400 if the input is deprecated. + + """ + # fail if wrong input + url = reverse("api-1-origins", query_params={"origin_from": 1}) + rv = api_client.get(url) + + assert rv.status_code == 400, rv.data + assert rv["Content-Type"] == "application/json" + + assert rv.data == { + "exception": "BadInputExc", + "reason": "Please use the Link header to browse through result", + } + + def test_api_origins(api_client, archive_data): - origins = list(archive_data.origin_get_range(0, 10000)) - origin_urls = {origin["url"] for origin in origins} + page_result = archive_data.origin_list(limit=10000) + origins = page_result.results + origin_urls = {origin.url for origin in origins} # Get only one url = reverse("api-1-origins", query_params={"origin_count": 1}) rv = api_client.get(url) + assert rv.status_code == 200, rv.data assert rv["Content-Type"] == "application/json" assert len(rv.data) == 1 @@ -395,8 +414,9 @@ @pytest.mark.parametrize("origin_count", [1, 2, 10, 100]) def test_api_origins_scroll(api_client, archive_data, origin_count): - origins = list(archive_data.origin_get_range(0, 10000)) - origin_urls = {origin["url"] for origin in origins} + page_result = archive_data.origin_list(limit=10000) + origins = page_result.results + origin_urls = {origin.url for origin in origins} url = reverse("api-1-origins", query_params={"origin_count": origin_count})