diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py --- a/swh/graphql/errors/__init__.py +++ b/swh/graphql/errors/__init__.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from .errors import ObjectNotFoundError +from .errors import ObjectNotFoundError, PaginationError from .handlers import format_error -__all__ = ["ObjectNotFoundError", "format_error"] +__all__ = ["ObjectNotFoundError", "PaginationError", "format_error"] diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py --- a/swh/graphql/errors/errors.py +++ b/swh/graphql/errors/errors.py @@ -6,3 +6,13 @@ class ObjectNotFoundError(Exception): """ """ + + +class PaginationError(Exception): + """ """ + + msg: str = "Error in pagination input" + + def __init__(self, message, errors=None): + # FIXME, log this error + super().__init__(f"{self.msg}: {message}") diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -4,9 +4,11 @@ # See top-level LICENSE file for more information from abc import ABC, abstractmethod +import binascii from dataclasses import dataclass from typing import Any, Optional, Type +from swh.graphql.errors import PaginationError from swh.graphql.utils import utils from .base_node import BaseNode @@ -31,6 +33,7 @@ _node_class: Optional[Type[BaseNode]] = None _page_size = 50 # default page size + _max_page_size = 1000 # maximum value for the first arg def __init__(self, obj, info, paged_data=None, **kwargs): self.obj = obj @@ -108,18 +111,28 @@ for (index, node) in enumerate(self.nodes) ] - def _get_after_arg(self): - """ - Return the decoded next page token - override to use a specific token - """ - return utils.get_decoded_cursor(self.kwargs.get("after")) - - def _get_first_arg(self): - """ - page_size is set to 50 by default - """ - return self.kwargs.get("first", self._page_size) + def _get_after_arg(self) -> str: + """ + Return the decoded next page token. Override to support a different + cursor type + """ + # different cursor is used in SnapshotBranchConnection + try: + cursor = utils.get_decoded_cursor(self.kwargs.get("after")) + except (UnicodeDecodeError, binascii.Error, Exception) as e: + raise PaginationError("Invalid value for argument 'after'", errors=e) + return cursor + + def _get_first_arg(self) -> int: + """ """ + # page_size is set to 50 by default + # Input type check is not required; It is defined in schema as an int + first = self.kwargs.get("first", self._page_size) + if first < 0 or first > self._max_page_size: + raise PaginationError( + "Value for argument 'first' is either too big or invalid" + ) + return first def _get_index_cursor(self, index: int, node: Any): """ diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -51,16 +51,17 @@ _node_class = SnapshotBranchNode def _get_paged_result(self): - # self.obj.swhid is the snapshot SWHID + # after argument must be an empty string by default + after = self._get_after_arg() if self._get_after_arg() else "" result = archive.Archive().get_snapshot_branches( self.obj.swhid.object_id, - after=self._get_after_arg(), + after=after.encode(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), name_include=self._get_name_include_arg(), ) - # FIXME Cursor must be a hex to be consistent with - # the base class, hack to make that work + + # endCursor is the last branch name, logic for that end_cusrsor = ( result["next_branch"] if result["next_branch"] is not None else None ) @@ -71,15 +72,10 @@ results=result["branches"].items(), next_page_token=end_cusrsor ) - def _get_after_arg(self): - # Snapshot branch is using a different cursor; logic to handle that - after = utils.get_decoded_cursor(self.kwargs.get("after", "")) - return after.encode() if after else b"" - def _get_name_include_arg(self): name_include = self.kwargs.get("nameInclude", None) return name_include.encode() if name_include else None def _get_index_cursor(self, index: int, node: SnapshotBranchNode): # Snapshot branch is using a different cursor, hence the override - return utils.get_encoded_cursor(node.name.hex()) + return utils.get_encoded_cursor(node.name) diff --git a/swh/graphql/tests/functional/test_pagination.py b/swh/graphql/tests/functional/test_pagination.py --- a/swh/graphql/tests/functional/test_pagination.py +++ b/swh/graphql/tests/functional/test_pagination.py @@ -8,32 +8,10 @@ # Using Origin object to run functional tests for pagination -def test_pagination(client): - # requesting the max number of nodes available - # endCursor must be None - query_str = f""" - {{ - origins(first: {len(get_origins())}) {{ - nodes {{ - id - }} - pageInfo {{ - hasNextPage - endCursor - }} - }} - }} - """ - - data, _ = get_query_response(client, query_str) - assert len(data["origins"]["nodes"]) == len(get_origins()) - assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} - - -def get_first_node(client): +def get_origin_nodes(client, first=1, after=""): query_str = """ { - origins(first: 1) { + origins(first: %s, %s) { nodes { id } @@ -43,58 +21,89 @@ } } } - """ - data, _ = get_query_response(client, query_str) - return data["origins"] + """ % ( + first, + after, + ) + return get_query_response(client, query_str) + + +def test_pagination(client): + # requesting the max number of nodes available + # endCursor must be None + data, _ = get_origin_nodes(client, len(get_origins())) + assert len(data["origins"]["nodes"]) == len(get_origins()) + assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} def test_first_arg(client): - origins = get_first_node(client) - assert len(origins["nodes"]) == 1 - assert origins["pageInfo"]["hasNextPage"] is True + data, _ = get_origin_nodes(client, 1) + assert len(data["origins"]["nodes"]) == 1 + assert data["origins"]["pageInfo"]["hasNextPage"] is True + + +def test_invalid_first_arg(client): + data, errors = get_origin_nodes(client, -1) + assert data["origins"] is None + assert (len(errors)) == 2 # one error for origins and anotehr one for pageInfo + assert ( + errors[0]["message"] + == "Error in pagination input: Value for argument 'first' is either too big or invalid" + ) + + +def test_too_big_first_arg(client): + data, errors = get_origin_nodes(client, 1001) # max page size is 1000 + assert data["origins"] is None + assert (len(errors)) == 2 + assert ( + errors[0]["message"] + == "Error in pagination input: Value for argument 'first' is either too big or invalid" + ) def test_after_arg(client): - origins = get_first_node(client) - end_cursor = origins["pageInfo"]["endCursor"] - query_str = f""" - {{ - origins(first: 1, after: "{end_cursor}") {{ - nodes {{ - id - }} - pageInfo {{ - hasNextPage - endCursor - }} - }} - }} - """ - data, _ = get_query_response(client, query_str) + first_data, _ = get_origin_nodes(client) + end_cursor = first_data["origins"]["pageInfo"]["endCursor"] + # get again with endcursor as the after argument + data, _ = get_origin_nodes(client, 1, f'after: "{end_cursor}"') assert len(data["origins"]["nodes"]) == 1 assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} +def test_invalid_after_arg(client): + data, errors = get_origin_nodes(client, 1, 'after: "invalid"') + assert data["origins"] is None + assert (len(errors)) == 2 + assert ( + errors[0]["message"] + == "Error in pagination input: Invalid value for argument 'after'" + ) + + def test_edge_cursor(client): - origins = get_first_node(client) + origins = get_origin_nodes(client)[0]["origins"] # end cursor here must be the item cursor for the second item end_cursor = origins["pageInfo"]["endCursor"] - query_str = f""" - {{ - origins(first: 1, after: "{end_cursor}") {{ - edges {{ + query_str = ( + """ + { + origins(first: 1, after: "%s") { + edges { cursor - node {{ + node { id - }} - }} - nodes {{ + } + } + nodes { id - }} - }} - }} + } + } + } """ + % end_cursor + ) data, _ = get_query_response(client, query_str) origins = data["origins"] assert [edge["node"] for edge in origins["edges"]] == origins["nodes"] diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -27,7 +27,7 @@ def get_decoded_cursor(cursor: str) -> str: if cursor is None: return None - return base64.b64decode(cursor).decode(ENCODING) + return base64.b64decode(cursor, validate=True).decode() def str_to_sha1(sha1: str) -> bytearray: