diff --git a/swh/graphql/resolvers/base_node.py b/swh/graphql/resolvers/base_node.py index da828b1..0195654 100644 --- a/swh/graphql/resolvers/base_node.py +++ b/swh/graphql/resolvers/base_node.py @@ -1,87 +1,93 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from collections import namedtuple -from typing import Any, Optional, Union +from typing import Any, ClassVar, Optional, Union from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs from swh.graphql.backends.archive import Archive -from swh.graphql.errors import ObjectNotFoundError +from swh.graphql.errors import NullableObjectError, ObjectNotFoundError class BaseNode: """ Base resolver for all the nodes """ - def __init__(self, obj, info, node_data: Optional[Any] = None, **kwargs): + _can_be_null: ClassVar[bool] = False + + def __init__(self, obj, info, node_data: Optional[Any] = None, **kwargs) -> None: self.obj: Optional[Union[BaseNode, rs.base_connection.BaseConnection]] = obj self.info: GraphQLResolveInfo = info self.kwargs = kwargs # initialize commonly used vars self.archive = Archive() self._node: Optional[Any] = self._get_node(node_data) # handle the errors, if any, after _node is set self._handle_node_errors() def _get_node(self, node_data: Optional[Any]) -> Optional[Any]: """ Get the node object from the given data if the data (node_data) is none make a function call to get data from backend """ if node_data is None: node_data = self._get_node_data() return self._get_node_from_data(node_data) def _get_node_from_data(self, node_data: Any) -> Optional[Any]: """ Get the object from node_data In case of a dict, convert it to an object Override to support different data structures """ if type(node_data) is dict: return namedtuple("NodeObj", node_data.keys())(*node_data.values()) return node_data def _handle_node_errors(self) -> None: """ Handle any error related to node data raise an error in case the object returned is None override for specific behaviour """ - if self._node is None: + if self._node is None and self._can_be_null: + # fail silently + raise NullableObjectError() + elif self._node is None: + # This will send this error to the client raise ObjectNotFoundError("Requested object is not available") def _get_node_data(self) -> Optional[Any]: """ Override for desired behaviour This will be called only when node_data is None """ # FIXME, make this call async (not for v1) return None def __getattr__(self, name: str) -> Any: """ Any property defined in the sub-class will get precedence over the _node attributes """ return getattr(self._node, name) def is_type_of(self) -> str: return self.__class__.__name__ class BaseSWHNode(BaseNode): """ Base resolver for all the nodes with a SWHID field """ @property def swhid(self): return self._node.swhid() diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py index e11a1c5..a85e0f7 100644 --- a/swh/graphql/resolvers/content.py +++ b/swh/graphql/resolvers/content.py @@ -1,98 +1,99 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from .base_node import BaseSWHNode from .directory_entry import BaseDirectoryEntryNode from .release import BaseReleaseNode from .search import SearchResultNode from .snapshot_branch import BaseSnapshotBranchNode class BaseContentNode(BaseSWHNode): """ Base resolver for all the content nodes """ def _get_content_by_hash(self, checksums: dict): content = self.archive.get_contents(checksums) # in case of a conflict, return the first element return content[0] if content else None @property def checksum(self): # FIXME, use a Node instead return {k: v.hex() for (k, v) in self._node.hashes().items()} @property def id(self): return self._node.sha1_git @property def data(self): # FIXME, return a Node object # FIXME, add more ways to retrieve data like binary string archive_url = "https://archive.softwareheritage.org/api/1/" content_sha1 = self._node.hashes()["sha1"] return { "url": f"{archive_url}content/sha1:{content_sha1.hex()}/raw/", } @property def fileType(self): # FIXME, fetch data from the indexers return None @property def language(self): # FIXME, fetch data from the indexers return None @property def license(self): # FIXME, fetch data from the indexers return None def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Content" class ContentNode(BaseContentNode): """ Node resolver for a content requested directly with its SWHID """ def _get_node_data(self): checksums = {"sha1_git": self.kwargs.get("swhid").object_id} return self._get_content_by_hash(checksums) class HashContentNode(BaseContentNode): """ Node resolver for a content requested with one or more checksums """ def _get_node_data(self): checksums = dict(self.kwargs.get("checksums")) return self._get_content_by_hash(checksums) class TargetContentNode(BaseContentNode): """ Node resolver for a content requested as a target """ + _can_be_null = True obj: Union[ SearchResultNode, BaseDirectoryEntryNode, BaseReleaseNode, BaseSnapshotBranchNode, ] def _get_node_data(self): return self._get_content_by_hash(checksums={"sha1_git": self.obj.target_hash}) diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py index 1846f4c..53291f6 100644 --- a/swh/graphql/resolvers/directory.py +++ b/swh/graphql/resolvers/directory.py @@ -1,77 +1,79 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from swh.model.model import Directory from swh.model.swhids import ObjectType from .base_node import BaseSWHNode from .release import BaseReleaseNode from .revision import BaseRevisionNode from .search import SearchResultNode from .snapshot_branch import BaseSnapshotBranchNode class BaseDirectoryNode(BaseSWHNode): """ Base resolver for all the directory nodes """ def _get_directory_by_id(self, directory_id): # Return a Directory model object # entries is initialized as empty # Same pattern is used in snapshot return Directory(id=directory_id, entries=()) def is_type_of(self): return "Directory" class DirectoryNode(BaseDirectoryNode): """ Node resolver for a directory requested directly with its SWHID """ def _get_node_data(self): swhid = self.kwargs.get("swhid") if ( swhid.object_type == ObjectType.DIRECTORY and self.archive.is_object_available(swhid.object_id, swhid.object_type) ): # _get_directory_by_id is not making any backend call # hence the is_directory_available validation return self._get_directory_by_id(swhid.object_id) return None class RevisionDirectoryNode(BaseDirectoryNode): """ Node resolver for a directory requested from a revision """ + _can_be_null = True obj: BaseRevisionNode def _get_node_data(self): # self.obj.directory_hash is the requested directory Id return self._get_directory_by_id(self.obj.directory_hash) class TargetDirectoryNode(BaseDirectoryNode): """ Node resolver for a directory requested as a target """ from .directory_entry import BaseDirectoryEntryNode + _can_be_null = True obj: Union[ BaseSnapshotBranchNode, BaseReleaseNode, BaseDirectoryEntryNode, SearchResultNode, ] def _get_node_data(self): return self._get_directory_by_id(self.obj.target_hash) diff --git a/swh/graphql/resolvers/release.py b/swh/graphql/resolvers/release.py index 3e507bb..f3ecc63 100644 --- a/swh/graphql/resolvers/release.py +++ b/swh/graphql/resolvers/release.py @@ -1,53 +1,54 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from .base_node import BaseSWHNode from .search import SearchResultNode from .snapshot_branch import BaseSnapshotBranchNode class BaseReleaseNode(BaseSWHNode): """ Base resolver for all the release nodes """ def _get_release_by_id(self, release_id): return self.archive.get_releases([release_id])[0] @property def target_hash(self): return self._node.target @property def targetType(self): # To support the schema naming convention return self._node.target_type.value def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Release" class ReleaseNode(BaseReleaseNode): """ Node resolver for a release requested directly with its SWHID """ def _get_node_data(self): return self._get_release_by_id(self.kwargs.get("swhid").object_id) class TargetReleaseNode(BaseReleaseNode): """ Node resolver for a release requested as a target """ + _can_be_null = True obj: Union[BaseSnapshotBranchNode, BaseReleaseNode, SearchResultNode] def _get_node_data(self): # self.obj.target_hash is the requested release id return self._get_release_by_id(self.obj.target_hash) diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py index 04efae7..1bca961 100644 --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -1,86 +1,103 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import ClassVar, Dict, Type + +from swh.graphql.errors import NullableObjectError + +from .base_connection import BaseConnection +from .base_node import BaseNode from .content import ContentNode, HashContentNode, TargetContentNode from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode from .directory_entry import DirectoryEntryConnection, DirectoryEntryNode from .origin import OriginConnection, OriginNode, TargetOriginNode from .release import ReleaseNode, TargetReleaseNode from .revision import ( LogRevisionConnection, ParentRevisionConnection, RevisionNode, TargetRevisionNode, ) from .search import ResolveSwhidConnection, SearchConnection from .snapshot import ( OriginSnapshotConnection, SnapshotNode, TargetSnapshotNode, VisitSnapshotNode, ) from .snapshot_branch import AliasSnapshotBranchNode, SnapshotBranchConnection from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode from .visit_status import LatestVisitStatusNode, VisitStatusConnection -def get_node_resolver(resolver_type: str): - # FIXME, replace with a proper factory method - mapping = { +class NodeObjectFactory: + mapping: ClassVar[Dict[str, Type[BaseNode]]] = { "origin": OriginNode, "visit": OriginVisitNode, "latest-visit": LatestVisitNode, "latest-status": LatestVisitStatusNode, "visit-snapshot": VisitSnapshotNode, "snapshot": SnapshotNode, "branch-alias": AliasSnapshotBranchNode, "branch-revision": TargetRevisionNode, "branch-release": TargetReleaseNode, "branch-directory": TargetDirectoryNode, "branch-content": TargetContentNode, "branch-snapshot": TargetSnapshotNode, "revision": RevisionNode, "revision-directory": RevisionDirectoryNode, "release": ReleaseNode, "release-revision": TargetRevisionNode, "release-release": TargetReleaseNode, "release-directory": TargetDirectoryNode, "release-content": TargetContentNode, "directory": DirectoryNode, "directory-entry": DirectoryEntryNode, "content": ContentNode, "content-by-hash": HashContentNode, "dir-entry-content": TargetContentNode, "dir-entry-directory": TargetDirectoryNode, "dir-entry-revision": TargetRevisionNode, "search-result-origin": TargetOriginNode, "search-result-snapshot": TargetSnapshotNode, "search-result-revision": TargetRevisionNode, "search-result-release": TargetReleaseNode, "search-result-directory": TargetDirectoryNode, "search-result-content": TargetContentNode, } - if resolver_type not in mapping: - raise AttributeError(f"Invalid node type: {resolver_type}") - return mapping[resolver_type] + @classmethod + def create(cls, node_type: str, obj, info, *args, **kw): + resolver = cls.mapping.get(node_type) + if not resolver: + raise AttributeError(f"Invalid node type: {node_type}") + try: + node_obj = resolver(obj, info, *args, **kw) + except NullableObjectError: + # Return None instead of the object + node_obj = None + return node_obj -def get_connection_resolver(resolver_type: str): - # FIXME, replace with a proper factory method - mapping = { + +class ConnectionObjectFactory: + mapping: ClassVar[Dict[str, Type[BaseConnection]]] = { "origins": OriginConnection, "origin-visits": OriginVisitConnection, "origin-snapshots": OriginSnapshotConnection, "visit-status": VisitStatusConnection, "snapshot-branches": SnapshotBranchConnection, "revision-parents": ParentRevisionConnection, "revision-log": LogRevisionConnection, "directory-entries": DirectoryEntryConnection, "resolve-swhid": ResolveSwhidConnection, "search": SearchConnection, } - if resolver_type not in mapping: - raise AttributeError(f"Invalid connection type: {resolver_type}") - return mapping[resolver_type] + + @classmethod + def create(cls, connection_type: str, obj, info, *args, **kw): + resolver = cls.mapping.get(connection_type) + if not resolver: + raise AttributeError(f"Invalid connection type: {connection_type}") + return resolver(obj, info, *args, **kw) diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py index 959ce4e..80087df 100644 --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -1,354 +1,314 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ High level resolvers """ # Any schema attribute can be resolved by any of the following ways # and in the following priority order # - In this module using a decorator (eg: @visit_status.field("snapshot")) # Every object (type) is expected to resolve this way as they can accept arguments # eg: origin.visits takes arguments to paginate # - As a property in the Node object (eg: resolvers.visit.BaseVisitNode.id) # Every scalar is expected to resolve this way # - As an attribute/item in the object/dict returned by a backend (eg: Origin.url) from typing import Optional, Union from ariadne import ObjectType, UnionType from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs -from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils -from .resolver_factory import get_connection_resolver, get_node_resolver +from .resolver_factory import ConnectionObjectFactory, NodeObjectFactory query: ObjectType = ObjectType("Query") origin: ObjectType = ObjectType("Origin") visit: ObjectType = ObjectType("Visit") visit_status: ObjectType = ObjectType("VisitStatus") snapshot: ObjectType = ObjectType("Snapshot") snapshot_branch: ObjectType = ObjectType("Branch") revision: ObjectType = ObjectType("Revision") release: ObjectType = ObjectType("Release") directory: ObjectType = ObjectType("Directory") directory_entry: ObjectType = ObjectType("DirectoryEntry") search_result: ObjectType = ObjectType("SearchResult") binary_string: ObjectType = ObjectType("BinaryString") branch_target: UnionType = UnionType("BranchTarget") release_target: UnionType = UnionType("ReleaseTarget") directory_entry_target: UnionType = UnionType("DirectoryEntryTarget") search_result_target: UnionType = UnionType("SearchResultTarget") # Node resolvers -# A node resolver should return an instance of BaseNode +# A node resolver will return either an instance of a BaseNode subclass or None @query.field("origin") def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode: - """ """ - resolver = get_node_resolver("origin") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("origin", obj, info, **kw) @origin.field("latestVisit") def latest_visit_resolver( obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw -) -> rs.visit.LatestVisitNode: - """ """ - resolver = get_node_resolver("latest-visit") - return resolver(obj, info, **kw) +) -> Optional[rs.visit.LatestVisitNode]: + return NodeObjectFactory.create("latest-visit", obj, info, **kw) @query.field("visit") def visit_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitNode: - """ """ - resolver = get_node_resolver("visit") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("visit", obj, info, **kw) @visit.field("latestStatus") def latest_visit_status_resolver( obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> Optional[rs.visit_status.LatestVisitStatusNode]: - """ """ - resolver = get_node_resolver("latest-status") - try: - return resolver(obj, info, **kw) - except NullableObjectError: - # FIXME, make this pattern generic for all the resolvers - return None + return NodeObjectFactory.create("latest-status", obj, info, **kw) @query.field("snapshot") def snapshot_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.SnapshotNode: - """ """ - resolver = get_node_resolver("snapshot") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("snapshot", obj, info, **kw) @visit_status.field("snapshot") def visit_snapshot_resolver( obj: rs.visit_status.BaseVisitStatusNode, info: GraphQLResolveInfo, **kw ) -> Optional[rs.snapshot.VisitSnapshotNode]: - if obj.snapshotSWHID is None: - return None - resolver = get_node_resolver("visit-snapshot") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("visit-snapshot", obj, info, **kw) @snapshot_branch.field("target") def snapshot_branch_target_resolver( obj: rs.snapshot_branch.BaseSnapshotBranchNode, info: GraphQLResolveInfo, **kw ) -> Union[ rs.revision.BaseRevisionNode, rs.release.BaseReleaseNode, rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, rs.snapshot.BaseSnapshotNode, rs.snapshot_branch.BaseSnapshotBranchNode, ]: """ Snapshot branch target can be a revision, release, directory, content, snapshot or a branch itself (alias type) """ - resolver_type = f"branch-{obj.type}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return NodeObjectFactory.create(f"branch-{obj.targetType}", obj, info, **kw) @query.field("revision") def revision_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.revision.RevisionNode: - resolver = get_node_resolver("revision") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("revision", obj, info, **kw) @revision.field("directory") def revision_directory_resolver( obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw -) -> rs.directory.RevisionDirectoryNode: - resolver = get_node_resolver("revision-directory") - return resolver(obj, info, **kw) +) -> Optional[rs.directory.RevisionDirectoryNode]: + return NodeObjectFactory.create("revision-directory", obj, info, **kw) @query.field("release") def release_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.release.ReleaseNode: - resolver = get_node_resolver("release") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("release", obj, info, **kw) @release.field("target") def release_target_resolver( obj: rs.release.BaseReleaseNode, info: GraphQLResolveInfo, **kw ) -> Union[ rs.revision.BaseRevisionNode, rs.release.BaseReleaseNode, rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, ]: """ - release target can be a release, revision, - directory or content - obj is release here, target type is - obj.target_type + Release target can be a release, revision, directory or a content """ - resolver_type = f"release-{obj.target_type.value}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return NodeObjectFactory.create(f"release-{obj.targetType}", obj, info, **kw) @query.field("directory") def directory_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory.DirectoryNode: - resolver = get_node_resolver("directory") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("directory", obj, info, **kw) @query.field("directoryEntry") def directory_entry_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryNode: - resolver = get_node_resolver("directory-entry") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("directory-entry", obj, info, **kw) @directory_entry.field("target") def directory_entry_target_resolver( obj: rs.directory_entry.BaseDirectoryEntryNode, info: GraphQLResolveInfo, **kw ) -> Union[ rs.revision.BaseRevisionNode, rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, ]: """ - directory entry target can be a directory, content or a revision + DirectoryEntry target can be a directory, content or a revision """ - resolver_type = f"dir-entry-{obj.targetType}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return NodeObjectFactory.create(f"dir-entry-{obj.targetType}", obj, info, **kw) @query.field("content") def content_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: - resolver = get_node_resolver("content") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("content", obj, info, **kw) @search_result.field("target") def search_result_target_resolver( obj: rs.search.SearchResultNode, info: GraphQLResolveInfo, **kw ) -> Union[ rs.origin.BaseOriginNode, rs.snapshot.BaseSnapshotNode, rs.revision.BaseRevisionNode, rs.release.BaseReleaseNode, rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, ]: - resolver_type = f"search-result-{obj.type}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + """ + SearchResult target can be an origin, snapshot, revision, release + directory or a content + """ + return NodeObjectFactory.create(f"search-result-{obj.targetType}", obj, info, **kw) @query.field("contentByHash") def content_by_hash_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: - resolver = get_node_resolver("content-by-hash") - return resolver(obj, info, **kw) + return NodeObjectFactory.create("content-by-hash", obj, info, **kw) # Connection resolvers # A connection resolver should return an instance of BaseConnection @query.field("origins") def origins_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.origin.OriginConnection: - resolver = get_connection_resolver("origins") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("origins", obj, info, **kw) @origin.field("visits") def visits_resolver( obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitConnection: - resolver = get_connection_resolver("origin-visits") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("origin-visits", obj, info, **kw) @origin.field("snapshots") def origin_snapshots_resolver( obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.OriginSnapshotConnection: - """ """ - resolver = get_connection_resolver("origin-snapshots") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("origin-snapshots", obj, info, **kw) @visit.field("statuses") def visitstatus_resolver( obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> rs.visit_status.VisitStatusConnection: - resolver = get_connection_resolver("visit-status") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("visit-status", obj, info, **kw) @snapshot.field("branches") def snapshot_branches_resolver( obj: rs.snapshot.BaseSnapshotNode, info: GraphQLResolveInfo, **kw ) -> rs.snapshot_branch.SnapshotBranchConnection: - resolver = get_connection_resolver("snapshot-branches") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("snapshot-branches", obj, info, **kw) @revision.field("parents") def revision_parents_resolver( obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.revision.ParentRevisionConnection: - resolver = get_connection_resolver("revision-parents") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("revision-parents", obj, info, **kw) @revision.field("revisionLog") def revision_log_resolver( obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.revision.LogRevisionConnection: - resolver = get_connection_resolver("revision-log") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("revision-log", obj, info, **kw) @directory.field("entries") def directory_entries_resolver( obj: rs.directory.BaseDirectoryNode, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryConnection: - resolver = get_connection_resolver("directory-entries") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("directory-entries", obj, info, **kw) @query.field("resolveSwhid") def search_swhid_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.search.ResolveSwhidConnection: - resolver = get_connection_resolver("resolve-swhid") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("resolve-swhid", obj, info, **kw) @query.field("search") def search_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.search.SearchConnection: - resolver = get_connection_resolver("search") - return resolver(obj, info, **kw) + return ConnectionObjectFactory.create("search", obj, info, **kw) -# Any other type of resolver +# Other resolvers @release_target.type_resolver @directory_entry_target.type_resolver @branch_target.type_resolver @search_result_target.type_resolver def union_resolver( obj: Union[ rs.origin.BaseOriginNode, rs.revision.BaseRevisionNode, rs.release.BaseReleaseNode, rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, rs.snapshot.BaseSnapshotNode, rs.snapshot_branch.BaseSnapshotBranchNode, ], *_, ) -> str: """ Generic resolver for all the union types """ return obj.is_type_of() +# BinaryString resolvers + + @binary_string.field("text") def binary_string_text_resolver(obj: bytes, *args, **kw) -> str: return obj.decode(utils.ENCODING, "replace") @binary_string.field("base64") def binary_string_base64_resolver(obj: bytes, *args, **kw) -> str: return utils.get_b64_string(obj) diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py index 0fd8bd2..fa5dad4 100644 --- a/swh/graphql/resolvers/revision.py +++ b/swh/graphql/resolvers/revision.py @@ -1,116 +1,117 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union from swh.graphql.utils import utils from swh.model.model import Revision from swh.model.swhids import CoreSWHID, ObjectType from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode from .directory_entry import BaseDirectoryEntryNode from .release import BaseReleaseNode from .search import SearchResultNode from .snapshot_branch import BaseSnapshotBranchNode class BaseRevisionNode(BaseSWHNode): """ Base resolver for all the revision nodes """ def _get_revision_by_id(self, revision_id): return self.archive.get_revisions([revision_id])[0] @property def parent_swhids(self): # for ParentRevisionConnection resolver return [ CoreSWHID(object_type=ObjectType.REVISION, object_id=parent_id) for parent_id in self._node.parents ] @property def directory_hash(self): # for RevisionDirectoryNode resolver return self._node.directory @property def type(self): return self._node.type.value def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Revision" class RevisionNode(BaseRevisionNode): """ Node resolver for a revision requested directly with its SWHID """ def _get_node_data(self): return self._get_revision_by_id(self.kwargs.get("swhid").object_id) class TargetRevisionNode(BaseRevisionNode): """ Node resolver for a revision requested as a target """ + _can_be_null = True obj: Union[ BaseSnapshotBranchNode, BaseReleaseNode, BaseDirectoryEntryNode, SearchResultNode, ] def _get_node_data(self): # self.obj.target_hash is the requested revision id return self._get_revision_by_id(self.obj.target_hash) class ParentRevisionConnection(BaseConnection): """ Connection resolver for parent revisions in a revision """ obj: BaseRevisionNode _node_class = BaseRevisionNode def _get_paged_result(self) -> PagedResult: # self.obj is the current(child) revision # self.obj.parent_swhids is the list of parent SWHIDs # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) parents = self.archive.get_revisions( [x.object_id for x in self.obj.parent_swhids] ) return utils.paginated(parents, self._get_first_arg(), self._get_after_arg()) class LogRevisionConnection(BaseConnection): """ Connection resolver for the log (list of revisions) in a revision """ obj: BaseRevisionNode _node_class = BaseRevisionNode def _get_paged_result(self) -> PagedResult: log = self.archive.get_revision_log([self.obj.swhid.object_id]) # Storage is returning a list of dicts instead of model objects # Following loop is to reverse that operation # STORAGE-TODO; remove to_dict from storage.revision_log log = [Revision.from_dict(rev) for rev in log] # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO (pagination) return utils.paginated(log, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py index 011aab7..dd0a02f 100644 --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -1,94 +1,99 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from typing import Union +from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils from swh.model.model import Snapshot from swh.model.swhids import ObjectType from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseSWHNode from .origin import OriginNode from .search import SearchResultNode from .visit_status import BaseVisitStatusNode class BaseSnapshotNode(BaseSWHNode): """ Base resolver for all the snapshot nodes """ def _get_snapshot_by_id(self, snapshot_id): # Return a Snapshot model object # branches is initialized as empty # Same pattern is used in directory return Snapshot(id=snapshot_id, branches={}) def is_type_of(self): # is_type_of is required only when resolving a UNION type # This is for ariadne to return the right type return "Snapshot" class SnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested directly with its SWHID """ def _get_node_data(self): """ """ swhid = self.kwargs.get("swhid") if ( swhid.object_type == ObjectType.SNAPSHOT and self.archive.is_object_available(swhid.object_id, swhid.object_type) ): return self._get_snapshot_by_id(swhid.object_id) return None class VisitSnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested from a visit-status """ + _can_be_null = True obj: BaseVisitStatusNode def _get_node_data(self): + if self.obj.snapshotSWHID is None: + raise NullableObjectError() snapshot_id = self.obj.snapshotSWHID.object_id return self._get_snapshot_by_id(snapshot_id) class TargetSnapshotNode(BaseSnapshotNode): """ Node resolver for a snapshot requested as a target """ from .snapshot_branch import BaseSnapshotBranchNode + _can_be_null = True obj: Union[SearchResultNode, BaseSnapshotBranchNode] def _get_node_data(self): snapshot_id = self.obj.target_hash return self._get_snapshot_by_id(snapshot_id) class OriginSnapshotConnection(BaseConnection): """ Connection resolver for the snapshots in an origin """ obj: OriginNode _node_class = BaseSnapshotNode def _get_paged_result(self) -> PagedResult: results = self.archive.get_origin_snapshots(self.obj.url) snapshots = [Snapshot(id=snapshot, branches={}) for snapshot in results] # FIXME, using dummy(local) pagination, move pagination to backend # To remove localpagination, just drop the paginated call # STORAGE-TODO return utils.paginated(snapshots, self._get_first_arg(), self._get_after_arg()) diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py index 19d24b3..27b30fd 100644 --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -1,66 +1,67 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information from swh.graphql.utils import utils from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode from .origin import OriginNode class BaseVisitNode(BaseNode): """ Base resolver for all the visit nodes """ @property def id(self): # FIXME, use a better id return utils.get_b64_string(f"{self.origin}-{str(self.visit)}") @property def visitId(self): # To support the schema naming convention return self._node.visit class OriginVisitNode(BaseVisitNode): """ Node resolver for a visit requested directly with an origin URL and a visit ID """ def _get_node_data(self): return self.archive.get_origin_visit( self.kwargs.get("originUrl"), int(self.kwargs.get("visitId")) ) class LatestVisitNode(BaseVisitNode): """ Node resolver for the latest visit in an origin """ + _can_be_null = True obj: OriginNode def _get_node_data(self): # self.obj.url is the origin URL return self.archive.get_origin_latest_visit(self.obj.url) class OriginVisitConnection(BaseConnection): """ Connection resolver for the visit objects in an origin """ obj: OriginNode _node_class = BaseVisitNode def _get_paged_result(self) -> PagedResult: # self.obj.url is the origin URL return self.archive.get_origin_visits( self.obj.url, after=self._get_after_arg(), first=self._get_first_arg() ) diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py index cfcf116..a637215 100644 --- a/swh/graphql/resolvers/visit_status.py +++ b/swh/graphql/resolvers/visit_status.py @@ -1,69 +1,64 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils from swh.model.swhids import CoreSWHID, ObjectType from swh.storage.interface import PagedResult from .base_connection import BaseConnection from .base_node import BaseNode from .visit import BaseVisitNode class BaseVisitStatusNode(BaseNode): """ Base resolver for all the visit-status nodes """ @property def snapshotSWHID(self): # To support the schema naming convention if self._node.snapshot is None: return None return CoreSWHID(object_type=ObjectType.SNAPSHOT, object_id=self._node.snapshot) class LatestVisitStatusNode(BaseVisitStatusNode): """ Node resolver for a visit-status requested from a visit """ + _can_be_null = True obj: BaseVisitNode def _get_node_data(self): # self.obj.origin is the origin URL return self.archive.get_latest_visit_status( origin_url=self.obj.origin, visit_id=self.obj.visitId, allowed_statuses=self.kwargs.get("allowedStatuses"), require_snapshot=self.kwargs.get("requireSnapshot"), ) - def _handle_node_errors(self) -> None: - # This object can be null - if self._node is None: - raise NullableObjectError("") - class VisitStatusConnection(BaseConnection): """ Connection resolver for the visit-status objects in a visit """ obj: BaseVisitNode _node_class = BaseVisitStatusNode def _get_paged_result(self) -> PagedResult: # self.obj.origin is the origin URL return self.archive.get_visit_status( self.obj.origin, self.obj.visitId, after=self._get_after_arg(), first=self._get_first_arg(), ) def _get_index_cursor(self, index: int, node: BaseVisitStatusNode): # Visit status is using a different cursor, hence the override return utils.get_encoded_cursor(utils.get_formatted_date(node.date)) diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py index 32b5bfc..ed6cb9f 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py +++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py @@ -1,73 +1,18 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.graphql.resolvers import resolver_factory class TestFactory: - @pytest.mark.parametrize( - "input_type, expected", - [ - ("origin", "OriginNode"), - ("visit", "OriginVisitNode"), - ("latest-visit", "LatestVisitNode"), - ("latest-status", "LatestVisitStatusNode"), - ("visit-snapshot", "VisitSnapshotNode"), - ("snapshot", "SnapshotNode"), - ("branch-revision", "TargetRevisionNode"), - ("branch-release", "TargetReleaseNode"), - ("branch-directory", "TargetDirectoryNode"), - ("branch-content", "TargetContentNode"), - ("branch-snapshot", "TargetSnapshotNode"), - ("revision", "RevisionNode"), - ("revision-directory", "RevisionDirectoryNode"), - ("release", "ReleaseNode"), - ("release-revision", "TargetRevisionNode"), - ("release-release", "TargetReleaseNode"), - ("release-directory", "TargetDirectoryNode"), - ("release-content", "TargetContentNode"), - ("directory", "DirectoryNode"), - ("content", "ContentNode"), - ("dir-entry-directory", "TargetDirectoryNode"), - ("dir-entry-content", "TargetContentNode"), - ("dir-entry-revision", "TargetRevisionNode"), - ("search-result-snapshot", "TargetSnapshotNode"), - ("search-result-revision", "TargetRevisionNode"), - ("search-result-release", "TargetReleaseNode"), - ("search-result-directory", "TargetDirectoryNode"), - ("search-result-content", "TargetContentNode"), - ], - ) - def test_get_node_resolver(self, input_type, expected): - response = resolver_factory.get_node_resolver(input_type) - assert response.__name__ == expected - def test_get_node_resolver_invalid_type(self): with pytest.raises(AttributeError): - resolver_factory.get_node_resolver("invalid") - - @pytest.mark.parametrize( - "input_type, expected", - [ - ("origins", "OriginConnection"), - ("origin-visits", "OriginVisitConnection"), - ("origin-snapshots", "OriginSnapshotConnection"), - ("visit-status", "VisitStatusConnection"), - ("snapshot-branches", "SnapshotBranchConnection"), - ("revision-parents", "ParentRevisionConnection"), - ("revision-log", "LogRevisionConnection"), - ("directory-entries", "DirectoryEntryConnection"), - ("resolve-swhid", "ResolveSwhidConnection"), - ], - ) - def test_get_connection_resolver(self, input_type, expected): - response = resolver_factory.get_connection_resolver(input_type) - assert response.__name__ == expected + resolver_factory.NodeObjectFactory().create("invalid", None, None) def test_get_connection_resolver_invalid_type(self): with pytest.raises(AttributeError): - resolver_factory.get_connection_resolver("invalid") + resolver_factory.get_connection_resolver("invalid", None, None) diff --git a/swh/graphql/tests/unit/resolvers/test_resolvers.py b/swh/graphql/tests/unit/resolvers/test_resolvers.py index ef54d7f..00ba594 100644 --- a/swh/graphql/tests/unit/resolvers/test_resolvers.py +++ b/swh/graphql/tests/unit/resolvers/test_resolvers.py @@ -1,131 +1,131 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import pytest from swh.graphql import resolvers from swh.graphql.resolvers import resolvers as rs class TestResolvers: """ """ @pytest.fixture def dummy_node(self): return {"test": "test"} @pytest.mark.parametrize( "resolver_func, node_cls", [ (rs.origin_resolver, resolvers.origin.OriginNode), (rs.visit_resolver, resolvers.visit.OriginVisitNode), (rs.latest_visit_resolver, resolvers.visit.LatestVisitNode), ( rs.latest_visit_status_resolver, resolvers.visit_status.LatestVisitStatusNode, ), (rs.snapshot_resolver, resolvers.snapshot.SnapshotNode), (rs.revision_resolver, resolvers.revision.RevisionNode), (rs.revision_directory_resolver, resolvers.directory.RevisionDirectoryNode), (rs.release_resolver, resolvers.release.ReleaseNode), (rs.directory_resolver, resolvers.directory.DirectoryNode), (rs.content_resolver, resolvers.content.ContentNode), ], ) def test_node_resolver(self, mocker, dummy_node, resolver_func, node_cls): mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = resolver_func(None, None) # assert the _get_node method is called on the right object assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "resolver_func, connection_cls", [ (rs.origins_resolver, resolvers.origin.OriginConnection), (rs.visits_resolver, resolvers.visit.OriginVisitConnection), (rs.origin_snapshots_resolver, resolvers.snapshot.OriginSnapshotConnection), (rs.visitstatus_resolver, resolvers.visit_status.VisitStatusConnection), ( rs.snapshot_branches_resolver, resolvers.snapshot_branch.SnapshotBranchConnection, ), (rs.revision_parents_resolver, resolvers.revision.ParentRevisionConnection), - # (rs.revision_log_resolver, resolvers.revision.LogRevisionConnection), + (rs.revision_log_resolver, resolvers.revision.LogRevisionConnection), ( rs.directory_entries_resolver, resolvers.directory_entry.DirectoryEntryConnection, ), ], ) def test_connection_resolver(self, resolver_func, connection_cls): connection_obj = resolver_func(None, None) # assert the right object is returned assert isinstance(connection_obj, connection_cls) @pytest.mark.parametrize( "branch_type, node_cls", [ ("revision", resolvers.revision.TargetRevisionNode), ("release", resolvers.release.TargetReleaseNode), ("directory", resolvers.directory.TargetDirectoryNode), ("content", resolvers.content.TargetContentNode), ("snapshot", resolvers.snapshot.TargetSnapshotNode), ], ) def test_snapshot_branch_target_resolver( self, mocker, dummy_node, branch_type, node_cls ): - obj = mocker.Mock(type=branch_type) + obj = mocker.Mock(targetType=branch_type) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.snapshot_branch_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "target_type, node_cls", [ ("revision", resolvers.revision.TargetRevisionNode), ("release", resolvers.release.TargetReleaseNode), ("directory", resolvers.directory.TargetDirectoryNode), ("content", resolvers.content.TargetContentNode), ], ) def test_release_target_resolver(self, mocker, dummy_node, target_type, node_cls): - obj = mocker.Mock(target_type=(mocker.Mock(value=target_type))) + obj = mocker.Mock(targetType=target_type) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.release_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called @pytest.mark.parametrize( "target_type, node_cls", [ ("directory", resolvers.directory.TargetDirectoryNode), ("content", resolvers.content.TargetContentNode), ("revision", resolvers.revision.TargetRevisionNode), ], ) def test_directory_entry_target_resolver( self, mocker, dummy_node, target_type, node_cls ): obj = mocker.Mock(targetType=target_type) mock_get = mocker.patch.object(node_cls, "_get_node", return_value=dummy_node) node_obj = rs.directory_entry_target_resolver(obj, None) assert isinstance(node_obj, node_cls) assert mock_get.assert_called def test_union_resolver(self, mocker): obj = mocker.Mock() obj.is_type_of.return_value = "test" assert rs.union_resolver(obj) == "test" def test_binary_string_text_resolver(self): text = rs.binary_string_text_resolver(b"test", None) assert text == "test" def test_binary_string_base64_resolver(self): b64string = rs.binary_string_base64_resolver(b"test", None) assert b64string == "dGVzdA=="