diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -12,6 +12,7 @@ import json import aiohttp.web from collections import deque +from typing import Optional from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES @@ -28,15 +29,6 @@ RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration -@asynccontextmanager -async def stream_response(request, content_type="text/plain", *args, **kwargs): - response = aiohttp.web.StreamResponse(*args, **kwargs) - response.content_type = content_type - await response.prepare(request) - yield response - await response.write_eof() - - async def index(request): return aiohttp.web.Response( content_type="text/html", @@ -54,218 +46,266 @@ ) -async def stats(request): - stats = request.app["backend"].stats() - return aiohttp.web.Response(body=stats, content_type="application/json") +class GraphView(aiohttp.web.View): + """Base class for views working on the graph, with utility functions""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backend = self.request.app["backend"] + + def node_of_pid(self, pid): + """Lookup a PID in a pid2node map, failing in an HTTP-nice way if needed.""" + try: + return self.backend.pid2node[pid] + except KeyError: + raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") + except ValidationError: + raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") + + def pid_of_node(self, node): + """Lookup a node in a node2pid map, failing in an HTTP-nice way if needed.""" + try: + return self.backend.node2pid[node] + except KeyError: + raise aiohttp.web.HTTPInternalServerError( + body=f"reverse lookup failed for node id: {node}" + ) + + def get_direction(self): + """Validate HTTP query parameter `direction`""" + s = self.request.query.get("direction", "forward") + if s not in ("forward", "backward"): + raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") + return s + + def get_edges(self): + """Validate HTTP query parameter `edges`, i.e., edge restrictions""" + s = self.request.query.get("edges", "*") + if any( + [ + node_type != "*" and node_type not in PID_TYPES + for edge in s.split(":") + for node_type in edge.split(",", maxsplit=1) + ] + ): + raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") + return s + + def get_traversal(self): + """Validate HTTP query parameter `traversal`, i.e., visit order""" + s = self.request.query.get("traversal", "dfs") + if s not in ("bfs", "dfs"): + raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") + return s + + def get_limit(self): + """Validate HTTP query parameter `limit`, i.e., number of results""" + s = self.request.query.get("limit", "0") + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") + + +class StreamingGraphView(GraphView): + """Base class for views streaming their response line by line.""" + + content_type = "text/plain" + + @asynccontextmanager + async def response_streamer(self, *args, **kwargs): + """Context manager to prepare then close a StreamResponse""" + response = aiohttp.web.StreamResponse(*args, **kwargs) + response.content_type = self.content_type + await response.prepare(self.request) + yield response + await response.write_eof() + + async def get(self): + await self.prepare_response() + async with self.response_streamer() as self.response_stream: + await self.stream_response() + return self.response_stream + + async def prepare_response(self): + """This can be overridden with some setup to be run before the response + actually starts streaming. + """ + pass + + async def stream_response(self): + """Override this to perform the response streaming. Implementations of + this should await self.stream_line(line) to write each line. + """ + raise NotImplementedError + + async def stream_line(self, line): + """Write a line in the response stream.""" + await self.response_stream.write((line + "\n").encode()) + + +class StatsView(GraphView): + """View showing some statistics on the graph""" + + async def get(self): + stats = self.backend.stats() + return aiohttp.web.Response(body=stats, content_type="application/json") + + +class SimpleTraversalView(StreamingGraphView): + """Base class for views of simple traversals""" + + simple_traversal_type: Optional[str] = None + + async def prepare_response(self): + src = self.request.match_info["src"] + self.src_node = self.node_of_pid(src) + + self.edges = self.get_edges() + self.direction = self.get_direction() + + async def stream_response(self): + async for res_node in self.backend.simple_traversal( + self.simple_traversal_type, self.direction, self.edges, self.src_node + ): + res_pid = self.pid_of_node(res_node) + await self.stream_line(res_pid) -def get_direction(request): - """validate HTTP query parameter `direction`""" - s = request.query.get("direction", "forward") - if s not in ("forward", "backward"): - raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") - return s +class LeavesView(SimpleTraversalView): + simple_traversal_type = "leaves" + +class NeighborsView(SimpleTraversalView): + simple_traversal_type = "neighbors" -def get_edges(request): - """validate HTTP query parameter `edges`, i.e., edge restrictions""" - s = request.query.get("edges", "*") - if any( - [ - node_type != "*" and node_type not in PID_TYPES - for edge in s.split(":") - for node_type in edge.split(",", maxsplit=1) - ] - ): - raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") - return s - - -def get_traversal(request): - """validate HTTP query parameter `traversal`, i.e., visit order""" - s = request.query.get("traversal", "dfs") - if s not in ("bfs", "dfs"): - raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") - return s - - -def get_limit(request): - """validate HTTP query parameter `limit`, i.e., number of results""" - s = request.query.get("limit", "0") - try: - return int(s) - except ValueError: - raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") - - -def node_of_pid(pid, backend): - """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" - try: - return backend.pid2node[pid] - except KeyError: - raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") - except ValidationError: - raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") - - -def pid_of_node(node, backend): - """lookup a node in a node2pid map, failing in an HTTP-nice way if needed - - """ - try: - return backend.node2pid[node] - except KeyError: - raise aiohttp.web.HTTPInternalServerError( - body=f"reverse lookup failed for node id: {node}" - ) +class VisitNodesView(SimpleTraversalView): + simple_traversal_type = "visit_nodes" -def get_simple_traversal_handler(ttype): - async def simple_traversal(request): - backend = request.app["backend"] - src = request.match_info["src"] - edges = get_edges(request) - direction = get_direction(request) +class WalkView(StreamingGraphView): + async def prepare_response(self): + src = self.request.match_info["src"] + dst = self.request.match_info["dst"] + self.src_node = self.node_of_pid(src) + if dst not in PID_TYPES: + self.dst_thing = self.node_of_pid(dst) + else: + self.dst_thing = dst + + self.edges = self.get_edges() + self.direction = self.get_direction() + self.algo = self.get_traversal() + self.limit = self.get_limit() + + async def get_walk_iterator(self): + return self.backend.walk( + self.direction, self.edges, self.algo, self.src_node, self.dst_thing + ) - src_node = node_of_pid(src, backend) - async with stream_response(request) as response: - async for res_node in backend.simple_traversal( - ttype, direction, edges, src_node - ): - res_pid = pid_of_node(res_node, backend) - await response.write("{}\n".format(res_pid).encode()) - return response + async def stream_response(self): + it = self.get_walk_iterator() + if self.limit < 0: + queue = deque(maxlen=-self.limit) + async for res_node in it: + res_pid = self.pid_of_node(res_node) + queue.append(res_pid) + while queue: + await self.stream_line(queue.popleft()) + else: + count = 0 + async for res_node in it: + if self.limit == 0 or count < self.limit: + res_pid = self.pid_of_node(res_node) + await self.stream_line(res_pid) + count += 1 + else: + break + + +class RandomWalkView(WalkView): + def get_walk_iterator(self): + return self.backend.random_walk( + self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing + ) - return simple_traversal +class VisitEdgesView(SimpleTraversalView): + async def stream_response(self): + it = self.backend.visit_edges(self.direction, self.edges, self.src_node) + async for (res_src, res_dst) in it: + res_src_pid = self.pid_of_node(res_src) + res_dst_pid = self.pid_of_node(res_dst) + await self.stream_line("{} {}".format(res_src_pid, res_dst_pid)) -def get_walk_handler(random=False): - async def walk(request): - backend = request.app["backend"] - src = request.match_info["src"] - dst = request.match_info["dst"] - edges = get_edges(request) - direction = get_direction(request) - algo = get_traversal(request) - limit = get_limit(request) +class VisitPathsView(SimpleTraversalView): + content_type = "application/x-ndjson" - src_node = node_of_pid(src, backend) - if dst not in PID_TYPES: - dst = node_of_pid(dst, backend) - async with stream_response(request) as response: - if random: - it = backend.random_walk( - direction, edges, RANDOM_RETRIES, src_node, dst - ) - else: - it = backend.walk(direction, edges, algo, src_node, dst) - - if limit < 0: - queue = deque(maxlen=-limit) - async for res_node in it: - res_pid = pid_of_node(res_node, backend) - queue.append("{}\n".format(res_pid).encode()) - while queue: - await response.write(queue.popleft()) - else: - count = 0 - async for res_node in it: - if limit == 0 or count < limit: - res_pid = pid_of_node(res_node, backend) - await response.write("{}\n".format(res_pid).encode()) - count += 1 - else: - break - return response - - return walk - - -async def visit_paths(request): - backend = request.app["backend"] - - src = request.match_info["src"] - edges = get_edges(request) - direction = get_direction(request) - - src_node = node_of_pid(src, backend) - it = backend.visit_paths(direction, edges, src_node) - async with stream_response( - request, content_type="application/x-ndjson" - ) as response: + async def stream_response(self): + it = self.backend.visit_paths(self.direction, self.edges, self.src_node) async for res_path in it: - res_path_pid = [pid_of_node(n, backend) for n in res_path] + res_path_pid = [self.pid_of_node(n) for n in res_path] line = json.dumps(res_path_pid) - await response.write("{}\n".format(line).encode()) - return response + await self.stream_line(line) -async def visit_edges(request): - backend = request.app["backend"] +class CountView(GraphView): + """Base class for counting views.""" - src = request.match_info["src"] - edges = get_edges(request) - direction = get_direction(request) + count_type: Optional[str] = None - src_node = node_of_pid(src, backend) - it = backend.visit_edges(direction, edges, src_node) - print(it) - async with stream_response(request) as response: - async for (res_src, res_dst) in it: - res_src_pid = pid_of_node(res_src, backend) - res_dst_pid = pid_of_node(res_dst, backend) - await response.write("{} {}\n".format(res_src_pid, res_dst_pid).encode()) - return response + async def get(self): + src = self.request.match_info["src"] + self.src_node = self.node_of_pid(src) + self.edges = self.get_edges() + self.direction = self.get_direction() -def get_count_handler(ttype): - async def count(request): loop = asyncio.get_event_loop() - backend = request.app["backend"] - - src = request.match_info["src"] - edges = get_edges(request) - direction = get_direction(request) - - src_node = node_of_pid(src, backend) cnt = await loop.run_in_executor( - None, backend.count, ttype, direction, edges, src_node + None, + self.backend.count, + self.count_type, + self.direction, + self.edges, + self.src_node, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") - return count +class CountNeighborsView(CountView): + count_type = "neighbors" -def make_app(backend, **kwargs): - app = RPCServerApp(**kwargs) - app.router.add_get("/", index) - app.router.add_get("/graph", index) - app.router.add_get("/graph/stats", stats) - app.router.add_get("/graph/leaves/{src}", get_simple_traversal_handler("leaves")) - app.router.add_get( - "/graph/neighbors/{src}", get_simple_traversal_handler("neighbors") - ) - app.router.add_get( - "/graph/visit/nodes/{src}", get_simple_traversal_handler("visit_nodes") - ) - app.router.add_get("/graph/visit/edges/{src}", visit_edges) - app.router.add_get("/graph/visit/paths/{src}", visit_paths) +class CountLeavesView(CountView): + count_type = "leaves" - # temporarily disabled in wait of a proper fix for T1969 - # app.router.add_get('/graph/walk/{src}/{dst}', - # get_walk_handler(random=False)) - # app.router.add_get('/graph/walk/last/{src}/{dst}', - # get_walk_handler(random=False, last=True)) - app.router.add_get("/graph/randomwalk/{src}/{dst}", get_walk_handler(random=True)) +class CountVisitNodesView(CountView): + count_type = "visit_nodes" - app.router.add_get("/graph/neighbors/count/{src}", get_count_handler("neighbors")) - app.router.add_get("/graph/leaves/count/{src}", get_count_handler("leaves")) - app.router.add_get( - "/graph/visit/nodes/count/{src}", get_count_handler("visit_nodes") + +def make_app(backend, **kwargs): + app = RPCServerApp(**kwargs) + app.add_routes( + [ + aiohttp.web.get("/", index), + aiohttp.web.get("/graph", index), + aiohttp.web.view("/graph/stats", StatsView), + aiohttp.web.view("/graph/leaves/{src}", LeavesView), + aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), + aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), + aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), + aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), + # temporarily disabled in wait of a proper fix for T1969 + # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) + aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), + aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), + aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), + aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), + ] ) app["backend"] = backend