diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -73,25 +73,27 @@ 'direction': direction })) - def walk(self, src, dst, - edges="*", traversal="dfs", direction="forward", last=False): + def walk(self, src, dst, edges="*", traversal="dfs", + direction="forward", last=False, limit=None): endpoint = 'walk/last/{}/{}' if last else 'walk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, 'traversal': traversal, - 'direction': direction + 'direction': direction, + 'limit': limit }) def random_walk(self, src, dst, - edges="*", direction="forward", last=False): + edges="*", direction="forward", last=False, limit=None): endpoint = 'randomwalk/last/{}/{}' if last else 'randomwalk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, - 'direction': direction + 'direction': direction, + 'limit': limit }) def count_leaves(self, src, edges="*", direction="forward"): diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -70,8 +70,8 @@ """validate HTTP query parameter `edges`, i.e., edge restrictions""" s = request.query.get('edges', '*') if any([node_type != '*' and node_type not in PID_TYPES - for edge in s.split(':') - for node_type in edge.split(',', maxsplit=1)]): + for edge in s.split(':') + for node_type in edge.split(',', maxsplit=1)]): raise aiohttp.web.HTTPBadRequest(body=f'invalid edge restriction: {s}') return s @@ -84,6 +84,15 @@ return s +def get_limit(request): + """validate HTTP query parameter `limit`, i.e., number of results""" + s = request.query.get('limit', -1) + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(body=f'invalid limit value: {s}') + + def node_of_pid(pid, backend): """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" try: @@ -134,6 +143,7 @@ edges = get_edges(request) direction = get_direction(request) algo = get_traversal(request) + limit = get_limit(request) src_node = node_of_pid(src, backend) if dst not in PID_TYPES: @@ -145,10 +155,14 @@ else: it = backend.walk(direction, edges, algo, src_node, dst) res_node = None + count = 0 + async for res_node in it: - if not last: + if not last and (count < limit or limit < 0): res_pid = pid_of_node(res_node, backend) await response.write('{}\n'.format(res_pid).encode()) + count += 1 + assert count == limit if limit >= 0 else True if last and res_node is not None: res_pid = pid_of_node(res_node, backend) await response.write('{}\n'.format(res_pid).encode()) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -122,6 +122,15 @@ ] assert set(actual) == set(expected) + kwargs2 = kwargs.copy() + kwargs2['limit'] = 2 + actual = list(graph_client.walk(*args, **kwargs2)) + expected = [ + 'swh:1:dir:0000000000000000000000000000000000000016', + 'swh:1:dir:0000000000000000000000000000000000000017' + ] + assert set(actual) == set(expected) + def test_random_walk(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in @@ -142,6 +151,11 @@ actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] + kwargs3 = kwargs.copy() + kwargs3['limit'] = 3 + actual = list(graph_client.random_walk(*args, **kwargs3)) + assert len(actual) == 3 + def test_count(graph_client): actual = graph_client.count_leaves(