diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -8,6 +8,8 @@ from dataclasses import dataclass from typing import Any, List, Optional, Type, Union +from graphql.type import GraphQLResolveInfo + from swh.graphql.backends.archive import Archive from swh.graphql.backends.search import Search from swh.graphql.errors import PaginationError @@ -39,9 +41,10 @@ _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): - self.obj: Optional[Any] = obj - self.info = info + self.obj: Optional[BaseNode] = obj + self.info: GraphQLResolveInfo = info self.kwargs = kwargs + # initialize commonly used vars self.archive = Archive() self.search = Search() self._paged_data: PagedResult = paged_data @@ -109,7 +112,7 @@ # FIXME, make this call async (not for v1) return None - def _get_after_arg(self) -> str: + def _get_after_arg(self): """ Return the decoded next page token. Override to support a different cursor type diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py --- a/swh/graphql/resolvers/base_node.py +++ b/swh/graphql/resolvers/base_node.py @@ -3,28 +3,32 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from abc import ABC from collections import namedtuple +from typing import Any, Optional, Union +from graphql.type import GraphQLResolveInfo + +from swh.graphql import resolvers as rs from swh.graphql.backends.archive import Archive from swh.graphql.errors import ObjectNotFoundError -class BaseNode(ABC): +class BaseNode: """ Base resolver for all the nodes """ - def __init__(self, obj, info, node_data=None, **kwargs): - self.obj = obj - self.info = info + def __init__(self, obj, info, node_data: Optional[Any] = None, **kwargs): + self.obj: Optional[Union[BaseNode, rs.base_connection.BaseConnection]] = obj + self.info: GraphQLResolveInfo = info self.kwargs = kwargs + # initialize commonly used vars self.archive = Archive() - self._node = self._get_node(node_data) + self._node: Optional[Any] = self._get_node(node_data) # handle the errors, if any, after _node is set self._handle_node_errors() - def _get_node(self, node_data): + def _get_node(self, node_data: Optional[Any]) -> Optional[Any]: """ Get the node object from the given data if the data (node_data) is none make a function call @@ -34,7 +38,7 @@ node_data = self._get_node_data() return self._get_node_from_data(node_data) - def _get_node_from_data(self, node_data): + def _get_node_from_data(self, node_data: Any) -> Optional[Any]: """ Get the object from node_data In case of a dict, convert it to an object @@ -44,7 +48,7 @@ return namedtuple("NodeObj", node_data.keys())(*node_data.values()) return node_data - def _handle_node_errors(self): + def _handle_node_errors(self) -> None: """ Handle any error related to node data @@ -54,7 +58,7 @@ if self._node is None: raise ObjectNotFoundError("Requested object is not available") - def _get_node_data(self): + def _get_node_data(self) -> Optional[Any]: """ Override for desired behaviour This will be called only when node_data is None @@ -62,14 +66,14 @@ # FIXME, make this call async (not for v1) return None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """ Any property defined in the sub-class will get precedence over the _node attributes """ return getattr(self._node, name) - def is_type_of(self): + def is_type_of(self) -> str: return self.__class__.__name__ diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -26,7 +26,7 @@ from .visit_status import LatestVisitStatusNode, VisitStatusConnection -def get_node_resolver(resolver_type): +def get_node_resolver(resolver_type: str): # FIXME, replace with a proper factory method mapping = { "origin": OriginNode, @@ -68,7 +68,7 @@ return mapping[resolver_type] -def get_connection_resolver(resolver_type): +def get_connection_resolver(resolver_type: str): # FIXME, replace with a proper factory method mapping = { "origins": OriginConnection, diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -16,7 +16,7 @@ # Every scalar is expected to resolve this way # - As an attribute/item in the object/dict returned by a backend (eg: Origin.url) -from typing import Optional +from typing import Optional, Union from ariadne import ObjectType, UnionType from graphql.type import GraphQLResolveInfo @@ -57,7 +57,7 @@ @origin.field("latestVisit") def latest_visit_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.LatestVisitNode: """ """ resolver = get_node_resolver("latest-visit") @@ -75,7 +75,7 @@ @visit.field("latestStatus") def latest_visit_status_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.LatestVisitStatusNode: """ """ resolver = get_node_resolver("latest-status") @@ -104,7 +104,14 @@ @snapshot_branch.field("target") def snapshot_branch_target_resolver( obj: rs.snapshot_branch.BaseSnapshotBranchNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, + rs.snapshot.BaseSnapshotNode, + rs.snapshot_branch.BaseSnapshotBranchNode, +]: """ Snapshot branch target can be a revision, release, directory, content, snapshot or a branch itself (alias type) @@ -124,7 +131,7 @@ @revision.field("directory") def revision_directory_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.directory.RevisionDirectoryNode: resolver = get_node_resolver("revision-directory") return resolver(obj, info, **kw) @@ -139,7 +146,14 @@ @release.field("target") -def release_target_resolver(obj, info: GraphQLResolveInfo, **kw): +def release_target_resolver( + obj: rs.release.BaseReleaseNode, info: GraphQLResolveInfo, **kw +) -> Union[ + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: """ release target can be a release, revision, directory or content @@ -162,7 +176,7 @@ @query.field("directoryEntry") def directory_entry_resolver( obj: None, info: GraphQLResolveInfo, **kw -) -> rs.directory.DirectoryNode: +) -> rs.directory_entry.DirectoryEntryNode: resolver = get_node_resolver("directory-entry") return resolver(obj, info, **kw) @@ -170,9 +184,13 @@ @directory_entry.field("target") def directory_entry_target_resolver( obj: rs.directory_entry.BaseDirectoryEntryNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.revision.BaseRevisionNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: """ - directory entry target can be a directory or a content + directory entry target can be a directory, content or a revision """ resolver_type = f"dir-entry-{obj.type}" resolver = get_node_resolver(resolver_type) @@ -190,7 +208,14 @@ @search_result.field("target") def search_result_target_resolver( obj: rs.search.SearchResultNode, info: GraphQLResolveInfo, **kw -): +) -> Union[ + rs.origin.BaseOriginNode, + rs.snapshot.BaseSnapshotNode, + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, +]: resolver_type = f"search-result-{obj.type}" resolver = get_node_resolver(resolver_type) return resolver(obj, info, **kw) @@ -218,7 +243,7 @@ @origin.field("visits") def visits_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitConnection: resolver = get_connection_resolver("origin-visits") return resolver(obj, info, **kw) @@ -226,7 +251,7 @@ @origin.field("snapshots") def origin_snapshots_resolver( - obj: rs.origin.OriginNode, info: GraphQLResolveInfo, **kw + obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.OriginSnapshotConnection: """ """ resolver = get_connection_resolver("origin-snapshots") @@ -235,7 +260,7 @@ @visit.field("status") def visitstatus_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.VisitStatusConnection: resolver = get_connection_resolver("visit-status") return resolver(obj, info, **kw) @@ -243,7 +268,7 @@ @snapshot.field("branches") def snapshot_branches_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.snapshot.BaseSnapshotNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot_branch.SnapshotBranchConnection: resolver = get_connection_resolver("snapshot-branches") return resolver(obj, info, **kw) @@ -251,21 +276,23 @@ @revision.field("parents") def revision_parents_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.revision.ParentRevisionConnection: resolver = get_connection_resolver("revision-parents") return resolver(obj, info, **kw) @revision.field("revisionLog") -def revision_log_resolver(obj, info, **kw): +def revision_log_resolver( + obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw +) -> rs.revision.LogRevisionConnection: resolver = get_connection_resolver("revision-log") return resolver(obj, info, **kw) @directory.field("entries") def directory_entries_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: rs.directory.BaseDirectoryNode, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryConnection: resolver = get_connection_resolver("directory-entries") return resolver(obj, info, **kw) @@ -273,7 +300,7 @@ @query.field("resolveSwhid") def search_swhid_resolver( - obj, info: GraphQLResolveInfo, **kw + obj: None, info: GraphQLResolveInfo, **kw ) -> rs.search.ResolveSwhidConnection: resolver = get_connection_resolver("resolve-swhid") return resolver(obj, info, **kw) @@ -281,8 +308,8 @@ @query.field("search") def search_resolver( - obj, info: GraphQLResolveInfo, **kw -) -> rs.search.ResolveSwhidConnection: + obj: None, info: GraphQLResolveInfo, **kw +) -> rs.search.SearchConnection: resolver = get_connection_resolver("search") return resolver(obj, info, **kw) @@ -294,7 +321,18 @@ @directory_entry_target.type_resolver @branch_target.type_resolver @search_result_target.type_resolver -def union_resolver(obj, *_) -> str: +def union_resolver( + obj: Union[ + rs.origin.BaseOriginNode, + rs.revision.BaseRevisionNode, + rs.release.BaseReleaseNode, + rs.directory.BaseDirectoryNode, + rs.content.BaseContentNode, + rs.snapshot.BaseSnapshotNode, + rs.snapshot_branch.BaseSnapshotBranchNode, + ], + *_, +) -> str: """ Generic resolver for all the union types """ @@ -302,10 +340,10 @@ @binary_string.field("text") -def binary_string_text_resolver(obj, *args, **kw): +def binary_string_text_resolver(obj: bytes, *args, **kw) -> str: return obj.decode(utils.ENCODING, "replace") @binary_string.field("base64") -def binary_string_base64_resolver(obj, *args, **kw): +def binary_string_base64_resolver(obj: bytes, *args, **kw) -> str: return utils.get_b64_string(obj) diff --git a/swh/graphql/utils/utils.py b/swh/graphql/utils/utils.py --- a/swh/graphql/utils/utils.py +++ b/swh/graphql/utils/utils.py @@ -24,7 +24,7 @@ return get_b64_string(cursor) -def get_decoded_cursor(cursor: str) -> str: +def get_decoded_cursor(cursor: Optional[str]) -> Optional[str]: if cursor is None: return None return base64.b64decode(cursor, validate=True).decode()