diff --git a/swh/graph/client.py b/swh/graph/client.py index f0f8edc..95379b1 100644 --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -1,119 +1,121 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from swh.core.api import RPCClient class GraphAPIError(Exception): """Graph API Error""" def __str__(self): return ('An unexpected error occurred in the Graph backend: {}' .format(self.args)) class RemoteGraphClient(RPCClient): """Client to the Software Heritage Graph.""" def __init__(self, url, timeout=None): super().__init__( api_exception=GraphAPIError, url=url, timeout=timeout) def raw_verb_lines(self, verb, endpoint, **kwargs): response = self.raw_verb(verb, endpoint, stream=True, **kwargs) self.raise_for_status(response) for line in response.iter_lines(): yield line.decode().lstrip('\n') def get_lines(self, endpoint, **kwargs): yield from self.raw_verb_lines('get', endpoint, **kwargs) # Web API endpoints def stats(self): return self.get('stats') def leaves(self, src, edges="*", direction="forward"): return self.get_lines( 'leaves/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def neighbors(self, src, edges="*", direction="forward"): return self.get_lines( 'neighbors/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def visit_nodes(self, src, edges="*", direction="forward"): return self.get_lines( 'visit/nodes/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def visit_paths(self, src, edges="*", direction="forward"): def decode_path_wrapper(it): for e in it: yield json.loads(e) return decode_path_wrapper( self.get_lines( 'visit/paths/{}'.format(src), params={ 'edges': edges, 'direction': direction })) - def walk(self, src, dst, - edges="*", traversal="dfs", direction="forward", last=False): - endpoint = 'walk/last/{}/{}' if last else 'walk/{}/{}' + def walk(self, src, dst, edges="*", traversal="dfs", + direction="forward", limit=None): + endpoint = 'walk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, 'traversal': traversal, - 'direction': direction + 'direction': direction, + 'limit': limit }) def random_walk(self, src, dst, - edges="*", direction="forward", last=False): - endpoint = 'randomwalk/last/{}/{}' if last else 'randomwalk/{}/{}' + edges="*", direction="forward", limit=None): + endpoint = 'randomwalk/{}/{}' return self.get_lines( endpoint.format(src, dst), params={ 'edges': edges, - 'direction': direction + 'direction': direction, + 'limit': limit }) def count_leaves(self, src, edges="*", direction="forward"): return self.get( 'leaves/count/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def count_neighbors(self, src, edges="*", direction="forward"): return self.get( 'neighbors/count/{}'.format(src), params={ 'edges': edges, 'direction': direction }) def count_visit_nodes(self, src, edges="*", direction="forward"): return self.get( 'visit/nodes/count/{}'.format(src), params={ 'edges': edges, 'direction': direction }) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index 9464c20..f659043 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,229 +1,247 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio import json import aiohttp.web +from collections import deque from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES from swh.model.exceptions import ValidationError try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration @asynccontextmanager async def stream_response(request, content_type='text/plain', *args, **kwargs): response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = content_type await response.prepare(request) yield response await response.write_eof() async def index(request): return aiohttp.web.Response( content_type='text/html', body="""
You have reached the Software Heritage graph API server.
See its API documentation for more information.
""") async def stats(request): stats = request.app['backend'].stats() return aiohttp.web.Response(body=stats, content_type='application/json') def get_direction(request): """validate HTTP query parameter `direction`""" s = request.query.get('direction', 'forward') if s not in ('forward', 'backward'): raise aiohttp.web.HTTPBadRequest(body=f'invalid direction: {s}') return s def get_edges(request): """validate HTTP query parameter `edges`, i.e., edge restrictions""" s = request.query.get('edges', '*') if any([node_type != '*' and node_type not in PID_TYPES - for edge in s.split(':') - for node_type in edge.split(',', maxsplit=1)]): + for edge in s.split(':') + for node_type in edge.split(',', maxsplit=1)]): raise aiohttp.web.HTTPBadRequest(body=f'invalid edge restriction: {s}') return s def get_traversal(request): """validate HTTP query parameter `traversal`, i.e., visit order""" s = request.query.get('traversal', 'dfs') if s not in ('bfs', 'dfs'): raise aiohttp.web.HTTPBadRequest(body=f'invalid traversal order: {s}') return s +def get_limit(request): + """validate HTTP query parameter `limit`, i.e., number of results""" + s = request.query.get('limit', '0') + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(body=f'invalid limit value: {s}') + + def node_of_pid(pid, backend): """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" try: return backend.pid2node[pid] except KeyError: raise aiohttp.web.HTTPNotFound(body=f'PID not found: {pid}') except ValidationError: raise aiohttp.web.HTTPBadRequest(body=f'malformed PID: {pid}') def pid_of_node(node, backend): """lookup a node in a node2pid map, failing in an HTTP-nice way if needed """ try: return backend.node2pid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( body=f'reverse lookup failed for node id: {node}') def get_simple_traversal_handler(ttype): async def simple_traversal(request): backend = request.app['backend'] src = request.match_info['src'] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) async with stream_response(request) as response: async for res_node in backend.simple_traversal( ttype, direction, edges, src_node ): res_pid = pid_of_node(res_node, backend) await response.write('{}\n'.format(res_pid).encode()) return response return simple_traversal -def get_walk_handler(random=False, last=False): +def get_walk_handler(random=False): async def walk(request): backend = request.app['backend'] src = request.match_info['src'] dst = request.match_info['dst'] edges = get_edges(request) direction = get_direction(request) algo = get_traversal(request) + limit = get_limit(request) src_node = node_of_pid(src, backend) if dst not in PID_TYPES: dst = node_of_pid(dst, backend) async with stream_response(request) as response: if random: it = backend.random_walk(direction, edges, RANDOM_RETRIES, src_node, dst) else: it = backend.walk(direction, edges, algo, src_node, dst) - res_node = None - async for res_node in it: - if not last: + + if limit < 0: + queue = deque(maxlen=-limit) + async for res_node in it: res_pid = pid_of_node(res_node, backend) - await response.write('{}\n'.format(res_pid).encode()) - if last and res_node is not None: - res_pid = pid_of_node(res_node, backend) - await response.write('{}\n'.format(res_pid).encode()) + queue.append('{}\n'.format(res_pid).encode()) + while queue: + await response.write(queue.popleft()) + else: + count = 0 + async for res_node in it: + if limit == 0 or count < limit: + res_pid = pid_of_node(res_node, backend) + await response.write('{}\n'.format(res_pid).encode()) + count += 1 + else: + break return response return walk async def visit_paths(request): backend = request.app['backend'] src = request.match_info['src'] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) it = backend.visit_paths(direction, edges, src_node) async with stream_response(request, content_type='application/x-ndjson') \ as response: async for res_path in it: res_path_pid = [pid_of_node(n, backend) for n in res_path] line = json.dumps(res_path_pid) await response.write('{}\n'.format(line).encode()) return response def get_count_handler(ttype): async def count(request): loop = asyncio.get_event_loop() backend = request.app['backend'] src = request.match_info['src'] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) cnt = await loop.run_in_executor( None, backend.count, ttype, direction, edges, src_node) return aiohttp.web.Response(body=str(cnt), content_type='application/json') return count def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_get('/', index) app.router.add_get('/graph', index) app.router.add_get('/graph/stats', stats) app.router.add_get('/graph/leaves/{src}', get_simple_traversal_handler('leaves')) app.router.add_get('/graph/neighbors/{src}', get_simple_traversal_handler('neighbors')) app.router.add_get('/graph/visit/nodes/{src}', get_simple_traversal_handler('visit_nodes')) app.router.add_get('/graph/visit/paths/{src}', visit_paths) # temporarily disabled in wait of a proper fix for T1969 # app.router.add_get('/graph/walk/{src}/{dst}', # get_walk_handler(random=False)) # app.router.add_get('/graph/walk/last/{src}/{dst}', # get_walk_handler(random=False, last=True)) app.router.add_get('/graph/randomwalk/{src}/{dst}', - get_walk_handler(random=True, last=False)) - app.router.add_get('/graph/randomwalk/last/{src}/{dst}', - get_walk_handler(random=True, last=True)) + get_walk_handler(random=True)) app.router.add_get('/graph/neighbors/count/{src}', get_count_handler('neighbors')) app.router.add_get('/graph/leaves/count/{src}', get_count_handler('leaves')) app.router.add_get('/graph/visit/nodes/count/{src}', get_count_handler('visit_nodes')) app['backend'] = backend return app diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index e0edbe3..610aaeb 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,201 +1,219 @@ import pytest from pytest import raises from swh.core.api import RemoteException def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {'counts', 'ratios', 'indegree', 'outdegree'} assert set(stats['counts'].keys()) == {'nodes', 'edges'} assert set(stats['ratios'].keys()) == {'compression', 'bits_per_node', 'bits_per_edge', 'avg_locality'} assert set(stats['indegree'].keys()) == {'min', 'max', 'avg'} assert set(stats['outdegree'].keys()) == {'min', 'max', 'avg'} assert stats['counts']['nodes'] == 21 assert stats['counts']['edges'] == 23 assert isinstance(stats['ratios']['compression'], float) assert isinstance(stats['ratios']['bits_per_node'], float) assert isinstance(stats['ratios']['bits_per_edge'], float) assert isinstance(stats['ratios']['avg_locality'], float) assert stats['indegree']['min'] == 0 assert stats['indegree']['max'] == 3 assert isinstance(stats['indegree']['avg'], float) assert stats['outdegree']['min'] == 0 assert stats['outdegree']['max'] == 3 assert isinstance(stats['outdegree']['avg'], float) def test_leaves(graph_client): actual = list(graph_client.leaves( 'swh:1:ori:0000000000000000000000000000000000000021' )) expected = [ 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:cnt:0000000000000000000000000000000000000004', 'swh:1:cnt:0000000000000000000000000000000000000005', 'swh:1:cnt:0000000000000000000000000000000000000007' ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list(graph_client.neighbors( 'swh:1:rev:0000000000000000000000000000000000000009', direction='backward' )) expected = [ 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000013' ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list(graph_client.visit_nodes( 'swh:1:rel:0000000000000000000000000000000000000010', edges='rel:rev,rev:rev' )) expected = [ 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003' ] assert set(actual) == set(expected) def test_visit_paths(graph_client): actual = list(graph_client.visit_paths( 'swh:1:snp:0000000000000000000000000000000000000020', edges='snp:*,rev:*')) actual = [tuple(path) for path in actual] expected = [ ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003', 'swh:1:dir:0000000000000000000000000000000000000002' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000008' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010' ) ] assert set(actual) == set(expected) @pytest.mark.skip(reason='currently disabled due to T1969') def test_walk(graph_client): args = ('swh:1:dir:0000000000000000000000000000000000000016', 'rel') kwargs = { 'edges': 'dir:dir,dir:rev,rev:*', 'direction': 'backward', 'traversal': 'bfs', } actual = list(graph_client.walk(*args, **kwargs)) expected = [ 'swh:1:dir:0000000000000000000000000000000000000016', 'swh:1:dir:0000000000000000000000000000000000000017', 'swh:1:rev:0000000000000000000000000000000000000018', 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) kwargs2 = kwargs.copy() - kwargs2['last'] = True + kwargs2['limit'] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + kwargs2 = kwargs.copy() + kwargs2['limit'] = 2 + actual = list(graph_client.walk(*args, **kwargs2)) + expected = [ + 'swh:1:dir:0000000000000000000000000000000000000016', + 'swh:1:dir:0000000000000000000000000000000000000017' + ] + assert set(actual) == set(expected) + def test_random_walk(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in the dataset, and only check the final node of the path (i.e., the origin) """ args = ('swh:1:cnt:0000000000000000000000000000000000000001', 'ori') kwargs = {'direction': 'backward'} expected_root = 'swh:1:ori:0000000000000000000000000000000000000021' actual = list(graph_client.random_walk(*args, **kwargs)) assert len(actual) > 1 # no origin directly links to a content assert actual[0] == args[0] assert actual[-1] == expected_root kwargs2 = kwargs.copy() - kwargs2['last'] = True + kwargs2['limit'] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] + kwargs2['limit'] = -2 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 2 + assert actual[-1] == expected_root + + kwargs2['limit'] = 3 + actual = list(graph_client.random_walk(*args, **kwargs2)) + assert len(actual) == 3 + def test_count(graph_client): actual = graph_client.count_leaves( 'swh:1:ori:0000000000000000000000000000000000000021' ) assert actual == 4 actual = graph_client.count_visit_nodes( 'swh:1:rel:0000000000000000000000000000000000000010', edges='rel:rev,rev:rev' ) assert actual == 3 actual = graph_client.count_neighbors( 'swh:1:rev:0000000000000000000000000000000000000009', direction='backward' ) assert actual == 3 def test_param_validation(graph_client): with raises(RemoteException) as exc_info: # PID not found list(graph_client.leaves( 'swh:1:ori:fff0000000000000000000000000000000000021')) assert exc_info.value.response.status_code == 404 with raises(RemoteException) as exc_info: # malformed PID list(graph_client.neighbors( 'swh:1:ori:fff000000zzzzzz0000000000000000000000021')) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed edge specificaiton list(graph_client.visit_nodes( 'swh:1:dir:0000000000000000000000000000000000000016', edges='dir:notanodetype,dir:rev,rev:*', direction='backward', )) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed direction list(graph_client.visit_nodes( 'swh:1:dir:0000000000000000000000000000000000000016', edges='dir:dir,dir:rev,rev:*', direction='notadirection', )) assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason='currently disabled due to T1969') def test_param_validation_walk(graph_client): """test validation of walk-specific parameters only""" with raises(RemoteException) as exc_info: # malformed traversal order list(graph_client.walk( 'swh:1:dir:0000000000000000000000000000000000000016', 'rel', edges='dir:dir,dir:rev,rev:*', direction='backward', traversal='notatraversalorder', )) assert exc_info.value.response.status_code == 400