diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -46,7 +46,9 @@ def is_snapshot_available(self, snapshot_ids): return not self.storage.snapshot_missing(snapshot_ids) - def get_snapshot_branches(self, snapshot, after, first, target_types, name_include): + def get_snapshot_branches( + self, snapshot, after=b"", first=50, target_types=[], name_include=None + ): return self.storage.snapshot_get_branches( snapshot, branches_from=after, diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py --- a/swh/graphql/resolvers/content.py +++ b/swh/graphql/resolvers/content.py @@ -10,6 +10,7 @@ from .base_node import BaseSWHNode from .directory_entry import DirectoryEntryNode from .release import BaseReleaseNode +from .snapshot_branch import SnapshotBranchNode class BaseContentNode(BaseSWHNode): @@ -51,7 +52,7 @@ directory entry or from a release target """ - obj: Union[DirectoryEntryNode, BaseReleaseNode] + obj: Union[DirectoryEntryNode, BaseReleaseNode, SnapshotBranchNode] def _get_node_data(self): content_id = self.obj.targetHash diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py --- a/swh/graphql/resolvers/directory.py +++ b/swh/graphql/resolvers/directory.py @@ -3,11 +3,15 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Union + from swh.graphql.backends import archive from swh.model.model import Directory from .base_node import BaseSWHNode +from .release import BaseReleaseNode from .revision import BaseRevisionNode +from .snapshot_branch import SnapshotBranchNode class BaseDirectoryNode(BaseSWHNode): @@ -58,5 +62,9 @@ Node resolver for a directory requested as a target """ + from .directory_entry import DirectoryEntryNode + + obj: Union[SnapshotBranchNode, BaseReleaseNode, DirectoryEntryNode] + def _get_node_data(self): return self._get_directory_by_id(self.obj.targetHash) diff --git a/swh/graphql/resolvers/directory_entry.py b/swh/graphql/resolvers/directory_entry.py --- a/swh/graphql/resolvers/directory_entry.py +++ b/swh/graphql/resolvers/directory_entry.py @@ -8,7 +8,6 @@ from .base_connection import BaseConnection from .base_node import BaseNode -from .directory import BaseDirectoryNode class DirectoryEntryNode(BaseNode): @@ -26,6 +25,8 @@ Connection resolver for entries in a directory """ + from .directory import BaseDirectoryNode + obj: BaseDirectoryNode _node_class = DirectoryEntryNode diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -14,7 +14,12 @@ RevisionNode, TargetRevisionNode, ) -from .snapshot import OriginSnapshotConnection, SnapshotNode, VisitSnapshotNode +from .snapshot import ( + OriginSnapshotConnection, + SnapshotNode, + TargetSnapshotNode, + VisitSnapshotNode, +) from .snapshot_branch import SnapshotBranchConnection from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode from .visit_status import LatestVisitStatusNode, VisitStatusConnection @@ -31,6 +36,9 @@ "snapshot": SnapshotNode, "branch-revision": TargetRevisionNode, "branch-release": TargetReleaseNode, + "branch-directory": TargetDirectoryNode, + "branch-content": TargetContentNode, + "branch-snapshot": TargetSnapshotNode, "revision": RevisionNode, "revision-directory": RevisionDirectoryNode, "release": ReleaseNode, diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -3,6 +3,8 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Union + from swh.graphql.backends import archive from swh.graphql.utils import utils from swh.model.model import Snapshot @@ -24,6 +26,11 @@ # Same pattern is used in directory return Snapshot(id=snapshot_id, branches={}) + def is_type_of(self): + # is_type_of is required only when resolving a UNION type + # This is for ariadne to return the right type + return "Snapshot" + class SnapshotNode(BaseSnapshotNode): """ @@ -51,6 +58,20 @@ return self._get_snapshot_by_id(snapshot_id) +class TargetSnapshotNode(BaseSnapshotNode): + """ + Node resolver for a snapshot requested as a target + """ + + from .snapshot_branch import SnapshotBranchNode + + obj: Union[BaseVisitStatusNode, SnapshotBranchNode] + + def _get_node_data(self): + snapshot_id = self.obj.targetHash + return self._get_snapshot_by_id(snapshot_id) + + class OriginSnapshotConnection(BaseConnection): """ Connection resolver for the snapshots in an origin diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -11,7 +11,6 @@ from .base_connection import BaseConnection from .base_node import BaseNode -from .snapshot import SnapshotNode class SnapshotBranchNode(BaseNode): @@ -45,6 +44,8 @@ Connection resolver for the branches in a snapshot """ + from .snapshot import SnapshotNode + obj: SnapshotNode _node_class = SnapshotBranchNode @@ -56,12 +57,12 @@ after=self._get_after_arg(), first=self._get_first_arg(), target_types=self.kwargs.get("types"), - name_include=self.kwargs.get("nameInclude"), + name_include=self._get_name_include_arg(), ) # FIXME Cursor must be a hex to be consistent with # the base class, hack to make that work end_cusrsor = ( - result["next_branch"].hex() if result["next_branch"] is not None else None + result["next_branch"] if result["next_branch"] is not None else None ) # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult @@ -72,11 +73,12 @@ def _get_after_arg(self): # Snapshot branch is using a different cursor; logic to handle that - - # FIXME Cursor must be a hex to be consistent with - # the base class, hack to make that work after = utils.get_decoded_cursor(self.kwargs.get("after", "")) - return bytes.fromhex(after) + return after.encode() if after else b"" + + def _get_name_include_arg(self): + name_include = self.kwargs.get("nameInclude", None) + return name_include.encode() if name_include else None def _get_index_cursor(self, index: int, node: SnapshotBranchNode): # Snapshot branch is using a different cursor, hence the override diff --git a/swh/graphql/tests/functional/test_branch_connection.py b/swh/graphql/tests/functional/test_branch_connection.py new file mode 100644 --- /dev/null +++ b/swh/graphql/tests/functional/test_branch_connection.py @@ -0,0 +1,204 @@ +# Copyright (C) 2022 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import pytest + +from .utils import get_query_response + + +def test_get(client): + query_str = """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first:10) { + nodes { + type + target { + __typename + ...on Revision { + swhid + } + ...on Release { + swhid + } + ...on Content { + swhid + } + ...on Directory { + swhid + } + ...on Snapshot { + swhid + } + } + } + } + } + } + """ + data, errors = get_query_response(client, query_str) + # Alias type is not handled at the moment, hence the error + assert len(errors) == 1 + assert errors[0]["message"] == "Invalid node type: branch-alias" + assert len(data["snapshot"]["branches"]["nodes"]) == 5 + + +@pytest.mark.parametrize( + "filter_type, count, target_type, swhid_pattern", + [ + ("revision", 1, "Revision", "swh:1:rev"), + ("release", 1, "Release", "swh:1:rel"), + ("directory", 1, "Directory", "swh:1:dir"), + ("content", 0, "Content", "swh:1:cnt"), + ("snapshot", 1, "Snapshot", "swh:1:snp"), + ], +) +def test_get_type_filter(client, filter_type, count, target_type, swhid_pattern): + query_str = ( + """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first:10, types: [%s]) { + nodes { + type + target { + __typename + ...on Revision { + swhid + } + ...on Release { + swhid + } + ...on Content { + swhid + } + ...on Directory { + swhid + } + ...on Snapshot { + swhid + } + } + } + } + } + } + """ + % filter_type + ) + data, _ = get_query_response(client, query_str) + + assert len(data["snapshot"]["branches"]["nodes"]) == count + for node in data["snapshot"]["branches"]["nodes"]: + assert node["target"]["__typename"] == target_type + assert node["target"]["swhid"].startswith(swhid_pattern) + + +@pytest.mark.parametrize( + "filter_types, count", + [ + ("revision, release", 2), + ("revision, snapshot, release", 3), + ], +) +def test_get_type_filter_multiple(client, filter_types, count): + query_str = ( + """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first:10, types: [%s]) { + nodes { + type + } + } + } + }""" + % filter_types + ) + data, _ = get_query_response(client, query_str) + assert len(data["snapshot"]["branches"]["nodes"]) == count + + +@pytest.mark.parametrize("name", ["rel", "rev", "non-exist"]) +def test_get_name_include_filter(client, name): + query_str = ( + """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first:10, nameInclude: "%s") { + nodes { + name { + text + } + } + } + } + }""" + % name + ) + data, _ = get_query_response(client, query_str) + for node in data["snapshot"]["branches"]["nodes"]: + assert name in node["name"]["text"] + + +@pytest.mark.parametrize("count", [1, 2]) +def test_get_first_arg(client, count): + query_str = ( + """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first: %s) { + nodes { + type + } + } + } + }""" + % count + ) + data, _ = get_query_response(client, query_str) + assert len(data["snapshot"]["branches"]["nodes"]) == count + + +def test_get_after_arg(client): + query_str = """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first: 1) { + pageInfo { + endCursor + } + nodes { + name { + text + } + } + } + } + }""" + first_data, _ = get_query_response(client, query_str) + end_cursor = first_data["snapshot"]["branches"]["pageInfo"]["endCursor"] + node_name = first_data["snapshot"]["branches"]["nodes"][0]["name"]["text"] + + query_str = ( + """ + { + snapshot(swhid: "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f") { + branches(first: 3, after: "%s") { + nodes { + type + name { + text + } + } + } + } + }""" + % end_cursor + ) + second_data, _ = get_query_response(client, query_str) + assert len(second_data["snapshot"]["branches"]["nodes"]) == 3 + for node in second_data["snapshot"]["branches"]["nodes"]: + assert node["name"]["text"] > node_name