diff --git a/docs/api.rst b/docs/api.rst --- a/docs/api.rst +++ b/docs/api.rst @@ -83,6 +83,8 @@ default to 0 (not restricted) :query string return_types: only return the nodes matching this type; default to ``"*"`` + :query integer max_matching_nodes: how many results to return before stopping; + default to 0 (not restricted) :statuscode 200: success :statuscode 400: invalid query string provided @@ -207,6 +209,8 @@ default to 0 (not restricted) :query string return_types: only return the nodes matching this type; default to ``"*"`` + :query integer max_matching_nodes: how many nodes to return/visit before stopping; + default to 0 (not restricted) :statuscode 200: success :statuscode 400: invalid query string provided diff --git a/swh/graph/http_client.py b/swh/graph/http_client.py --- a/swh/graph/http_client.py +++ b/swh/graph/http_client.py @@ -85,7 +85,13 @@ ) def visit_nodes( - self, src, edges="*", direction="forward", max_edges=0, return_types="*" + self, + src, + edges="*", + direction="forward", + max_edges=0, + return_types="*", + max_matching_nodes=0, ): return self.get_lines( "visit/nodes/{}".format(src), @@ -94,6 +100,7 @@ "direction": direction, "max_edges": max_edges, "return_types": return_types, + "max_matching_nodes": max_matching_nodes, }, ) @@ -160,8 +167,14 @@ params={"edges": edges, "direction": direction}, ) - def count_visit_nodes(self, src, edges="*", direction="forward"): + def count_visit_nodes( + self, src, edges="*", direction="forward", max_matching_nodes=0 + ): return self.get( "visit/nodes/count/{}".format(src), - params={"edges": edges, "direction": direction}, + params={ + "edges": edges, + "direction": direction, + "max_matching_nodes": max_matching_nodes, + }, ) diff --git a/swh/graph/http_naive_client.py b/swh/graph/http_naive_client.py --- a/swh/graph/http_naive_client.py +++ b/swh/graph/http_naive_client.py @@ -190,11 +190,15 @@ direction: str = "forward", max_edges: int = 0, return_types: str = "*", + max_matching_nodes: int = 0, ) -> Iterator[str]: # TODO: max_edges - yield from filter_node_types( + res = filter_node_types( return_types, self.graph.get_subgraph(src, edges, direction) ) + if max_matching_nodes > 0: + res = itertools.islice(res, max_matching_nodes) + return res @check_arguments def visit_edges( @@ -279,9 +283,16 @@ @check_arguments def count_visit_nodes( - self, src: str, edges: str = "*", direction: str = "forward" + self, + src: str, + edges: str = "*", + direction: str = "forward", + max_matching_nodes: int = 0, ) -> int: - return len(self.graph.get_subgraph(src, edges, direction)) + res = len(self.graph.get_subgraph(src, edges, direction)) + if max_matching_nodes > 0: + res = min(max_matching_nodes, res) + return res class Graph: diff --git a/swh/graph/tests/test_http_client.py b/swh/graph/tests/test_http_client.py --- a/swh/graph/tests/test_http_client.py +++ b/swh/graph/tests/test_http_client.py @@ -91,6 +91,27 @@ assert set(actual) == set(expected) +@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31]) +def test_visit_nodes_limit(graph_client, max_matching_nodes): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", + edges="rel:rev,rev:rev", + max_matching_nodes=max_matching_nodes, + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ] + if max_matching_nodes == 0: + assert set(actual) == set(expected) + else: + assert set(actual) <= set(expected) + assert len(actual) == min(3, max_matching_nodes) + + def test_visit_nodes_filtered(graph_client): actual = list( graph_client.visit_nodes(