diff --git a/swh/graphql/backends/archive.py b/swh/graphql/backends/archive.py --- a/swh/graphql/backends/archive.py +++ b/swh/graphql/backends/archive.py @@ -38,7 +38,7 @@ return self.storage.origin_snapshot_get_all(origin_url) def get_snapshot_branches( - self, snapshot, after=b"", first=50, target_types=[], name_include=None + self, snapshot, after=b"", first=50, target_types=None, name_include=None ): return self.storage.snapshot_get_branches( snapshot, diff --git a/swh/graphql/resolvers/content.py b/swh/graphql/resolvers/content.py --- a/swh/graphql/resolvers/content.py +++ b/swh/graphql/resolvers/content.py @@ -10,7 +10,7 @@ from .base_node import BaseSWHNode from .directory_entry import DirectoryEntryNode from .release import BaseReleaseNode -from .snapshot_branch import SnapshotBranchNode +from .snapshot_branch import BaseSnapshotBranchNode class BaseContentNode(BaseSWHNode): @@ -89,7 +89,7 @@ This request could be from directory entry, release or a branch """ - obj: Union[DirectoryEntryNode, BaseReleaseNode, SnapshotBranchNode] + obj: Union[DirectoryEntryNode, BaseReleaseNode, BaseSnapshotBranchNode] def _get_node_data(self): return self._get_content_by_hash(checksums={"sha1_git": self.obj.target_hash}) diff --git a/swh/graphql/resolvers/directory.py b/swh/graphql/resolvers/directory.py --- a/swh/graphql/resolvers/directory.py +++ b/swh/graphql/resolvers/directory.py @@ -12,7 +12,7 @@ from .base_node import BaseSWHNode from .release import BaseReleaseNode from .revision import BaseRevisionNode -from .snapshot_branch import SnapshotBranchNode +from .snapshot_branch import BaseSnapshotBranchNode class BaseDirectoryNode(BaseSWHNode): @@ -69,7 +69,7 @@ from .directory_entry import DirectoryEntryNode - obj: Union[SnapshotBranchNode, BaseReleaseNode, DirectoryEntryNode] + obj: Union[BaseSnapshotBranchNode, BaseReleaseNode, DirectoryEntryNode] def _get_node_data(self): return self._get_directory_by_id(self.obj.target_hash) diff --git a/swh/graphql/resolvers/release.py b/swh/graphql/resolvers/release.py --- a/swh/graphql/resolvers/release.py +++ b/swh/graphql/resolvers/release.py @@ -8,7 +8,7 @@ from swh.graphql.backends import archive from .base_node import BaseSWHNode -from .snapshot_branch import SnapshotBranchNode +from .snapshot_branch import BaseSnapshotBranchNode class BaseReleaseNode(BaseSWHNode): @@ -47,7 +47,7 @@ Node resolver for a release requested as a target """ - obj: Union[SnapshotBranchNode, BaseReleaseNode] + obj: Union[BaseSnapshotBranchNode, BaseReleaseNode] def _get_node_data(self): # self.obj.target_hash is the requested release id diff --git a/swh/graphql/resolvers/resolver_factory.py b/swh/graphql/resolvers/resolver_factory.py --- a/swh/graphql/resolvers/resolver_factory.py +++ b/swh/graphql/resolvers/resolver_factory.py @@ -21,7 +21,7 @@ TargetSnapshotNode, VisitSnapshotNode, ) -from .snapshot_branch import SnapshotBranchConnection +from .snapshot_branch import AliasSnapshotBranchNode, SnapshotBranchConnection from .visit import LatestVisitNode, OriginVisitConnection, OriginVisitNode from .visit_status import LatestVisitStatusNode, VisitStatusConnection @@ -35,6 +35,7 @@ "latest-status": LatestVisitStatusNode, "visit-snapshot": VisitSnapshotNode, "snapshot": SnapshotNode, + "branch-alias": AliasSnapshotBranchNode, "branch-revision": TargetRevisionNode, "branch-release": TargetReleaseNode, "branch-directory": TargetDirectoryNode, diff --git a/swh/graphql/resolvers/resolvers.py b/swh/graphql/resolvers/resolvers.py --- a/swh/graphql/resolvers/resolvers.py +++ b/swh/graphql/resolvers/resolvers.py @@ -99,10 +99,11 @@ @snapshot_branch.field("target") def snapshot_branch_target_resolver( - obj: rs.snapshot_branch.SnapshotBranchNode, info: GraphQLResolveInfo, **kw + obj: rs.snapshot_branch.BaseSnapshotBranchNode, info: GraphQLResolveInfo, **kw ): """ - Snapshot branch target can be a revision or a release + Snapshot branch target can be a revision, release, directory, + content, snapshot or a branch itself (alias type) """ resolver_type = f"branch-{obj.type}" resolver = get_node_resolver(resolver_type) diff --git a/swh/graphql/resolvers/revision.py b/swh/graphql/resolvers/revision.py --- a/swh/graphql/resolvers/revision.py +++ b/swh/graphql/resolvers/revision.py @@ -13,7 +13,7 @@ from .base_connection import BaseConnection from .base_node import BaseSWHNode from .release import BaseReleaseNode -from .snapshot_branch import SnapshotBranchNode +from .snapshot_branch import BaseSnapshotBranchNode class BaseRevisionNode(BaseSWHNode): @@ -59,7 +59,7 @@ Node resolver for a revision requested as a target """ - obj: Union[SnapshotBranchNode, BaseReleaseNode] + obj: Union[BaseSnapshotBranchNode, BaseReleaseNode] def _get_node_data(self): # self.obj.target_hash is the requested revision id diff --git a/swh/graphql/resolvers/snapshot.py b/swh/graphql/resolvers/snapshot.py --- a/swh/graphql/resolvers/snapshot.py +++ b/swh/graphql/resolvers/snapshot.py @@ -70,9 +70,9 @@ Node resolver for a snapshot requested as a target """ - from .snapshot_branch import SnapshotBranchNode + from .snapshot_branch import BaseSnapshotBranchNode - obj: Union[BaseVisitStatusNode, SnapshotBranchNode] + obj: Union[BaseVisitStatusNode, BaseSnapshotBranchNode] def _get_node_data(self): snapshot_id = self.obj.target_hash diff --git a/swh/graphql/resolvers/snapshot_branch.py b/swh/graphql/resolvers/snapshot_branch.py --- a/swh/graphql/resolvers/snapshot_branch.py +++ b/swh/graphql/resolvers/snapshot_branch.py @@ -6,6 +6,7 @@ from collections import namedtuple from swh.graphql.backends import archive +from swh.graphql.errors import ObjectNotFoundError from swh.graphql.utils import utils from swh.storage.interface import PagedResult @@ -13,30 +14,71 @@ from .base_node import BaseNode -class SnapshotBranchNode(BaseNode): - """ - Node resolver for a snapshot branch - """ +class BaseSnapshotBranchNode(BaseNode): - # target field for this Node is a UNION type + # target field for this node is a UNION type # It is resolved in the top level (resolvers.resolvers.py) - def _get_node_from_data(self, node_data): - # node_data is not a dict in this case + def _get_node_from_data(self, node_data: tuple): + # node_data is a tuple as returned by _get_paged_result in + # SnapshotBranchConnection and _get_node_data in AliasSnapshotBranchNode # overriding to support this special data structure - - # STORAGE-TODO; return an object in the normal format branch_name, branch_obj = node_data node = { "name": branch_name, "type": branch_obj.target_type.value, - "target": branch_obj.target, + "target_hash": branch_obj.target, } return namedtuple("NodeObj", node.keys())(*node.values()) - @property - def target_hash(self): - return self._node.target + def is_type_of(self): + return "Branch" + + def snapshot_swhid(self): + raise NotImplementedError("Implement snapshot_swhid") + + +class ConnectionSnapshotBranchNode(BaseSnapshotBranchNode): + """ + Node resolver for a snapshot branch requested from a snapshot branch connection + """ + + # obj: SnapshotBranchConnection + + def snapshot_swhid(self): + # self.obj is SnapshotBranchConnection. + # hence self.obj.obj is always of type BaseSnapshotNode + + # This will fail when this node is used for a connection that directly + # requests snapshot branches with a snapshot SWHID. Create a new node object + # in that case + return self.obj.obj.swhid + + +class AliasSnapshotBranchNode(BaseSnapshotBranchNode): + + obj: ConnectionSnapshotBranchNode + + def _get_node_data(self): + # snapshot_swhid will be provided by the parent object (self.obj) + # As of now ConnectionSnapshotBranchNode is the only possible parent + # implement snapshot_swhid in each of them if you are planning to add more parents. + # eg for another possible parent: A node class that can get a snapshot branch directly + # using snapshot id and branch name, snapshot_swhid will be available in the + # user input (kwargs) in that case + + snapshot_swhid = self.obj.snapshot_swhid() + target_branch = self.obj.target_hash + + alias_branch = archive.Archive().get_snapshot_branches( + snapshot_swhid.object_id, first=1, name_include=target_branch + ) + if target_branch not in alias_branch["branches"]: + raise ObjectNotFoundError( + f"Branch name with {target_branch.decode()} is not available" + ) + # this will be serialized in _get_node_from_data method in the base class + return (target_branch, alias_branch["branches"][target_branch]) class SnapshotBranchConnection(BaseConnection): @@ -44,11 +86,11 @@ Connection resolver for the branches in a snapshot """ - from .snapshot import SnapshotNode + from .snapshot import BaseSnapshotNode - obj: SnapshotNode + obj: BaseSnapshotNode - _node_class = SnapshotBranchNode + _node_class = ConnectionSnapshotBranchNode def _get_paged_result(self) -> PagedResult: result = archive.Archive().get_snapshot_branches( @@ -58,7 +100,6 @@ target_types=self.kwargs.get("types"), name_include=self._get_name_include_arg(), ) - # endCursor is the last branch name, logic for that end_cusrsor = ( result["next_branch"] if result["next_branch"] is not None else None @@ -66,6 +107,8 @@ # FIXME, this pagination is not consistent with other connections # FIX in swh-storage to return PagedResult # STORAGE-TODO + + # this will be serialized in _get_node_from_data method in the node class return PagedResult( results=result["branches"].items(), next_page_token=end_cusrsor ) @@ -79,6 +122,6 @@ name_include = self.kwargs.get("nameInclude", None) return name_include.encode() if name_include else None - def _get_index_cursor(self, index: int, node: SnapshotBranchNode): + def _get_index_cursor(self, index: int, node: ConnectionSnapshotBranchNode): # Snapshot branch is using a different cursor, hence the override return utils.get_encoded_cursor(node.name) diff --git a/swh/graphql/tests/functional/test_branch_connection.py b/swh/graphql/tests/functional/test_branch_connection.py --- a/swh/graphql/tests/functional/test_branch_connection.py +++ b/swh/graphql/tests/functional/test_branch_connection.py @@ -28,6 +28,11 @@ } target { __typename + ...on Branch { + name { + text + } + } ...on Revision { swhid } @@ -55,15 +60,6 @@ return utils.get_query_response(client, query_str) -def test_get(client): - swhid = "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f" - data, errors = get_branches(client, swhid, 10) - # Alias type is not handled at the moment, hence the error - assert len(errors) == 1 - assert errors[0]["message"] == "Invalid node type: branch-alias" - assert len(data["snapshot"]["branches"]["nodes"]) == 5 - - def test_get_data(client): swhid = "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f" data, errors = get_branches(client, swhid, 10, types="[revision]") @@ -80,6 +76,17 @@ } +def test_get_branches_with_alias(client): + swhid = "swh:1:snp:0e7f84ede9a254f2cd55649ad5240783f557e65f" + data, _ = get_branches(client, swhid, 10, types="[alias]") + node = data["snapshot"]["branches"]["nodes"][0] + assert node == { + "name": {"text": "target/alias"}, + "target": {"__typename": "Branch", "name": {"text": "target/revision"}}, + "type": "alias", + } + + @pytest.mark.parametrize( "filter_type, count, target_type, swhid_pattern", [