diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -6,10 +6,11 @@ from abc import ABC, abstractmethod import binascii from dataclasses import dataclass -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type, Union from swh.graphql.errors import PaginationError from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_node import BaseNode @@ -17,13 +18,13 @@ @dataclass class PageInfo: hasNextPage: bool - endCursor: str + endCursor: Optional[str] @dataclass class ConnectionEdge: node: Any - cursor: str + cursor: Optional[str] class BaseConnection(ABC): @@ -36,17 +37,23 @@ _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): - self.obj = obj + self.obj: Optional[Any] = obj self.info = info self.kwargs = kwargs - self._paged_data = paged_data + self._paged_data: PagedResult = paged_data @property - def edges(self): - return self._get_edges() + def edges(self) -> List[ConnectionEdge]: + """ + Return the list of connection edges, each with a cursor + """ + return [ + ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) + for (index, node) in enumerate(self.nodes) + ] @property - def nodes(self): + def nodes(self) -> List[Union[BaseNode, object]]: """ Override if needed; return a list of objects @@ -63,7 +70,7 @@ return self.get_paged_data().results @property - def pageInfo(self): # To support the schema naming convention + def pageInfo(self) -> PageInfo: # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), @@ -71,20 +78,17 @@ ) @property - def totalCount(self): # To support the schema naming convention - return self._get_total_count() - - def _get_total_count(self): + def totalCount(self) -> Optional[int]: # To support the schema naming convention """ Will be None for most of the connections override if needed/possible """ + return None - def get_paged_data(self): + def get_paged_data(self) -> PagedResult: """ - Cache to avoid multiple calls to - the backend (_get_paged_result) + Cache to avoid multiple calls to the backend :meth:`_get_paged_result` return a PagedResult object """ if self._paged_data is None: @@ -101,15 +105,6 @@ # FIXME, make this call async (not for v1) return None - def _get_edges(self): - """ - Return the list of connection edges, each with a cursor - """ - return [ - ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) - for (index, node) in enumerate(self.nodes) - ] - def _get_after_arg(self) -> str: """ Return the decoded next page token. Override to support a different @@ -133,7 +128,7 @@ ) return first - def _get_index_cursor(self, index: int, node: Any): + def _get_index_cursor(self, index: int, node: Any) -> Optional[str]: """ Get the cursor to the given item index """ diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py --- a/swh/graphql/resolvers/directory_entry.py +++ b/swh/graphql/resolvers/directory_entry.py @@ -5,6 +5,7 @@ from swh.graphql.backends import archive from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode @@ -31,7 +32,7 @@ _node_class = DirectoryEntryNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO diff --git a/swh/graphql/resolvers/origin.py b/swh/graphql/resolvers/origin.py --- a/swh/graphql/resolvers/origin.py +++ b/swh/graphql/resolvers/origin.py @@ -4,6 +4,7 @@ # See top-level LICENSE file for more information from swh.graphql.backends import archive +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode @@ -25,7 +26,7 @@ _node_class = OriginNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: return archive.Archive().get_origins( after=self._get_after_arg(), first=self._get_first_arg(), diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py --- a/swh/graphql/resolvers/revision.py +++ b/swh/graphql/resolvers/revision.py @@ -8,6 +8,7 @@ from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.swhids import CoreSWHID, ObjectType +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode @@ -76,7 +77,7 @@ _node_class = BaseRevisionNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj is the current(child) revision # self.obj.parent_swhids is the list of parent SWHIDs @@ -98,7 +99,7 @@ _node_class = BaseRevisionNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # STORAGE-TODO (date in revisionlog is a dict) log = archive.Archive().get_revision_log([self.obj.swhid.object_id]) # FIXME, using dummy(local) pagination, move pagination to backend diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -8,6 +8,7 @@ from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.model import Snapshot +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode @@ -81,7 +82,7 @@ _node_class = BaseSnapshotNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: results = archive.Archive().get_origin_snapshots(self.obj.url) snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results] # FIXME, using dummy(local) pagination, move pagination to backend diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -50,7 +50,7 @@ _node_class = SnapshotBranchNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: result = archive.Archive().get_snapshot_branches( self.obj.swhid.object_id, after=self._get_after_arg(), diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -5,6 +5,7 @@ from swh.graphql.backends import archive from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode @@ -59,7 +60,7 @@ _node_class = BaseVisitNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj.url is the origin URL return archive.Archive().get_origin_visits( self.obj.url, after=self._get_after_arg(), first=self._get_first_arg() diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py --- a/swh/graphql/resolvers/visit_status.py +++ b/swh/graphql/resolvers/visit_status.py @@ -5,6 +5,7 @@ from swh.graphql.backends import archive from swh.model.swhids import CoreSWHID, ObjectType +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode @@ -43,7 +44,7 @@ obj: BaseVisitNode _node_class = BaseVisitStatusNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj.origin is the origin URL return archive.Archive().get_visit_status( self.obj.origin, diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -5,7 +5,7 @@ import base64 from datetime import datetime -from typing import List +from typing import List, Optional from swh.storage.interface import PagedResult @@ -18,7 +18,7 @@ return base64.b64encode(source).decode("ascii") -def get_encoded_cursor(cursor: str) -> str: +def get_encoded_cursor(cursor: Optional[str]) -> Optional[str]: if cursor is None: return None return get_b64_string(cursor)