diff --git a/swh/graphql/errors/__init__.py b/swh/graphql/errors/__init__.py index 3e8bd41..7bc04e9 100644 --- a/swh/graphql/errors/__init__.py +++ b/swh/graphql/errors/__init__.py @@ -1,9 +1,9 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from .errors import ObjectNotFoundError +from .errors import ObjectNotFoundError, PaginationError from .handlers import format_error -__all__ = ["ObjectNotFoundError", "format_error"] +__all__ = ["ObjectNotFoundError", "PaginationError", "format_error"] diff --git a/swh/graphql/errors/errors.py b/swh/graphql/errors/errors.py index 8036b65..37ecd35 100644 --- a/swh/graphql/errors/errors.py +++ b/swh/graphql/errors/errors.py @@ -1,8 +1,18 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information class ObjectNotFoundError(Exception): """ """ + + +class PaginationError(Exception): + """ """ + + msg: str = "Pagination error" + + def __init__(self, message, errors=None): + # FIXME, log this error + super().__init__(f"{self.msg}: {message}") diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index ee16b26..f78ccbb 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,132 +1,145 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from abc import ABC, abstractmethod +import binascii from dataclasses import dataclass from typing import Any, Optional, Type +from swh.graphql.errors import PaginationError from swh.graphql.utils import utils from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool endCursor: str @dataclass class ConnectionEdge: node: Any cursor: str class BaseConnection(ABC): """ Base resolver for all the connections """ _node_class: Optional[Type[BaseNode]] = None - _page_size = 50 # default page size + _page_size: int = 50 # default page size (default value for the first arg) + _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): self.obj = obj self.info = info self.kwargs = kwargs self._paged_data = paged_data def __call__(self, *args, **kw): return self @property def edges(self): return self._get_edges() @property def nodes(self): """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class(self, self.info, node_data=result, **self.kwargs) for result in self.get_paged_data().results ] return self.get_paged_data().results @property def pageInfo(self): # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property def totalCount(self): # To support the schema naming convention return self._get_total_count() def _get_total_count(self): """ Will be None for most of the connections override if needed/possible """ return None def get_paged_data(self): """ Cache to avoid multiple calls to the backend (_get_paged_result) return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) return None def _get_edges(self): """ Return the list of connection edges, each with a cursor """ return [ ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) for (index, node) in enumerate(self.nodes) ] - def _get_after_arg(self): - """ - Return the decoded next page token - override to use a specific token - """ - return utils.get_decoded_cursor(self.kwargs.get("after")) - - def _get_first_arg(self): - """ - page_size is set to 50 by default - """ - return self.kwargs.get("first", self._page_size) + def _get_after_arg(self) -> str: + """ + Return the decoded next page token. Override to support a different + cursor type + """ + # different implementation is used in SnapshotBranchConnection + try: + cursor = utils.get_decoded_cursor(self.kwargs.get("after")) + except (UnicodeDecodeError, binascii.Error, Exception) as e: + raise PaginationError("Invalid value for argument 'after'", errors=e) + return cursor + + def _get_first_arg(self) -> int: + """ """ + # page_size is set to 50 by default + # Input type check is not required; It is defined in schema as an int + first = self.kwargs.get("first", self._page_size) + if first < 0 or first > self._max_page_size: + raise PaginationError( + f"Value for argument 'first' is invalid; it must be between 0 and {self._max_page_size}" # noqa: B950 + ) + return first def _get_index_cursor(self, index: int, node: Any): """ Get the cursor to the given item index """ # default implementation which works with swh-storage pagaination # override this function to support other types (eg: SnapshotBranchConnection) offset_index = self._get_after_arg() or "0" index_cursor = int(offset_index) + index return utils.get_encoded_cursor(str(index_cursor)) diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py index 6f780b4..349150a 100644 --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -1,85 +1,84 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode class SnapshotBranchNode(BaseNode): """ Node resolver for a snapshot branch """ # target field for this Node is a UNION type # It is resolved in the top level (resolvers.resolvers.py) def _get_node_from_data(self, node_data): # node_data is not a dict in this case # overriding to support this special data structure # STORAGE-TODO; return an object in the normal format branch_name, branch_obj = node_data node = { "name": branch_name, "type": branch_obj.target_type.value, "target": branch_obj.target, } return namedtuple("NodeObj", node.keys())(*node.values()) @property def targetHash(self): # To support the schema naming convention return self._node.target class SnapshotBranchConnection(BaseConnection): """ Connection resolver for the branches in a snapshot """ from .snapshot import SnapshotNode obj: SnapshotNode _node_class = SnapshotBranchNode def _get_paged_result(self): - # self.obj.swhid is the snapshot SWHID result = archive.Archive().get_snapshot_branches( self.obj.swhid.object_id, after=self._get_after_arg(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), name_include=self._get_name_include_arg(), ) - # FIXME Cursor must be a hex to be consistent with - # the base class, hack to make that work + + # endCursor is the last branch name, logic for that end_cusrsor = ( result["next_branch"] if result["next_branch"] is not None else None ) # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult # STORAGE-TODO return PagedResult( results=result["branches"].items(), next_page_token=end_cusrsor ) def _get_after_arg(self): - # Snapshot branch is using a different cursor; logic to handle that - after = utils.get_decoded_cursor(self.kwargs.get("after", "")) + # after argument must be an empty string by default + after = super()._get_after_arg() return after.encode() if after else b"" def _get_name_include_arg(self): name_include = self.kwargs.get("nameInclude", None) return name_include.encode() if name_include else None def _get_index_cursor(self, index: int, node: SnapshotBranchNode): # Snapshot branch is using a different cursor, hence the override - return utils.get_encoded_cursor(node.name.hex()) + return utils.get_encoded_cursor(node.name) diff --git a/swh/graphql/tests/functional/test_pagination.py b/swh/graphql/tests/functional/test_pagination.py index d8b4962..2370d49 100644 --- a/swh/graphql/tests/functional/test_pagination.py +++ b/swh/graphql/tests/functional/test_pagination.py @@ -1,101 +1,109 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from ..data import get_origins from .utils import get_query_response # Using Origin object to run functional tests for pagination -def test_pagination(client): - # requesting the max number of nodes available - # endCursor must be None - query_str = f""" - {{ - origins(first: {len(get_origins())}) {{ - nodes {{ - id - }} - pageInfo {{ - hasNextPage - endCursor - }} - }} - }} - """ - - data, _ = get_query_response(client, query_str) - assert len(data["origins"]["nodes"]) == len(get_origins()) - assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} - - -def get_first_node(client): +def get_origin_nodes(client, first=1, after=""): query_str = """ { - origins(first: 1) { + origins(first: %s, %s) { nodes { id } pageInfo { hasNextPage endCursor } } } - """ - data, _ = get_query_response(client, query_str) - return data["origins"] + """ % ( + first, + after, + ) + return get_query_response(client, query_str) + + +def test_pagination(client): + # requesting the max number of nodes available + # endCursor must be None + data, _ = get_origin_nodes(client, len(get_origins())) + assert len(data["origins"]["nodes"]) == len(get_origins()) + assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} def test_first_arg(client): - origins = get_first_node(client) - assert len(origins["nodes"]) == 1 - assert origins["pageInfo"]["hasNextPage"] is True + data, _ = get_origin_nodes(client, 1) + assert len(data["origins"]["nodes"]) == 1 + assert data["origins"]["pageInfo"]["hasNextPage"] is True + + +def test_invalid_first_arg(client): + data, errors = get_origin_nodes(client, -1) + assert data["origins"] is None + assert (len(errors)) == 2 # one error for origins and anotehr one for pageInfo + assert ( + errors[0]["message"] + == "Pagination error: Value for argument 'first' is invalid; it must be between 0 and 1000" # noqa: B950 + ) + + +def test_too_big_first_arg(client): + data, errors = get_origin_nodes(client, 1001) # max page size is 1000 + assert data["origins"] is None + assert (len(errors)) == 2 + assert ( + errors[0]["message"] + == "Pagination error: Value for argument 'first' is invalid; it must be between 0 and 1000" # noqa: B950 + ) def test_after_arg(client): - origins = get_first_node(client) - end_cursor = origins["pageInfo"]["endCursor"] - query_str = f""" - {{ - origins(first: 1, after: "{end_cursor}") {{ - nodes {{ - id - }} - pageInfo {{ - hasNextPage - endCursor - }} - }} - }} - """ - data, _ = get_query_response(client, query_str) + first_data, _ = get_origin_nodes(client) + end_cursor = first_data["origins"]["pageInfo"]["endCursor"] + # get again with endcursor as the after argument + data, _ = get_origin_nodes(client, 1, f'after: "{end_cursor}"') assert len(data["origins"]["nodes"]) == 1 assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} +def test_invalid_after_arg(client): + data, errors = get_origin_nodes(client, 1, 'after: "invalid"') + assert data["origins"] is None + assert (len(errors)) == 2 + assert ( + errors[0]["message"] == "Pagination error: Invalid value for argument 'after'" + ) + + def test_edge_cursor(client): - origins = get_first_node(client) + origins = get_origin_nodes(client)[0]["origins"] # end cursor here must be the item cursor for the second item end_cursor = origins["pageInfo"]["endCursor"] - query_str = f""" - {{ - origins(first: 1, after: "{end_cursor}") {{ - edges {{ + query_str = ( + """ + { + origins(first: 1, after: "%s") { + edges { cursor - node {{ + node { id - }} - }} - nodes {{ + } + } + nodes { id - }} - }} - }} + } + } + } """ + % end_cursor + ) data, _ = get_query_response(client, query_str) origins = data["origins"] assert [edge["node"] for edge in origins["edges"]] == origins["nodes"] assert origins["edges"][0]["cursor"] == end_cursor diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py index 87984af..5dd9411 100644 --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -1,58 +1,58 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from datetime import datetime from typing import List from swh.storage.interface import PagedResult ENCODING = "utf-8" def get_b64_string(source) -> str: if type(source) is str: source = source.encode(ENCODING) return base64.b64encode(source).decode("ascii") def get_encoded_cursor(cursor: str) -> str: if cursor is None: return None return get_b64_string(cursor) def get_decoded_cursor(cursor: str) -> str: if cursor is None: return None - return base64.b64decode(cursor).decode(ENCODING) + return base64.b64decode(cursor, validate=True).decode() def str_to_sha1(sha1: str) -> bytearray: # FIXME, use core function return bytearray.fromhex(sha1) def get_formatted_date(date: datetime) -> str: # FIXME, handle error + return other formats return date.isoformat() def paginated(source: List, first: int, after=0) -> PagedResult: """ Pagination at the GraphQL level This is a temporary fix and inefficient. Should eventually be moved to the backend (storage) level """ # FIXME, handle data errors here after = 0 if after is None else int(after) end_cursor = after + first results = source[after:end_cursor] next_page_token = None if len(source) > end_cursor: next_page_token = str(end_cursor) return PagedResult(results=results, next_page_token=next_page_token)