diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -73,25 +73,27 @@ 'direction': direction })) - def walk(self, src, dst, - edges="*", traversal="dfs", direction="forward", last=False): - endpoint = 'walk/last/{}/{}' if last else 'walk/{}/{}' + def walk(self, src, dst, edges="*", traversal="dfs", + direction="forward", limit=None): + endpoint = 'walk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, 'traversal': traversal, - 'direction': direction + 'direction': direction, + 'limit': limit }) def random_walk(self, src, dst, - edges="*", direction="forward", last=False): - endpoint = 'randomwalk/last/{}/{}' if last else 'randomwalk/{}/{}' + edges="*", direction="forward", limit=None): + endpoint = 'randomwalk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, - 'direction': direction + 'direction': direction, + 'limit': limit }) def count_leaves(self, src, edges="*", direction="forward"): diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -11,6 +11,7 @@ import asyncio import json import aiohttp.web +from collections import deque from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES @@ -70,8 +71,8 @@ """validate HTTP query parameter `edges`, i.e., edge restrictions""" s = request.query.get('edges', '*') if any([node_type != '*' and node_type not in PID_TYPES - for edge in s.split(':') - for node_type in edge.split(',', maxsplit=1)]): + for edge in s.split(':') + for node_type in edge.split(',', maxsplit=1)]): raise aiohttp.web.HTTPBadRequest(body=f'invalid edge restriction: {s}') return s @@ -84,6 +85,15 @@ return s +def get_limit(request): + """validate HTTP query parameter `limit`, i.e., number of results""" + s = request.query.get('limit', '0') + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(body=f'invalid limit value: {s}') + + def node_of_pid(pid, backend): """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" try: @@ -125,7 +135,7 @@ return simple_traversal -def get_walk_handler(random=False, last=False): +def get_walk_handler(random=False): async def walk(request): backend = request.app['backend'] @@ -134,6 +144,7 @@ edges = get_edges(request) direction = get_direction(request) algo = get_traversal(request) + limit = get_limit(request) src_node = node_of_pid(src, backend) if dst not in PID_TYPES: @@ -144,14 +155,23 @@ src_node, dst) else: it = backend.walk(direction, edges, algo, src_node, dst) - res_node = None - async for res_node in it: - if not last: + + if limit < 0: + queue = deque(maxlen=-limit) + async for res_node in it: res_pid = pid_of_node(res_node, backend) - await response.write('{}\n'.format(res_pid).encode()) - if last and res_node is not None: - res_pid = pid_of_node(res_node, backend) - await response.write('{}\n'.format(res_pid).encode()) + queue.append('{}\n'.format(res_pid).encode()) + while queue: + await response.write(queue.popleft()) + else: + count = 0 + async for res_node in it: + if limit == 0 or count < limit: + res_pid = pid_of_node(res_node, backend) + await response.write('{}\n'.format(res_pid).encode()) + count += 1 + else: + break return response return walk @@ -214,9 +234,7 @@ # get_walk_handler(random=False, last=True)) app.router.add_get('/graph/randomwalk/{src}/{dst}', - get_walk_handler(random=True, last=False)) - app.router.add_get('/graph/randomwalk/last/{src}/{dst}', - get_walk_handler(random=True, last=True)) + get_walk_handler(random=True)) app.router.add_get('/graph/neighbors/count/{src}', get_count_handler('neighbors')) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -115,13 +115,22 @@ assert set(actual) == set(expected) kwargs2 = kwargs.copy() - kwargs2['last'] = True + kwargs2['limit'] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + kwargs2 = kwargs.copy() + kwargs2['limit'] = 2 + actual = list(graph_client.walk(*args, **kwargs2)) + expected = [ + 'swh:1:dir:0000000000000000000000000000000000000016', + 'swh:1:dir:0000000000000000000000000000000000000017' + ] + assert set(actual) == set(expected) + def test_random_walk(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in @@ -138,10 +147,19 @@ assert actual[-1] == expected_root kwargs2 = kwargs.copy() - kwargs2['last'] = True + kwargs2['limit'] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] + kwargs2['limit'] = -2 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 2 + assert actual[-1] == expected_root + + kwargs2['limit'] = 3 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 3 + def test_count(graph_client): actual = graph_client.count_leaves(