diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index e8f5b07..a970bd4 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,143 +1,146 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from abc import ABC, abstractmethod import binascii from dataclasses import dataclass from typing import Any, List, Optional, Type, Union +from graphql.type import GraphQLResolveInfo + from swh.graphql.backends.archive import Archive from swh.graphql.backends.search import Search from swh.graphql.errors import PaginationError from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool endCursor: Optional[str] @dataclass class ConnectionEdge: node: Any cursor: Optional[str] class BaseConnection(ABC): """ Base resolver for all the connections """ _node_class: Optional[Type[BaseNode]] = None _page_size: int = 50 # default page size (default value for the first arg) _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): - self.obj: Optional[Any] = obj - self.info = info + self.obj: Optional[BaseNode] = obj + self.info: GraphQLResolveInfo = info self.kwargs = kwargs + # initialize commonly used vars self.archive = Archive() self.search = Search() self._paged_data: PagedResult = paged_data @property def edges(self) -> List[ConnectionEdge]: """ Return the list of connection edges, each with a cursor """ return [ ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) for (index, node) in enumerate(self.nodes) ] @property def nodes(self) -> List[Union[BaseNode, object]]: """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class( obj=self, info=self.info, node_data=result, **self.kwargs ) for result in self.get_paged_data().results ] return self.get_paged_data().results @property def pageInfo(self) -> PageInfo: # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property def totalCount(self) -> Optional[int]: # To support the schema naming convention """ Will be None for most of the connections override if needed/possible """ return None def get_paged_data(self) -> PagedResult: """ Cache to avoid multiple calls to the backend :meth:`_get_paged_result` return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) return None - def _get_after_arg(self) -> str: + def _get_after_arg(self): """ Return the decoded next page token. Override to support a different cursor type """ # different implementation is used in SnapshotBranchConnection try: cursor = utils.get_decoded_cursor(self.kwargs.get("after")) except (UnicodeDecodeError, binascii.Error) as e: raise PaginationError("Invalid value for argument 'after'", errors=e) return cursor def _get_first_arg(self) -> int: """ """ # page_size is set to 50 by default # Input type check is not required; It is defined in schema as an int first = self.kwargs.get("first", self._page_size) if first < 0 or first > self._max_page_size: raise PaginationError( f"Value for argument 'first' is invalid; it must be between 0 and {self._max_page_size}" # noqa: B950 ) return first def _get_index_cursor(self, index: int, node: Any) -> Optional[str]: """ Get the cursor to the given item index """ # default implementation which works with swh-storage pagaination # override this function to support other types (eg: SnapshotBranchConnection) offset_index = self._get_after_arg() or "0" index_cursor = int(offset_index) + index return utils.get_encoded_cursor(str(index_cursor)) diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py index 06384e1..da828b1 100644 --- a/swh/graphql/resolvers/base_node.py +++ b/swh/graphql/resolvers/base_node.py @@ -1,83 +1,87 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from abc import ABC from collections import namedtuple +from typing import Any, Optional, Union +from graphql.type import GraphQLResolveInfo + +from swh.graphql import resolvers as rs from swh.graphql.backends.archive import Archive from swh.graphql.errors import ObjectNotFoundError -class BaseNode(ABC): +class BaseNode: """ Base resolver for all the nodes """ - def __init__(self, obj, info, node_data=None, **kwargs): - self.obj = obj - self.info = info + def __init__(self, obj, info, node_data: Optional[Any] = None, **kwargs): + self.obj: Optional[Union[BaseNode, rs.base_connection.BaseConnection]] = obj + self.info: GraphQLResolveInfo = info self.kwargs = kwargs + # initialize commonly used vars self.archive = Archive() - self._node = self._get_node(node_data) + self._node: Optional[Any] = self._get_node(node_data) # handle the errors, if any, after _node is set self._handle_node_errors() - def _get_node(self, node_data): + def _get_node(self, node_data: Optional[Any]) -> Optional[Any]: """ Get the node object from the given data if the data (node_data) is none make a function call to get data from backend """ if node_data is None: node_data = self._get_node_data() return self._get_node_from_data(node_data) - def _get_node_from_data(self, node_data): + def _get_node_from_data(self, node_data: Any) -> Optional[Any]: """ Get the object from node_data In case of a dict, convert it to an object Override to support different data structures """ if type(node_data) is dict: return namedtuple("NodeObj", node_data.keys())(*node_data.values()) return node_data - def _handle_node_errors(self): + def _handle_node_errors(self) -> None: """ Handle any error related to node data raise an error in case the object returned is None override for specific behaviour """ if self._node is None: raise ObjectNotFoundError("Requested object is not available") - def _get_node_data(self): + def _get_node_data(self) -> Optional[Any]: """ Override for desired behaviour This will be called only when node_data is None """ # FIXME, make this call async (not for v1) return None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """ Any property defined in the sub-class will get precedence over the _node attributes """ return getattr(self._node, name) - def is_type_of(self): + def is_type_of(self) -> str: return self.__class__.__name__ class BaseSWHNode(BaseNode): """ Base resolver for all the nodes with a SWHID field """ @property def swhid(self): return self._node.swhid() diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py index b4fa41d..7ca2805 100644 --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -1,87 +1,87 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from .content import ContentNode, HashContentNode, TargetContentNode from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode from .directory_entry import DirectoryEntryConnection, DirectoryEntryNode from .origin import OriginConnection, OriginNode, TargetOriginNode from .release import ReleaseNode, TargetReleaseNode from .revision import ( LogRevisionConnection, ParentRevisionConnection, RevisionNode, TargetRevisionNode, ) from .search import ResolveSwhidConnection, SearchConnection from .snapshot import ( OriginSnapshotConnection, SnapshotNode, TargetSnapshotNode, VisitSnapshotNode, ) from .snapshot_branch import AliasSnapshotBranchNode, SnapshotBranchConnection from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode from .visit_status import LatestVisitStatusNode, VisitStatusConnection -def get_node_resolver(resolver_type): +def get_node_resolver(resolver_type: str): # FIXME, replace with a proper factory method mapping = { "origin": OriginNode, "visit": OriginVisitNode, "latest-visit": LatestVisitNode, "latest-status": LatestVisitStatusNode, "visit-snapshot": VisitSnapshotNode, "snapshot": SnapshotNode, "branch-alias": AliasSnapshotBranchNode, "branch-revision": TargetRevisionNode, "branch-release": TargetReleaseNode, "branch-directory": TargetDirectoryNode, "branch-content": TargetContentNode, "branch-snapshot": TargetSnapshotNode, "revision": RevisionNode, "revision-directory": RevisionDirectoryNode, "release": ReleaseNode, "release-revision": TargetRevisionNode, "release-release": TargetReleaseNode, "release-directory": TargetDirectoryNode, "release-content": TargetContentNode, "directory": DirectoryNode, "directory-entry": DirectoryEntryNode, "content": ContentNode, "content-by-hash": HashContentNode, "dir-entry-dir": TargetDirectoryNode, "dir-entry-file": TargetContentNode, "dir-entry-dir": TargetDirectoryNode, "dir-entry-rev": TargetRevisionNode, "search-result-origin": TargetOriginNode, "search-result-snapshot": TargetSnapshotNode, "search-result-revision": TargetRevisionNode, "search-result-release": TargetReleaseNode, "search-result-directory": TargetDirectoryNode, "search-result-content": TargetContentNode, } if resolver_type not in mapping: raise AttributeError(f"Invalid node type: {resolver_type}") return mapping[resolver_type] -def get_connection_resolver(resolver_type): +def get_connection_resolver(resolver_type: str): # FIXME, replace with a proper factory method mapping = { "origins": OriginConnection, "origin-visits": OriginVisitConnection, "origin-snapshots": OriginSnapshotConnection, "visit-status": VisitStatusConnection, "snapshot-branches": SnapshotBranchConnection, "revision-parents": ParentRevisionConnection, "revision-log": LogRevisionConnection, "directory-entries": DirectoryEntryConnection, "resolve-swhid": ResolveSwhidConnection, "search": SearchConnection, } if resolver_type not in mapping: raise AttributeError(f"Invalid connection type: {resolver_type}") return mapping[resolver_type] diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py index a637ef9..c430e81 100644 --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -1,311 +1,349 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ High level resolvers """ # Any schema attribute can be resolved by any of the following ways # and in the following priority order # - In this module using a decorator (eg: @visitstatus.field("snapshot") # Every object (type) is expected to resolve this way as they can accept arguments # eg: origin.visits takes arguments to paginate # - As a property in the Node object (eg: resolvers.visit.BaseVisitNode.id) # Every scalar is expected to resolve this way # - As an attribute/item in the object/dict returned by a backend (eg: Origin.url) -from typing import Optional +from typing import Optional, Union from ariadne import ObjectType, UnionType from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs from swh.graphql.utils import utils from .resolver_factory import get_connection_resolver, get_node_resolver query: ObjectType = ObjectType("Query") origin: ObjectType = ObjectType("Origin") visit: ObjectType = ObjectType("Visit") visit_status: ObjectType = ObjectType("VisitStatus") snapshot: ObjectType = ObjectType("Snapshot") snapshot_branch: ObjectType = ObjectType("Branch") revision: ObjectType = ObjectType("Revision") release: ObjectType = ObjectType("Release") directory: ObjectType = ObjectType("Directory") directory_entry: ObjectType = ObjectType("DirectoryEntry") search_result: ObjectType = ObjectType("SearchResult") binary_string: ObjectType = ObjectType("BinaryString") branch_target: UnionType = UnionType("BranchTarget") release_target: UnionType = UnionType("ReleaseTarget") directory_entry_target: UnionType = UnionType("DirectoryEntryTarget") search_result_target: UnionType = UnionType("SearchResultTarget") # Node resolvers # A node resolver should return an instance of BaseNode @query.field("origin") def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode: """ """ resolver = get_node_resolver("origin") return resolver(obj, info, **kw) @origin.field("latestVisit") def latest_visit_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.LatestVisitNode: """ """ resolver = get_node_resolver("latest-visit") return resolver(obj, info, **kw) @query.field("visit") def visit_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitNode: """ """ resolver = get_node_resolver("visit") return resolver(obj, info, **kw) @visit.field("latestStatus") def latest_visit_status_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.LatestVisitStatusNode: """ """ resolver = get_node_resolver("latest-status") return resolver(obj, info, **kw) @query.field("snapshot") def snapshot_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.SnapshotNode: """ """ resolver = get_node_resolver("snapshot") return resolver(obj, info, **kw) @visit_status.field("snapshot") def visit_snapshot_resolver( obj: rs.visit_status.BaseVisitStatusNode, info: GraphQLResolveInfo, **kw ) -> Optional[rs.snapshot.VisitSnapshotNode]: if obj.snapshotSWHID is None: return None resolver = get_node_resolver("visit-snapshot") return resolver(obj, info, **kw) @snapshot_branch.field("target") def snapshot_branch_target_resolver( obj: rs.snapshot_branch.BaseSnapshotBranchNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, + rs.snapshot.BaseSnapshotNode, + rs.snapshot_branch.BaseSnapshotBranchNode, +]: """ Snapshot branch target can be a revision, release, directory, content, snapshot or a branch itself (alias type) """ resolver_type = f"branch-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw) @query.field("revision") def revision_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.revision.RevisionNode: resolver = get_node_resolver("revision") return resolver(obj, info, **kw) @revision.field("directory") def revision_directory_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.directory.RevisionDirectoryNode: resolver = get_node_resolver("revision-directory") return resolver(obj, info, **kw) @query.field("release") def release_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.release.ReleaseNode: resolver = get_node_resolver("release") return resolver(obj, info, **kw) @release.field("target") -def release_target_resolver(obj, info: GraphQLResolveInfo, **kw): +def release_target_resolver( + obj: rs.release.BaseReleaseNode, info: GraphQLResolveInfo, **kw +) -> Union[ + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: """ release target can be a release, revision, directory or content obj is release here, target type is obj.target_type """ resolver_type = f"release-{obj.target_type.value}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw) @query.field("directory") def directory_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory.DirectoryNode: resolver = get_node_resolver("directory") return resolver(obj, info, **kw) @query.field("directoryEntry") def directory_entry_resolver( obj: None, info: GraphQLResolveInfo, **kw -) -> rs.directory.DirectoryNode: +) -> rs.directory_entry.DirectoryEntryNode: resolver = get_node_resolver("directory-entry") return resolver(obj, info, **kw) @directory_entry.field("target") def directory_entry_target_resolver( obj: rs.directory_entry.BaseDirectoryEntryNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.revision.BaseRevisionNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: """ - directory entry target can be a directory or a content + directory entry target can be a directory, content or a revision """ resolver_type = f"dir-entry-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw) @query.field("content") def content_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: resolver = get_node_resolver("content") return resolver(obj, info, **kw) @search_result.field("target") def search_result_target_resolver( obj: rs.search.SearchResultNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.origin.BaseOriginNode, + rs.snapshot.BaseSnapshotNode, + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: resolver_type = f"search-result-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw) @query.field("contentByHash") def content_by_hash_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: resolver = get_node_resolver("content-by-hash") return resolver(obj, info, **kw) # Connection resolvers # A connection resolver should return an instance of BaseConnection @query.field("origins") def origins_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.origin.OriginConnection: resolver = get_connection_resolver("origins") return resolver(obj, info, **kw) @origin.field("visits") def visits_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitConnection: resolver = get_connection_resolver("origin-visits") return resolver(obj, info, **kw) @origin.field("snapshots") def origin_snapshots_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.OriginSnapshotConnection: """ """ resolver = get_connection_resolver("origin-snapshots") return resolver(obj, info, **kw) @visit.field("status") def visitstatus_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.VisitStatusConnection: resolver = get_connection_resolver("visit-status") return resolver(obj, info, **kw) @snapshot.field("branches") def snapshot_branches_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.snapshot.BaseSnapshotNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot_branch.SnapshotBranchConnection: resolver = get_connection_resolver("snapshot-branches") return resolver(obj, info, **kw) @revision.field("parents") def revision_parents_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.revision.ParentRevisionConnection: resolver = get_connection_resolver("revision-parents") return resolver(obj, info, **kw) @revision.field("revisionLog") -def revision_log_resolver(obj, info, **kw): +def revision_log_resolver( + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw +) -> rs.revision.LogRevisionConnection: resolver = get_connection_resolver("revision-log") return resolver(obj, info, **kw) @directory.field("entries") def directory_entries_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.directory.BaseDirectoryNode, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryConnection: resolver = get_connection_resolver("directory-entries") return resolver(obj, info, **kw) @query.field("resolveSwhid") def search_swhid_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: None, info: GraphQLResolveInfo, **kw ) -> rs.search.ResolveSwhidConnection: resolver = get_connection_resolver("resolve-swhid") return resolver(obj, info, **kw) @query.field("search") def search_resolver( - obj, info: GraphQLResolveInfo, **kw -) -> rs.search.ResolveSwhidConnection: + obj: None, info: GraphQLResolveInfo, **kw +) -> rs.search.SearchConnection: resolver = get_connection_resolver("search") return resolver(obj, info, **kw) # Any other type of resolver @release_target.type_resolver @directory_entry_target.type_resolver @branch_target.type_resolver @search_result_target.type_resolver -def union_resolver(obj, *_) -> str: +def union_resolver( + obj: Union[ + rs.origin.BaseOriginNode, + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, + rs.snapshot.BaseSnapshotNode, + rs.snapshot_branch.BaseSnapshotBranchNode, + ], + *_, +) -> str: """ Generic resolver for all the union types """ return obj.is_type_of() @binary_string.field("text") -def binary_string_text_resolver(obj, *args, **kw): +def binary_string_text_resolver(obj: bytes, *args, **kw) -> str: return obj.decode(utils.ENCODING, "replace") @binary_string.field("base64") -def binary_string_base64_resolver(obj, *args, **kw): +def binary_string_base64_resolver(obj: bytes, *args, **kw) -> str: return utils.get_b64_string(obj) diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py index f22be76..5e7f2e5 100644 --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -1,53 +1,53 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import base64 from datetime import datetime from typing import List, Optional from swh.storage.interface import PagedResult ENCODING = "utf-8" def get_b64_string(source) -> str: if type(source) is str: source = source.encode(ENCODING) return base64.b64encode(source).decode("ascii") def get_encoded_cursor(cursor: Optional[str]) -> Optional[str]: if cursor is None: return None return get_b64_string(cursor) -def get_decoded_cursor(cursor: str) -> str: +def get_decoded_cursor(cursor: Optional[str]) -> Optional[str]: if cursor is None: return None return base64.b64decode(cursor, validate=True).decode() def get_formatted_date(date: datetime) -> str: # FIXME, handle error + return other formats return date.isoformat() def paginated(source: List, first: int, after=0) -> PagedResult: """ Pagination at the GraphQL level This is a temporary fix and inefficient. Should eventually be moved to the backend (storage) level """ # FIXME, handle data errors here after = 0 if after is None else int(after) end_cursor = after + first results = source[after:end_cursor] next_page_token = None if len(source) > end_cursor: next_page_token = str(end_cursor) return PagedResult(results=results, next_page_token=next_page_token)