diff --git a/swh/graph/http_client.py b/swh/graph/http_client.py --- a/swh/graph/http_client.py +++ b/swh/graph/http_client.py @@ -52,7 +52,13 @@ return self.get("stats") def leaves( - self, src, edges="*", direction="forward", max_edges=0, return_types="*" + self, + src, + edges="*", + direction="forward", + max_edges=0, + return_types="*", + max_matching_nodes=0, ): return self.get_lines( "leaves/{}".format(src), @@ -61,6 +67,7 @@ "direction": direction, "max_edges": max_edges, "return_types": return_types, + "max_matching_nodes": max_matching_nodes, }, ) @@ -137,10 +144,14 @@ }, ) - def count_leaves(self, src, edges="*", direction="forward"): + def count_leaves(self, src, edges="*", direction="forward", max_matching_nodes=0): return self.get( "leaves/count/{}".format(src), - params={"edges": edges, "direction": direction}, + params={ + "edges": edges, + "direction": direction, + "max_matching_nodes": max_matching_nodes, + }, ) def count_neighbors(self, src, edges="*", direction="forward"): diff --git a/swh/graph/http_naive_client.py b/swh/graph/http_naive_client.py --- a/swh/graph/http_naive_client.py +++ b/swh/graph/http_naive_client.py @@ -1,10 +1,11 @@ -# Copyright (C) 2021 The Software Heritage developers +# Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import functools import inspect +import itertools import re import statistics from typing import ( @@ -150,9 +151,10 @@ direction: str = "forward", max_edges: int = 0, return_types: str = "*", + max_matching_nodes: int = 0, ) -> Iterator[str]: # TODO: max_edges - yield from filter_node_types( + leaves = filter_node_types( return_types, [ node @@ -161,6 +163,11 @@ ], ) + if max_matching_nodes > 0: + leaves = itertools.islice(leaves, max_matching_nodes) + + return leaves + @check_arguments def neighbors( self, @@ -250,9 +257,19 @@ @check_arguments def count_leaves( - self, src: str, edges: str = "*", direction: str = "forward" + self, + src: str, + edges: str = "*", + direction: str = "forward", + max_matching_nodes: int = 0, ) -> int: - return len(list(self.leaves(src, edges, direction))) + return len( + list( + self.leaves( + src, edges, direction, max_matching_nodes=max_matching_nodes + ) + ) + ) @check_arguments def count_neighbors( diff --git a/swh/graph/http_server.py b/swh/graph/http_server.py --- a/swh/graph/http_server.py +++ b/swh/graph/http_server.py @@ -141,13 +141,15 @@ else: return s - def get_limit(self): - """Validate HTTP query parameter `limit`, i.e., number of results""" - s = self.request.query.get("limit", "0") + def get_max_matching_nodes(self): + """Validate HTTP query parameter `max_matching_nodes`, i.e., number of results""" + s = self.request.query.get("max_matching_nodes", "0") try: return int(s) except ValueError: - raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") + raise aiohttp.web.HTTPBadRequest( + text=f"invalid max_matching_nodes value: {s}" + ) def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., @@ -246,6 +248,7 @@ direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), + max_matching_nodes=self.get_max_matching_nodes(), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() @@ -304,6 +307,7 @@ direction=self.get_direction(), return_nodes=NodeFilter(types=self.get_return_types()), mask=FieldMask(paths=["swhid"]), + max_matching_nodes=self.get_max_matching_nodes(), ) if self.get_max_edges(): self.traversal_request.max_edges = self.get_max_edges() diff --git a/swh/graph/tests/test_http_client.py b/swh/graph/tests/test_http_client.py --- a/swh/graph/tests/test_http_client.py +++ b/swh/graph/tests/test_http_client.py @@ -43,6 +43,25 @@ assert set(actual) == set(expected) +@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31]) +def test_leaves_with_limit(graph_client, max_matching_nodes): + actual = list( + graph_client.leaves(TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes) + ) + expected = [ + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + "swh:1:cnt:0000000000000000000000000000000000000007", + ] + + if max_matching_nodes == 0: + assert set(actual) == set(expected) + else: + assert set(actual) <= set(expected) + assert len(actual) == min(4, max_matching_nodes) + + def test_neighbors(graph_client): actual = list( graph_client.neighbors( @@ -326,6 +345,17 @@ assert actual == 3 +@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31]) +def test_count_with_limit(graph_client, max_matching_nodes): + actual = graph_client.count_leaves( + TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes + ) + if max_matching_nodes == 0: + assert actual == 4 + else: + assert actual == min(4, max_matching_nodes) + + def test_param_validation(graph_client): with raises(GraphArgumentException) as exc_info: # SWHID not found list(graph_client.leaves("swh:1:rel:00ffffffff000000000000000000000000000010"))