diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -14,6 +14,7 @@ from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES +from swh.model.exceptions import ValidationError try: from contextlib import asynccontextmanager @@ -53,20 +54,67 @@ return aiohttp.web.Response(body=stats, content_type='application/json') +def get_direction(request): + """validate HTTP query parameter `direction`""" + s = request.query.get('direction', 'forward') + if s not in ('forward', 'backward'): + raise aiohttp.web.HTTPBadRequest(body=f'invalid direction: {s}') + return s + + +def get_edges(request): + """validate HTTP query parameter `edges`, i.e., edge restrictions""" + s = request.query.get('edges', '*') + if any([node_type != '*' and node_type not in PID_TYPES + for edge in s.split(':') + for node_type in edge.split(',', maxsplit=1)]): + raise aiohttp.web.HTTPBadRequest(body=f'invalid edge restriction: {s}') + return s + + +def get_traversal(request): + """validate HTTP query parameter `traversal`, i.e., visit order""" + s = request.query.get('traversal', 'dfs') + if s not in ('bfs', 'dfs'): + raise aiohttp.web.HTTPBadRequest(body=f'invalid traversal order: {s}') + return s + + +def node_of_pid(pid, backend): + """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" + try: + return backend.pid2node[pid] + except KeyError: + raise aiohttp.web.HTTPNotFound(body=f'PID not found: {pid}') + except ValidationError: + raise aiohttp.web.HTTPBadRequest(body=f'malformed PID: {pid}') + + +def pid_of_node(node, backend): + """lookup a node in a node2pid map, failing in an HTTP-nice way if needed + + """ + try: + return backend.node2pid[node] + except KeyError: + raise aiohttp.web.HTTPInternalServerError( + body=f'reverse lookup failed for node id: {node}') + + def get_simple_traversal_handler(ttype): async def simple_traversal(request): backend = request.app['backend'] src = request.match_info['src'] - edges = request.query.get('edges', '*') - direction = request.query.get('direction', 'forward') + edges = get_edges(request) + direction = get_direction(request) - src_node = backend.pid2node[src] + src_node = node_of_pid(src, backend) async with stream_response(request) as response: async for res_node in backend.simple_traversal( ttype, direction, edges, src_node ): - res_pid = backend.node2pid[res_node] + res_pid = pid_of_node(res_node, backend) await response.write('{}\n'.format(res_pid).encode()) return response @@ -78,18 +126,18 @@ src = request.match_info['src'] dst = request.match_info['dst'] - edges = request.query.get('edges', '*') - direction = request.query.get('direction', 'forward') - algo = request.query.get('traversal', 'dfs') + edges = get_edges(request) + direction = get_direction(request) + algo = get_traversal(request) - src_node = backend.pid2node[src] + src_node = node_of_pid(src, backend) if dst not in PID_TYPES: - dst = backend.pid2node[dst] + dst = backend.node_of_pid(dst, backend) async with stream_response(request) as response: async for res_node in backend.walk( direction, edges, algo, src_node, dst ): - res_pid = backend.node2pid[res_node] + res_pid = pid_of_node(res_node, backend) await response.write('{}\n'.format(res_pid).encode()) return response @@ -98,15 +146,15 @@ backend = request.app['backend'] src = request.match_info['src'] - edges = request.query.get('edges', '*') - direction = request.query.get('direction', 'forward') + edges = get_edges(request) + direction = get_direction(request) - src_node = backend.pid2node[src] + src_node = node_of_pid(src, backend) it = backend.visit_paths(direction, edges, src_node) async with stream_response(request, content_type='application/x-ndjson') \ as response: async for res_path in it: - res_path_pid = [backend.node2pid[n] for n in res_path] + res_path_pid = [pid_of_node(n, backend) for n in res_path] line = json.dumps(res_path_pid) await response.write('{}\n'.format(line).encode()) return response @@ -118,10 +166,10 @@ backend = request.app['backend'] src = request.match_info['src'] - edges = request.query.get('edges', '*') - direction = request.query.get('direction', 'forward') + edges = get_edges(request) + direction = get_direction(request) - src_node = backend.pid2node[src] + src_node = node_of_pid(src, backend) cnt = await loop.run_in_executor( None, backend.count, ttype, direction, edges, src_node) return aiohttp.web.Response(body=str(cnt),