diff --git a/swh/graphql/resolvers/base_connection.py b/swh/graphql/resolvers/base_connection.py index b8df117..beac11d 100644 --- a/swh/graphql/resolvers/base_connection.py +++ b/swh/graphql/resolvers/base_connection.py @@ -1,185 +1,184 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from abc import ABC, abstractmethod import binascii from dataclasses import dataclass from typing import Any, List, Optional, Type, Union from graphql.type import GraphQLResolveInfo from swh.graphql.backends.archive import Archive from swh.graphql.backends.search import Search from swh.graphql.errors import PaginationError from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_node import BaseNode @dataclass class PageInfo: hasNextPage: bool endCursor: Optional[str] @dataclass class ConnectionEdge: node: Any cursor: Optional[str] class BaseConnection(ABC): """ Base resolver for all the connections """ _node_class: Optional[Type[BaseNode]] = None _page_size: int = 50 # default page size (default value for the first arg) _max_page_size: int = 1000 # maximum page size(max value for the first arg) def __init__(self, obj, info, paged_data=None, **kwargs): self.obj: Optional[BaseNode] = obj self.info: GraphQLResolveInfo = info self.kwargs = kwargs # initialize commonly used vars self.archive = Archive() self.search = Search() self._paged_data: PagedResult = paged_data @property def edges(self) -> List[ConnectionEdge]: """ Return the list of connection edges, each with a cursor """ return [ ConnectionEdge(node=node, cursor=self._get_index_cursor(index, node)) for (index, node) in enumerate(self.nodes) ] @property def nodes(self) -> List[Union[BaseNode, object]]: """ Override if needed; return a list of objects If a node class is set, return a list of its (Node) instances else a list of raw results """ if self._node_class is not None: return [ self._node_class( obj=self, info=self.info, node_data=result, **self.kwargs ) for result in self.get_paged_data().results ] return self.get_paged_data().results @property def pageInfo(self) -> PageInfo: # To support the schema naming convention # FIXME, add more details like startCursor return PageInfo( hasNextPage=bool(self.get_paged_data().next_page_token), endCursor=utils.get_encoded_cursor(self.get_paged_data().next_page_token), ) @property def totalCount(self) -> Optional[int]: # To support the schema naming convention """ Will be None for most of the connections override if needed/possible """ return None def get_paged_data(self) -> PagedResult: """ Cache to avoid multiple calls to the backend :meth:`_get_paged_result` return a PagedResult object """ if self._paged_data is None: # FIXME, make this call async (not for v1) self._paged_data = self._get_paged_result() return self._paged_data @abstractmethod def _get_paged_result(self): """ Override for desired behaviour return a PagedResult object """ # FIXME, make this call async (not for v1) - return None def _get_after_arg(self): """ Return the decoded next page token. Override to support a different cursor type """ # different implementation is used in SnapshotBranchConnection try: cursor = utils.get_decoded_cursor(self.kwargs.get("after")) except (UnicodeDecodeError, binascii.Error) as e: raise PaginationError("Invalid value for argument 'after'", errors=e) return cursor def _get_first_arg(self) -> int: """ """ # page_size is set to 50 by default # Input type check is not required; It is defined in schema as an int first = self.kwargs.get("first", self._page_size) if first < 0 or first > self._max_page_size: raise PaginationError( f"Value for argument 'first' is invalid; it must be between 0 and {self._max_page_size}" # noqa: B950 ) return first def _get_index_cursor(self, index: int, node: Any) -> Optional[str]: """ Get the cursor to the given item index """ # default implementation which works with swh-storage pagaination # override this function to support other types (eg: SnapshotBranchConnection) offset_index = self._get_after_arg() or "0" index_cursor = int(offset_index) + index return utils.get_encoded_cursor(str(index_cursor)) class BaseList(ABC): """ Base class to be used for simple lists that do not require pagination; eg resolveSWHID entrypoint """ _node_class: Optional[Type[BaseNode]] = None def __init__(self, obj, info, results=None, **kwargs): self.obj: Optional[BaseNode] = obj self.info: GraphQLResolveInfo = info self.kwargs = kwargs self._results: List = results self.archive = Archive() def get_results(self) -> List: if self._results is None: # To avoid multiple calls to the backend self._results = self._get_results() if self._node_class is not None: # convert list items to node objects return [ self._node_class( obj=self.obj, info=self.info, node_data=result, **self.kwargs ) for result in self._results ] return self._results @abstractmethod def _get_results(self) -> List: """ Override for desired behaviour return a list of objects """ diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py index 0195654..9341fb2 100644 --- a/swh/graphql/resolvers/base_node.py +++ b/swh/graphql/resolvers/base_node.py @@ -1,93 +1,92 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple from typing import Any, ClassVar, Optional, Union from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs from swh.graphql.backends.archive import Archive from swh.graphql.errors import NullableObjectError, ObjectNotFoundError class BaseNode: """ Base resolver for all the nodes """ _can_be_null: ClassVar[bool] = False def __init__(self, obj, info, node_data: Optional[Any] = None, **kwargs) -> None: self.obj: Optional[Union[BaseNode, rs.base_connection.BaseConnection]] = obj self.info: GraphQLResolveInfo = info self.kwargs = kwargs # initialize commonly used vars self.archive = Archive() self._node: Optional[Any] = self._get_node(node_data) # handle the errors, if any, after _node is set self._handle_node_errors() def _get_node(self, node_data: Optional[Any]) -> Optional[Any]: """ Get the node object from the given data if the data (node_data) is none make a function call to get data from backend """ if node_data is None: node_data = self._get_node_data() return self._get_node_from_data(node_data) def _get_node_from_data(self, node_data: Any) -> Optional[Any]: """ Get the object from node_data In case of a dict, convert it to an object Override to support different data structures """ if type(node_data) is dict: return namedtuple("NodeObj", node_data.keys())(*node_data.values()) return node_data def _handle_node_errors(self) -> None: """ Handle any error related to node data raise an error in case the object returned is None override for specific behaviour """ if self._node is None and self._can_be_null: # fail silently raise NullableObjectError() elif self._node is None: # This will send this error to the client raise ObjectNotFoundError("Requested object is not available") def _get_node_data(self) -> Optional[Any]: """ Override for desired behaviour This will be called only when node_data is None """ # FIXME, make this call async (not for v1) - return None def __getattr__(self, name: str) -> Any: """ Any property defined in the sub-class will get precedence over the _node attributes """ return getattr(self._node, name) def is_type_of(self) -> str: return self.__class__.__name__ class BaseSWHNode(BaseNode): """ Base resolver for all the nodes with a SWHID field """ @property def swhid(self): return self._node.swhid() diff --git a/swh/graphql/tests/unit/resolvers/test_base_node.py b/swh/graphql/tests/unit/resolvers/test_base_node.py deleted file mode 100644 index 94b2e29..0000000 --- a/swh/graphql/tests/unit/resolvers/test_base_node.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (C) 2022 The Software Heritage developers -# See the AUTHORS file at the top-level directory of this distribution -# License: GNU General Public License version 3, or any later version -# See top-level LICENSE file for more information - -""" -""" - - -class TestBaseNode: - def test_init(self): - pass - - def test_get_node(self): - pass - - def test_get_node_from_data(self): - pass - - def test_handle_node_errors(self): - pass - - def test_get_node_data(self): - pass - - def test_getattr(self): - pass - - def test_is_type_of(sellf): - pass diff --git a/swh/graphql/tests/unit/resolvers/test_origin.py b/swh/graphql/tests/unit/resolvers/test_origin.py deleted file mode 100644 index 2ae2839..0000000 --- a/swh/graphql/tests/unit/resolvers/test_origin.py +++ /dev/null @@ -1 +0,0 @@ -pass diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py index 691e6b6..88d2a32 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py +++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py @@ -1,18 +1,22 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.graphql.resolvers import resolver_factory class TestFactory: def test_get_node_resolver_invalid_type(self): with pytest.raises(AttributeError): resolver_factory.NodeObjectFactory().create("invalid", None, None) def test_get_connection_resolver_invalid_type(self): with pytest.raises(AttributeError): resolver_factory.ConnectionObjectFactory().create("invalid", None, None) + + def test_get_list_resolver_invalid_type(self): + with pytest.raises(AttributeError): + resolver_factory.SimpleListFactory().create("invalid", None, None)