diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index fee50f1..fdda631 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,144 +1,139 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from abc import ABC, abstractmethod import binascii from dataclasses import dataclass -from typing import Any, Optional, Type +from typing import Any, List, Optional, Type, Union from swh.graphql.errors import PaginationError from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool - endCursor: str + endCursor: Optional[str] @dataclass class ConnectionEdge: node: Any - cursor: str + cursor: Optional[str] class BaseConnection(ABC): """ Base resolver for all the connections """ _node_class: Optional[Type[BaseNode]] = None _page_size: int = 50 # default page size (default value for the first arg) _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): - self.obj = obj + self.obj: Optional[Any] = obj self.info = info self.kwargs = kwargs - self._paged_data = paged_data + self._paged_data: PagedResult = paged_data @property - def edges(self): - return self._get_edges() + def edges(self) -> List[ConnectionEdge]: + """ + Return the list of connection edges, each with a cursor + """ + return [ + ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) + for (index, node) in enumerate(self.nodes) + ] @property - def nodes(self): + def nodes(self) -> List[Union[BaseNode, object]]: """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class( obj=self, info=self.info, node_data=result, **self.kwargs ) for result in self.get_paged_data().results ] return self.get_paged_data().results @property - def pageInfo(self): # To support the schema naming convention + def pageInfo(self) -> PageInfo: # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property - def totalCount(self): # To support the schema naming convention - return self._get_total_count() - - def _get_total_count(self): + def totalCount(self) -> Optional[int]: # To support the schema naming convention """ Will be None for most of the connections override if needed/possible """ + return None - def get_paged_data(self): + def get_paged_data(self) -> PagedResult: """ - Cache to avoid multiple calls to - the backend (_get_paged_result) + Cache to avoid multiple calls to the backend :meth:`_get_paged_result` return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) return None - def _get_edges(self): - """ - Return the list of connection edges, each with a cursor - """ - return [ - ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) - for (index, node) in enumerate(self.nodes) - ] - def _get_after_arg(self) -> str: """ Return the decoded next page token. Override to support a different cursor type """ # different implementation is used in SnapshotBranchConnection try: cursor = utils.get_decoded_cursor(self.kwargs.get("after")) except (UnicodeDecodeError, binascii.Error) as e: raise PaginationError("Invalid value for argument 'after'", errors=e) return cursor def _get_first_arg(self) -> int: """ """ # page_size is set to 50 by default # Input type check is not required; It is defined in schema as an int first = self.kwargs.get("first", self._page_size) if first < 0 or first > self._max_page_size: raise PaginationError( f"Value for argument 'first' is invalid; it must be between 0 and {self._max_page_size}" # noqa: B950 ) return first - def _get_index_cursor(self, index: int, node: Any): + def _get_index_cursor(self, index: int, node: Any) -> Optional[str]: """ Get the cursor to the given item index """ # default implementation which works with swh-storage pagaination # override this function to support other types (eg: SnapshotBranchConnection) offset_index = self._get_after_arg() or "0" index_cursor = int(offset_index) + index return utils.get_encoded_cursor(str(index_cursor)) diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py index c99da81..6fd30c1 100644 --- a/swh/graphql/resolvers/directory_entry.py +++ b/swh/graphql/resolvers/directory_entry.py @@ -1,41 +1,42 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.graphql.backends import archive from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode class DirectoryEntryNode(BaseNode): """ Node resolver for a directory entry """ @property def target_hash(self): # for DirectoryNode return self._node.target class DirectoryEntryConnection(BaseConnection): """ Connection resolver for entries in a directory """ from .directory import BaseDirectoryNode obj: BaseDirectoryNode _node_class = DirectoryEntryNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO entries = ( archive.Archive().get_directory_entries(self.obj.swhid.object_id).results ) return utils.paginated(entries, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/origin.py b/swh/graphql/resolvers/origin.py index ee67db1..9ab2f27 100644 --- a/swh/graphql/resolvers/origin.py +++ b/swh/graphql/resolvers/origin.py @@ -1,33 +1,34 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.graphql.backends import archive +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode class OriginNode(BaseSWHNode): """ Node resolver for an origin requested directly with its URL """ def _get_node_data(self): return archive.Archive().get_origin(self.kwargs.get("url")) class OriginConnection(BaseConnection): """ Connection resolver for the origins """ _node_class = OriginNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: return archive.Archive().get_origins( after=self._get_after_arg(), first=self._get_first_arg(), url_pattern=self.kwargs.get("urlPattern"), ) diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py index 1ccfb42..7272607 100644 --- a/swh/graphql/resolvers/revision.py +++ b/swh/graphql/resolvers/revision.py @@ -1,107 +1,108 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.swhids import CoreSWHID, ObjectType +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode from .release import BaseReleaseNode from .snapshot_branch import SnapshotBranchNode class BaseRevisionNode(BaseSWHNode): """ Base resolver for all the revision nodes """ def _get_revision_by_id(self, revision_id): return (archive.Archive().get_revisions([revision_id]) or None)[0] @property def parent_swhids(self): # for ParentRevisionConnection resolver return [ CoreSWHID(object_type=ObjectType.REVISION, object_id=parent_id) for parent_id in self._node.parents ] @property def directory_swhid(self): # for RevisionDirectoryNode resolver return CoreSWHID( object_type=ObjectType.DIRECTORY, object_id=self._node.directory ) @property def type(self): return self._node.type.value def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Revision" class RevisionNode(BaseRevisionNode): """ Node resolver for a revision requested directly with its SWHID """ def _get_node_data(self): return self._get_revision_by_id(self.kwargs.get("swhid").object_id) class TargetRevisionNode(BaseRevisionNode): """ Node resolver for a revision requested as a target """ obj: Union[SnapshotBranchNode, BaseReleaseNode] def _get_node_data(self): # self.obj.target_hash is the requested revision id return self._get_revision_by_id(self.obj.target_hash) class ParentRevisionConnection(BaseConnection): """ Connection resolver for parent revisions in a revision """ obj: BaseRevisionNode _node_class = BaseRevisionNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj is the current(child) revision # self.obj.parent_swhids is the list of parent SWHIDs # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) parents = archive.Archive().get_revisions( [x.object_id for x in self.obj.parent_swhids] ) return utils.paginated(parents, self._get_first_arg(), self._get_after_arg()) class LogRevisionConnection(BaseConnection): """ Connection resolver for the log (list of revisions) in a revision """ obj: BaseRevisionNode _node_class = BaseRevisionNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # STORAGE-TODO (date in revisionlog is a dict) log = archive.Archive().get_revision_log([self.obj.swhid.object_id]) # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) return utils.paginated(log, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py index cda3305..89f75e0 100644 --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -1,90 +1,91 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.model import Snapshot +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode from .origin import OriginNode from .visit_status import BaseVisitStatusNode class BaseSnapshotNode(BaseSWHNode): """ Base resolver for all the snapshot nodes """ def _get_snapshot_by_id(self, snapshot_id): # Return a Snapshot model object # branches is initialized as empty # Same pattern is used in directory return Snapshot(id=snapshot_id, branches={}) def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Snapshot" class SnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested directly with its SWHID """ def _get_node_data(self): """ """ snapshot_id = self.kwargs.get("swhid").object_id if archive.Archive().is_snapshot_available([snapshot_id]): return self._get_snapshot_by_id(snapshot_id) return None class VisitSnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested from a visit-status """ obj: BaseVisitStatusNode def _get_node_data(self): # self.obj.snapshotSWHID is the requested snapshot SWHID snapshot_id = self.obj.snapshotSWHID.object_id return self._get_snapshot_by_id(snapshot_id) class TargetSnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested as a target """ from .snapshot_branch import SnapshotBranchNode obj: Union[BaseVisitStatusNode, SnapshotBranchNode] def _get_node_data(self): snapshot_id = self.obj.target_hash return self._get_snapshot_by_id(snapshot_id) class OriginSnapshotConnection(BaseConnection): """ Connection resolver for the snapshots in an origin """ obj: OriginNode _node_class = BaseSnapshotNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: results = archive.Archive().get_origin_snapshots(self.obj.url) snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results] # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO return utils.paginated(snapshots, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py index f1b9a79..912b2c0 100644 --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -1,84 +1,84 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode class SnapshotBranchNode(BaseNode): """ Node resolver for a snapshot branch """ # target field for this Node is a UNION type # It is resolved in the top level (resolvers.resolvers.py) def _get_node_from_data(self, node_data): # node_data is not a dict in this case # overriding to support this special data structure # STORAGE-TODO; return an object in the normal format branch_name, branch_obj = node_data node = { "name": branch_name, "type": branch_obj.target_type.value, "target": branch_obj.target, } return namedtuple("NodeObj", node.keys())(*node.values()) @property def target_hash(self): return self._node.target class SnapshotBranchConnection(BaseConnection): """ Connection resolver for the branches in a snapshot """ from .snapshot import SnapshotNode obj: SnapshotNode _node_class = SnapshotBranchNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: result = archive.Archive().get_snapshot_branches( self.obj.swhid.object_id, after=self._get_after_arg(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), name_include=self._get_name_include_arg(), ) # endCursor is the last branch name, logic for that end_cusrsor = ( result["next_branch"] if result["next_branch"] is not None else None ) # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult # STORAGE-TODO return PagedResult( results=result["branches"].items(), next_page_token=end_cusrsor ) def _get_after_arg(self): # after argument must be an empty string by default after = super()._get_after_arg() return after.encode() if after else b"" def _get_name_include_arg(self): name_include = self.kwargs.get("nameInclude", None) return name_include.encode() if name_include else None def _get_index_cursor(self, index: int, node: SnapshotBranchNode): # Snapshot branch is using a different cursor, hence the override return utils.get_encoded_cursor(node.name) diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py index e0cb7fb..9d44d0a 100644 --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -1,66 +1,67 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.graphql.backends import archive from swh.graphql.utils import utils +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode from .origin import OriginNode class BaseVisitNode(BaseNode): """ Base resolver for all the visit nodes """ @property def id(self): # FIXME, use a better id return utils.get_b64_string(f"{self.origin}-{str(self.visit)}") @property def visitId(self): # To support the schema naming convention return self._node.visit class OriginVisitNode(BaseVisitNode): """ Node resolver for a visit requested directly with an origin URL and a visit ID """ def _get_node_data(self): return archive.Archive().get_origin_visit( self.kwargs.get("originUrl"), int(self.kwargs.get("visitId")) ) class LatestVisitNode(BaseVisitNode): """ Node resolver for the latest visit in an origin """ obj: OriginNode def _get_node_data(self): # self.obj.url is the origin URL return archive.Archive().get_origin_latest_visit(self.obj.url) class OriginVisitConnection(BaseConnection): """ Connection resolver for the visit objects in an origin """ obj: OriginNode _node_class = BaseVisitNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj.url is the origin URL return archive.Archive().get_origin_visits( self.obj.url, after=self._get_after_arg(), first=self._get_first_arg() ) diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py index 89c95e1..268b018 100644 --- a/swh/graphql/resolvers/visit_status.py +++ b/swh/graphql/resolvers/visit_status.py @@ -1,53 +1,54 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.graphql.backends import archive from swh.model.swhids import CoreSWHID, ObjectType +from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode from .visit import BaseVisitNode class BaseVisitStatusNode(BaseNode): """ Base resolver for all the visit-status nodes """ @property def snapshotSWHID(self): # To support the schema naming convention return CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=self._node.snapshot) class LatestVisitStatusNode(BaseVisitStatusNode): """ Node resolver for a visit-status requested from a visit """ obj: BaseVisitNode def _get_node_data(self): # self.obj.origin is the origin URL return archive.Archive().get_latest_visit_status( self.obj.origin, self.obj.visitId ) class VisitStatusConnection(BaseConnection): """ Connection resolver for the visit-status objects in a visit """ obj: BaseVisitNode _node_class = BaseVisitStatusNode - def _get_paged_result(self): + def _get_paged_result(self) -> PagedResult: # self.obj.origin is the origin URL return archive.Archive().get_visit_status( self.obj.origin, self.obj.visitId, after=self._get_after_arg(), first=self._get_first_arg(), ) diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py index bd9e928..f22be76 100644 --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -1,53 +1,53 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from datetime import datetime -from typing import List +from typing import List, Optional from swh.storage.interface import PagedResult ENCODING = "utf-8" def get_b64_string(source) -> str: if type(source) is str: source = source.encode(ENCODING) return base64.b64encode(source).decode("ascii") -def get_encoded_cursor(cursor: str) -> str: +def get_encoded_cursor(cursor: Optional[str]) -> Optional[str]: if cursor is None: return None return get_b64_string(cursor) def get_decoded_cursor(cursor: str) -> str: if cursor is None: return None return base64.b64decode(cursor, validate=True).decode() def get_formatted_date(date: datetime) -> str: # FIXME, handle error + return other formats return date.isoformat() def paginated(source: List, first: int, after=0) -> PagedResult: """ Pagination at the GraphQL level This is a temporary fix and inefficient. Should eventually be moved to the backend (storage) level """ # FIXME, handle data errors here after = 0 if after is None else int(after) end_cursor = after + first results = source[after:end_cursor] next_page_token = None if len(source) > end_cursor: next_page_token = str(end_cursor) return PagedResult(results=results, next_page_token=next_page_token)