diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index 3adbb5f..8399166 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,112 +1,132 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Type +from typing import Any, Optional, Type from swh.graphql.utils import utils from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool endCursor: str +@dataclass +class ConnectionEdge: + node: Any + cursor: str + + class BaseConnection(ABC): """ Base resolver for all the connections """ _node_class: Optional[Type[BaseNode]] = None _page_size = 50 # default page size def __init__(self, obj, info, paged_data=None, **kwargs): self.obj = obj self.info = info self.kwargs = kwargs self._paged_data = paged_data def __call__(self, *args, **kw): return self @property def edges(self): return self._get_edges() @property def nodes(self): """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class(self.obj, self.info, node_data=result, **self.kwargs) for result in self.get_paged_data().results ] return self.get_paged_data().results @property def pageInfo(self): # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property def totalCount(self): # To support the schema naming convention return self._get_total_count() def _get_total_count(self): """ Will be None for most of the connections override if needed/possible """ return None def get_paged_data(self): """ Cache to avoid multiple calls to the backend (_get_paged_result) return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) return None def _get_edges(self): - # FIXME, make cursor work per item - # Cursor can't be None here - return [{"cursor": "dummy", "node": node} for node in self.nodes] + """ + Return the list of connection edges, each with a cursor + """ + return [ + ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) + for (index, node) in enumerate(self.nodes) + ] def _get_after_arg(self): """ Return the decoded next page token override to use a specific token """ return utils.get_decoded_cursor(self.kwargs.get("after")) def _get_first_arg(self): """ page_size is set to 50 by default """ return self.kwargs.get("first", self._page_size) + + def _get_index_cursor(self, index: int, node: Any): + """ + Get the cursor to the given item index + """ + # default implementation which works with swh-storage pagaination + # override this function to support other types (eg: SnapshotBranchConnection) + offset_index = self._get_after_arg() or "0" + index_cursor = int(offset_index) + index + return utils.get_encoded_cursor(str(index_cursor)) diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py index 2ac39cb..78e989f 100644 --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -1,79 +1,83 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode from .snapshot import SnapshotNode class SnapshotBranchNode(BaseNode): """ Node resolver for a snapshot branch """ # target field for this Node is a UNION type # It is resolved in the top level (resolvers.resolvers.py) def _get_node_from_data(self, node_data): # node_data is not a dict in this case # overriding to support this special data structure # STORAGE-TODO; return an object in the normal format branch_name, branch_obj = node_data node = { "name": branch_name, "type": branch_obj.target_type.value, "target": branch_obj.target, } return namedtuple("NodeObj", node.keys())(*node.values()) @property def targetHash(self): # To support the schema naming convention return self._node.target class SnapshotBranchConnection(BaseConnection): """ Connection resolver for the branches in a snapshot """ obj: SnapshotNode _node_class = SnapshotBranchNode def _get_paged_result(self): # self.obj.swhid is the snapshot SWHID result = archive.Archive().get_snapshot_branches( self.obj.swhid.object_id, after=self._get_after_arg(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), name_include=self.kwargs.get("nameInclude"), ) # FIXME Cursor must be a hex to be consistent with # the base class, hack to make that work end_cusrsor = ( result["next_branch"].hex() if result["next_branch"] is not None else None ) # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult # STORAGE-TODO return PagedResult( results=result["branches"].items(), next_page_token=end_cusrsor ) def _get_after_arg(self): # Snapshot branch is using a different cursor; logic to handle that # FIXME Cursor must be a hex to be consistent with # the base class, hack to make that work after = utils.get_decoded_cursor(self.kwargs.get("after", "")) return bytes.fromhex(after) + + def _get_index_cursor(self, index: int, node: SnapshotBranchNode): + # Snapshot branch is using a different cursor, hence the override + return utils.get_encoded_cursor(node.name.hex()) diff --git a/swh/graphql/tests/functional/test_pagination.py b/swh/graphql/tests/functional/test_pagination.py new file mode 100644 index 0000000..2db0289 --- /dev/null +++ b/swh/graphql/tests/functional/test_pagination.py @@ -0,0 +1,102 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from ..data import get_origins +from .utils import get_query_response + + +# Using Origin object to run functional tests for pagination +def test_pagination(client): + # requesting the max number of nodes available + # endCursor must be None + query_str = f""" + {{ + origins(first: {len(get_origins())}) {{ + nodes {{ + id + }} + pageInfo {{ + hasNextPage + endCursor + }} + }} + }} + """ + + data, _ = get_query_response(client, query_str) + assert len(data["origins"]["nodes"]) == len(get_origins()) + assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} + + +def get_first_node(client): + query_str = """ + { + origins(first: 1) { + nodes { + id + } + pageInfo { + hasNextPage + endCursor + } + } + } + """ + data, _ = get_query_response(client, query_str) + return data["origins"] + + +def test_first_arg(client): + origins = get_first_node(client) + assert len(origins["nodes"]) == 1 + assert origins["pageInfo"]["hasNextPage"] is True + + +def test_after_arg(client): + origins = get_first_node(client) + end_cursor = origins["pageInfo"]["endCursor"] + query_str = f""" + {{ + origins(first: 1, after: "{end_cursor}") {{ + nodes {{ + id + }} + pageInfo {{ + hasNextPage + endCursor + }} + }} + }} + """ + data, _ = get_query_response(client, query_str) + assert len(data["origins"]["nodes"]) == 1 + assert data["origins"]["pageInfo"] == {"hasNextPage": False, "endCursor": None} + + +def test_edge_cursor(client): + origins = get_first_node(client) + # end cursor here must be the item cursor for the second item + end_cursor = origins["pageInfo"]["endCursor"] + + query_str = f""" + {{ + origins(first: 1, after: "{end_cursor}") {{ + edges {{ + cursor + node {{ + id + }} + }} + nodes {{ + id + }} + }} + }} + """ + data, _ = get_query_response(client, query_str) + origins = data["origins"] + # nodes in list node fields in edges must be the same + assert [edge["node"] for edge in origins["edges"]] == origins["nodes"] + assert origins["edges"][0]["cursor"] == end_cursor