diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py --- a/swh/graphql/resolvers/directory.py +++ b/swh/graphql/resolvers/directory.py @@ -5,6 +5,7 @@ from typing import Union +from swh.graphql.errors import NullableObjectError from swh.model.model import Directory from swh.model.swhids import ObjectType @@ -58,6 +59,11 @@ # self.obj.directory_hash is the requested directory Id return self._get_directory_by_id(self.obj.directory_hash) + def _handle_node_errors(self) -> None: + # This object can be null + if self._node is None: + raise NullableObjectError() + class TargetDirectoryNode(BaseDirectoryNode): """ diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from swh.graphql.errors import NullableObjectError + from .content import ContentNode, HashContentNode, TargetContentNode from .directory import DirectoryNode, RevisionDirectoryNode, TargetDirectoryNode from .directory_entry import DirectoryEntryConnection, DirectoryEntryNode @@ -26,45 +28,54 @@ from .visit_status import LatestVisitStatusNode, VisitStatusConnection -def get_node_resolver(resolver_type: str): - # FIXME, replace with a proper factory method - mapping = { - "origin": OriginNode, - "visit": OriginVisitNode, - "latest-visit": LatestVisitNode, - "latest-status": LatestVisitStatusNode, - "visit-snapshot": VisitSnapshotNode, - "snapshot": SnapshotNode, - "branch-alias": AliasSnapshotBranchNode, - "branch-revision": TargetRevisionNode, - "branch-release": TargetReleaseNode, - "branch-directory": TargetDirectoryNode, - "branch-content": TargetContentNode, - "branch-snapshot": TargetSnapshotNode, - "revision": RevisionNode, - "revision-directory": RevisionDirectoryNode, - "release": ReleaseNode, - "release-revision": TargetRevisionNode, - "release-release": TargetReleaseNode, - "release-directory": TargetDirectoryNode, - "release-content": TargetContentNode, - "directory": DirectoryNode, - "directory-entry": DirectoryEntryNode, - "content": ContentNode, - "content-by-hash": HashContentNode, - "dir-entry-content": TargetContentNode, - "dir-entry-directory": TargetDirectoryNode, - "dir-entry-revision": TargetRevisionNode, - "search-result-origin": TargetOriginNode, - "search-result-snapshot": TargetSnapshotNode, - "search-result-revision": TargetRevisionNode, - "search-result-release": TargetReleaseNode, - "search-result-directory": TargetDirectoryNode, - "search-result-content": TargetContentNode, - } - if resolver_type not in mapping: - raise AttributeError(f"Invalid node type: {resolver_type}") - return mapping[resolver_type] +class NodeObjectFactory: + def __init__(self): + self.mapping = { + "origin": OriginNode, + "visit": OriginVisitNode, + "latest-visit": LatestVisitNode, + "latest-status": LatestVisitStatusNode, + "visit-snapshot": VisitSnapshotNode, + "snapshot": SnapshotNode, + "branch-alias": AliasSnapshotBranchNode, + "branch-revision": TargetRevisionNode, + "branch-release": TargetReleaseNode, + "branch-directory": TargetDirectoryNode, + "branch-content": TargetContentNode, + "branch-snapshot": TargetSnapshotNode, + "revision": RevisionNode, + "revision-directory": RevisionDirectoryNode, + "release": ReleaseNode, + "release-revision": TargetRevisionNode, + "release-release": TargetReleaseNode, + "release-directory": TargetDirectoryNode, + "release-content": TargetContentNode, + "directory": DirectoryNode, + "directory-entry": DirectoryEntryNode, + "content": ContentNode, + "content-by-hash": HashContentNode, + "dir-entry-content": TargetContentNode, + "dir-entry-directory": TargetDirectoryNode, + "dir-entry-revision": TargetRevisionNode, + "search-result-origin": TargetOriginNode, + "search-result-snapshot": TargetSnapshotNode, + "search-result-revision": TargetRevisionNode, + "search-result-release": TargetReleaseNode, + "search-result-directory": TargetDirectoryNode, + "search-result-content": TargetContentNode, + } + + def create(self, node_type: str, *args, **kw): + resolver = self.mapping.get(node_type) + if not resolver: + raise AttributeError(f"Invalid node type: {node_type}") + try: + node_obj = resolver(*args, **kw) + except NullableObjectError: + # This exception will not create 'missing field' + # errors for object attributes + node_obj = None + return node_obj def get_connection_resolver(resolver_type: str): diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -22,10 +22,9 @@ from graphql.type import GraphQLResolveInfo from swh.graphql import resolvers as rs -from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils -from .resolver_factory import get_connection_resolver, get_node_resolver +from .resolver_factory import NodeObjectFactory, get_connection_resolver query: ObjectType = ObjectType("Query") origin: ObjectType = ObjectType("Origin") @@ -48,12 +47,13 @@ # Node resolvers # A node resolver should return an instance of BaseNode +node_object_factory = NodeObjectFactory() + @query.field("origin") def origin_resolver(obj: None, info: GraphQLResolveInfo, **kw) -> rs.origin.OriginNode: """ """ - resolver = get_node_resolver("origin") - return resolver(obj, info, **kw) + return node_object_factory.create("origin", obj, info, **kw) @origin.field("latestVisit") @@ -61,8 +61,7 @@ obj: rs.origin.BaseOriginNode, info: GraphQLResolveInfo, **kw ) -> rs.visit.LatestVisitNode: """ """ - resolver = get_node_resolver("latest-visit") - return resolver(obj, info, **kw) + return node_object_factory.create("latest-visit", obj, info, **kw) @query.field("visit") @@ -70,8 +69,7 @@ obj: None, info: GraphQLResolveInfo, **kw ) -> rs.visit.OriginVisitNode: """ """ - resolver = get_node_resolver("visit") - return resolver(obj, info, **kw) + return node_object_factory.create("visit", obj, info, **kw) @visit.field("latestStatus") @@ -79,12 +77,7 @@ obj: rs.visit.BaseVisitNode, info: GraphQLResolveInfo, **kw ) -> Optional[rs.visit_status.LatestVisitStatusNode]: """ """ - resolver = get_node_resolver("latest-status") - try: - return resolver(obj, info, **kw) - except NullableObjectError: - # FIXME, make this pattern generic for all the resolvers - return None + return node_object_factory.create("latest-status", obj, info, **kw) @query.field("snapshot") @@ -92,18 +85,14 @@ obj: None, info: GraphQLResolveInfo, **kw ) -> rs.snapshot.SnapshotNode: """ """ - resolver = get_node_resolver("snapshot") - return resolver(obj, info, **kw) + return node_object_factory.create("snapshot", obj, info, **kw) @visit_status.field("snapshot") def visit_snapshot_resolver( obj: rs.visit_status.BaseVisitStatusNode, info: GraphQLResolveInfo, **kw ) -> Optional[rs.snapshot.VisitSnapshotNode]: - if obj.snapshotSWHID is None: - return None - resolver = get_node_resolver("visit-snapshot") - return resolver(obj, info, **kw) + return node_object_factory.create("visit-snapshot", obj, info, **kw) @snapshot_branch.field("target") @@ -121,33 +110,28 @@ Snapshot branch target can be a revision, release, directory, content, snapshot or a branch itself (alias type) """ - resolver_type = f"branch-{obj.type}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return node_object_factory.create(f"branch-{obj.type}", obj, info, **kw) @query.field("revision") def revision_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.revision.RevisionNode: - resolver = get_node_resolver("revision") - return resolver(obj, info, **kw) + return node_object_factory.create("revision", obj, info, **kw) @revision.field("directory") def revision_directory_resolver( obj: rs.revision.BaseRevisionNode, info: GraphQLResolveInfo, **kw ) -> rs.directory.RevisionDirectoryNode: - resolver = get_node_resolver("revision-directory") - return resolver(obj, info, **kw) + return node_object_factory.create("revision-directory", obj, info, **kw) @query.field("release") def release_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.release.ReleaseNode: - resolver = get_node_resolver("release") - return resolver(obj, info, **kw) + return node_object_factory.create("release", obj, info, **kw) @release.field("target") @@ -165,25 +149,23 @@ obj is release here, target type is obj.target_type """ - resolver_type = f"release-{obj.target_type.value}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return node_object_factory.create( + f"release-{obj.target_type.value}", obj, info, **kw + ) @query.field("directory") def directory_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory.DirectoryNode: - resolver = get_node_resolver("directory") - return resolver(obj, info, **kw) + return node_object_factory.create("directory", obj, info, **kw) @query.field("directoryEntry") def directory_entry_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.directory_entry.DirectoryEntryNode: - resolver = get_node_resolver("directory-entry") - return resolver(obj, info, **kw) + return node_object_factory.create("directory-entry", obj, info, **kw) @directory_entry.field("target") @@ -197,17 +179,14 @@ """ directory entry target can be a directory, content or a revision """ - resolver_type = f"dir-entry-{obj.targetType}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return node_object_factory.create(f"dir-entry-{obj.targetType}", obj, info, **kw) @query.field("content") def content_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: - resolver = get_node_resolver("content") - return resolver(obj, info, **kw) + return node_object_factory.create("content", obj, info, **kw) @search_result.field("target") @@ -221,17 +200,14 @@ rs.directory.BaseDirectoryNode, rs.content.BaseContentNode, ]: - resolver_type = f"search-result-{obj.type}" - resolver = get_node_resolver(resolver_type) - return resolver(obj, info, **kw) + return node_object_factory.create(f"search-result-{obj.type}", obj, info, **kw) @query.field("contentByHash") def content_by_hash_resolver( obj: None, info: GraphQLResolveInfo, **kw ) -> rs.content.ContentNode: - resolver = get_node_resolver("content-by-hash") - return resolver(obj, info, **kw) + return node_object_factory.create("content-by-hash", obj, info, **kw) # Connection resolvers diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -5,6 +5,7 @@ from typing import Union +from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils from swh.model.model import Snapshot from swh.model.swhids import ObjectType @@ -58,6 +59,8 @@ obj: BaseVisitStatusNode def _get_node_data(self): + if self.obj.snapshotSWHID is None: + raise NullableObjectError() snapshot_id = self.obj.snapshotSWHID.object_id return self._get_snapshot_by_id(snapshot_id) diff --git a/swh/graphql/resolvers/visit.py b/swh/graphql/resolvers/visit.py --- a/swh/graphql/resolvers/visit.py +++ b/swh/graphql/resolvers/visit.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from swh.graphql.errors import NullableObjectError from swh.graphql.utils import utils from swh.storage.interface import PagedResult @@ -49,6 +50,11 @@ # self.obj.url is the origin URL return self.archive.get_origin_latest_visit(self.obj.url) + def _handle_node_errors(self) -> None: + # This object can be null + if self._node is None: + raise NullableObjectError() + class OriginVisitConnection(BaseConnection): """ diff --git a/swh/graphql/resolvers/visit_status.py b/swh/graphql/resolvers/visit_status.py --- a/swh/graphql/resolvers/visit_status.py +++ b/swh/graphql/resolvers/visit_status.py @@ -43,7 +43,7 @@ def _handle_node_errors(self) -> None: # This object can be null if self._node is None: - raise NullableObjectError("") + raise NullableObjectError() class VisitStatusConnection(BaseConnection): diff --git a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py --- a/swh/graphql/tests/unit/resolvers/test_resolver_factory.py +++ b/swh/graphql/tests/unit/resolvers/test_resolver_factory.py @@ -9,46 +9,9 @@ class TestFactory: - @pytest.mark.parametrize( - "input_type, expected", - [ - ("origin", "OriginNode"), - ("visit", "OriginVisitNode"), - ("latest-visit", "LatestVisitNode"), - ("latest-status", "LatestVisitStatusNode"), - ("visit-snapshot", "VisitSnapshotNode"), - ("snapshot", "SnapshotNode"), - ("branch-revision", "TargetRevisionNode"), - ("branch-release", "TargetReleaseNode"), - ("branch-directory", "TargetDirectoryNode"), - ("branch-content", "TargetContentNode"), - ("branch-snapshot", "TargetSnapshotNode"), - ("revision", "RevisionNode"), - ("revision-directory", "RevisionDirectoryNode"), - ("release", "ReleaseNode"), - ("release-revision", "TargetRevisionNode"), - ("release-release", "TargetReleaseNode"), - ("release-directory", "TargetDirectoryNode"), - ("release-content", "TargetContentNode"), - ("directory", "DirectoryNode"), - ("content", "ContentNode"), - ("dir-entry-directory", "TargetDirectoryNode"), - ("dir-entry-content", "TargetContentNode"), - ("dir-entry-revision", "TargetRevisionNode"), - ("search-result-snapshot", "TargetSnapshotNode"), - ("search-result-revision", "TargetRevisionNode"), - ("search-result-release", "TargetReleaseNode"), - ("search-result-directory", "TargetDirectoryNode"), - ("search-result-content", "TargetContentNode"), - ], - ) - def test_get_node_resolver(self, input_type, expected): - response = resolver_factory.get_node_resolver(input_type) - assert response.__name__ == expected - def test_get_node_resolver_invalid_type(self): with pytest.raises(AttributeError): - resolver_factory.get_node_resolver("invalid") + resolver_factory.NodeObjectFactory().create("invalid", None, None) @pytest.mark.parametrize( "input_type, expected",