Page MenuHomeSoftware Heritage

D3604.diff
No OneTemporary

D3604.diff

diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -12,6 +12,7 @@
import json
import aiohttp.web
from collections import deque
+from typing import Optional
from swh.core.api.asynchronous import RPCServerApp
from swh.model.identifiers import PID_TYPES
@@ -28,15 +29,6 @@
RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration
-@asynccontextmanager
-async def stream_response(request, content_type="text/plain", *args, **kwargs):
- response = aiohttp.web.StreamResponse(*args, **kwargs)
- response.content_type = content_type
- await response.prepare(request)
- yield response
- await response.write_eof()
-
-
async def index(request):
return aiohttp.web.Response(
content_type="text/html",
@@ -54,217 +46,266 @@
)
-async def stats(request):
- stats = request.app["backend"].stats()
- return aiohttp.web.Response(body=stats, content_type="application/json")
+class GraphView(aiohttp.web.View):
+ """Base class for views working on the graph, with utility functions"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.backend = self.request.app["backend"]
+
+ def node_of_pid(self, pid):
+ """Lookup a PID in a pid2node map, failing in an HTTP-nice way if needed."""
+ try:
+ return self.backend.pid2node[pid]
+ except KeyError:
+ raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}")
+ except ValidationError:
+ raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}")
+
+ def pid_of_node(self, node):
+ """Lookup a node in a node2pid map, failing in an HTTP-nice way if needed."""
+ try:
+ return self.backend.node2pid[node]
+ except KeyError:
+ raise aiohttp.web.HTTPInternalServerError(
+ body=f"reverse lookup failed for node id: {node}"
+ )
+
+ def get_direction(self):
+ """Validate HTTP query parameter `direction`"""
+ s = self.request.query.get("direction", "forward")
+ if s not in ("forward", "backward"):
+ raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}")
+ return s
+
+ def get_edges(self):
+ """Validate HTTP query parameter `edges`, i.e., edge restrictions"""
+ s = self.request.query.get("edges", "*")
+ if any(
+ [
+ node_type != "*" and node_type not in PID_TYPES
+ for edge in s.split(":")
+ for node_type in edge.split(",", maxsplit=1)
+ ]
+ ):
+ raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}")
+ return s
+
+ def get_traversal(self):
+ """Validate HTTP query parameter `traversal`, i.e., visit order"""
+ s = self.request.query.get("traversal", "dfs")
+ if s not in ("bfs", "dfs"):
+ raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}")
+ return s
+
+ def get_limit(self):
+ """Validate HTTP query parameter `limit`, i.e., number of results"""
+ s = self.request.query.get("limit", "0")
+ try:
+ return int(s)
+ except ValueError:
+ raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}")
+
+
+class StreamingGraphView(GraphView):
+ """Base class for views streaming their response line by line."""
+
+ content_type = "text/plain"
+
+ @asynccontextmanager
+ async def response_streamer(self, *args, **kwargs):
+ """Context manager to prepare then close a StreamResponse"""
+ response = aiohttp.web.StreamResponse(*args, **kwargs)
+ response.content_type = self.content_type
+ await response.prepare(self.request)
+ yield response
+ await response.write_eof()
+
+ async def get(self):
+ await self.prepare_response()
+ async with self.response_streamer() as self.response_stream:
+ await self.stream_response()
+ return self.response_stream
+
+ async def prepare_response(self):
+ """This can be overridden with some setup to be run before the response
+ actually starts streaming.
+ """
+ pass
+
+ async def stream_response(self):
+ """Override this to perform the response streaming. Implementations of
+ this should await self.stream_line(line) to write each line.
+ """
+ raise NotImplementedError
+
+ async def stream_line(self, line):
+ """Write a line in the response stream."""
+ await self.response_stream.write((line + "\n").encode())
+
+
+class StatsView(GraphView):
+ """View showing some statistics on the graph"""
+
+ async def get(self):
+ stats = self.backend.stats()
+ return aiohttp.web.Response(body=stats, content_type="application/json")
+
+
+class SimpleTraversalView(StreamingGraphView):
+ """Base class for views of simple traversals"""
+
+ simple_traversal_type: Optional[str] = None
+
+ async def prepare_response(self):
+ src = self.request.match_info["src"]
+ self.src_node = self.node_of_pid(src)
+
+ self.edges = self.get_edges()
+ self.direction = self.get_direction()
+
+ async def stream_response(self):
+ async for res_node in self.backend.simple_traversal(
+ self.simple_traversal_type, self.direction, self.edges, self.src_node
+ ):
+ res_pid = self.pid_of_node(res_node)
+ await self.stream_line(res_pid)
-def get_direction(request):
- """validate HTTP query parameter `direction`"""
- s = request.query.get("direction", "forward")
- if s not in ("forward", "backward"):
- raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}")
- return s
+class LeavesView(SimpleTraversalView):
+ simple_traversal_type = "leaves"
+
+class NeighborsView(SimpleTraversalView):
+ simple_traversal_type = "neighbors"
-def get_edges(request):
- """validate HTTP query parameter `edges`, i.e., edge restrictions"""
- s = request.query.get("edges", "*")
- if any(
- [
- node_type != "*" and node_type not in PID_TYPES
- for edge in s.split(":")
- for node_type in edge.split(",", maxsplit=1)
- ]
- ):
- raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}")
- return s
-
-
-def get_traversal(request):
- """validate HTTP query parameter `traversal`, i.e., visit order"""
- s = request.query.get("traversal", "dfs")
- if s not in ("bfs", "dfs"):
- raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}")
- return s
-
-
-def get_limit(request):
- """validate HTTP query parameter `limit`, i.e., number of results"""
- s = request.query.get("limit", "0")
- try:
- return int(s)
- except ValueError:
- raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}")
-
-
-def node_of_pid(pid, backend):
- """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed"""
- try:
- return backend.pid2node[pid]
- except KeyError:
- raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}")
- except ValidationError:
- raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}")
-
-
-def pid_of_node(node, backend):
- """lookup a node in a node2pid map, failing in an HTTP-nice way if needed
-
- """
- try:
- return backend.node2pid[node]
- except KeyError:
- raise aiohttp.web.HTTPInternalServerError(
- body=f"reverse lookup failed for node id: {node}"
- )
+class VisitNodesView(SimpleTraversalView):
+ simple_traversal_type = "visit_nodes"
-def get_simple_traversal_handler(ttype):
- async def simple_traversal(request):
- backend = request.app["backend"]
- src = request.match_info["src"]
- edges = get_edges(request)
- direction = get_direction(request)
+class WalkView(StreamingGraphView):
+ async def prepare_response(self):
+ src = self.request.match_info["src"]
+ dst = self.request.match_info["dst"]
+ self.src_node = self.node_of_pid(src)
+ if dst not in PID_TYPES:
+ self.dst_thing = self.node_of_pid(dst)
+ else:
+ self.dst_thing = dst
+
+ self.edges = self.get_edges()
+ self.direction = self.get_direction()
+ self.algo = self.get_traversal()
+ self.limit = self.get_limit()
+
+ async def get_walk_iterator(self):
+ return self.backend.walk(
+ self.direction, self.edges, self.algo, self.src_node, self.dst_thing
+ )
- src_node = node_of_pid(src, backend)
- async with stream_response(request) as response:
- async for res_node in backend.simple_traversal(
- ttype, direction, edges, src_node
- ):
- res_pid = pid_of_node(res_node, backend)
- await response.write("{}\n".format(res_pid).encode())
- return response
+ async def stream_response(self):
+ it = self.get_walk_iterator()
+ if self.limit < 0:
+ queue = deque(maxlen=-self.limit)
+ async for res_node in it:
+ res_pid = self.pid_of_node(res_node)
+ queue.append(res_pid)
+ while queue:
+ await self.stream_line(queue.popleft())
+ else:
+ count = 0
+ async for res_node in it:
+ if self.limit == 0 or count < self.limit:
+ res_pid = self.pid_of_node(res_node)
+ await self.stream_line(res_pid)
+ count += 1
+ else:
+ break
+
+
+class RandomWalkView(WalkView):
+ def get_walk_iterator(self):
+ return self.backend.random_walk(
+ self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing
+ )
- return simple_traversal
+class VisitEdgesView(SimpleTraversalView):
+ async def stream_response(self):
+ it = self.backend.visit_edges(self.direction, self.edges, self.src_node)
+ async for (res_src, res_dst) in it:
+ res_src_pid = self.pid_of_node(res_src)
+ res_dst_pid = self.pid_of_node(res_dst)
+ await self.stream_line("{} {}".format(res_src_pid, res_dst_pid))
-def get_walk_handler(random=False):
- async def walk(request):
- backend = request.app["backend"]
- src = request.match_info["src"]
- dst = request.match_info["dst"]
- edges = get_edges(request)
- direction = get_direction(request)
- algo = get_traversal(request)
- limit = get_limit(request)
+class VisitPathsView(SimpleTraversalView):
+ content_type = "application/x-ndjson"
- src_node = node_of_pid(src, backend)
- if dst not in PID_TYPES:
- dst = node_of_pid(dst, backend)
- async with stream_response(request) as response:
- if random:
- it = backend.random_walk(
- direction, edges, RANDOM_RETRIES, src_node, dst
- )
- else:
- it = backend.walk(direction, edges, algo, src_node, dst)
-
- if limit < 0:
- queue = deque(maxlen=-limit)
- async for res_node in it:
- res_pid = pid_of_node(res_node, backend)
- queue.append("{}\n".format(res_pid).encode())
- while queue:
- await response.write(queue.popleft())
- else:
- count = 0
- async for res_node in it:
- if limit == 0 or count < limit:
- res_pid = pid_of_node(res_node, backend)
- await response.write("{}\n".format(res_pid).encode())
- count += 1
- else:
- break
- return response
-
- return walk
-
-
-async def visit_paths(request):
- backend = request.app["backend"]
-
- src = request.match_info["src"]
- edges = get_edges(request)
- direction = get_direction(request)
-
- src_node = node_of_pid(src, backend)
- it = backend.visit_paths(direction, edges, src_node)
- async with stream_response(
- request, content_type="application/x-ndjson"
- ) as response:
+ async def stream_response(self):
+ it = self.backend.visit_paths(self.direction, self.edges, self.src_node)
async for res_path in it:
- res_path_pid = [pid_of_node(n, backend) for n in res_path]
+ res_path_pid = [self.pid_of_node(n) for n in res_path]
line = json.dumps(res_path_pid)
- await response.write("{}\n".format(line).encode())
- return response
+ await self.stream_line(line)
-async def visit_edges(request):
- backend = request.app["backend"]
+class CountView(GraphView):
+ """Base class for counting views."""
- src = request.match_info["src"]
- edges = get_edges(request)
- direction = get_direction(request)
+ count_type: Optional[str] = None
- src_node = node_of_pid(src, backend)
- it = backend.visit_edges(direction, edges, src_node)
- async with stream_response(request) as response:
- async for (res_src, res_dst) in it:
- res_src_pid = pid_of_node(res_src, backend)
- res_dst_pid = pid_of_node(res_dst, backend)
- await response.write("{} {}\n".format(res_src_pid, res_dst_pid).encode())
- return response
+ async def get(self):
+ src = self.request.match_info["src"]
+ self.src_node = self.node_of_pid(src)
+ self.edges = self.get_edges()
+ self.direction = self.get_direction()
-def get_count_handler(ttype):
- async def count(request):
loop = asyncio.get_event_loop()
- backend = request.app["backend"]
-
- src = request.match_info["src"]
- edges = get_edges(request)
- direction = get_direction(request)
-
- src_node = node_of_pid(src, backend)
cnt = await loop.run_in_executor(
- None, backend.count, ttype, direction, edges, src_node
+ None,
+ self.backend.count,
+ self.count_type,
+ self.direction,
+ self.edges,
+ self.src_node,
)
return aiohttp.web.Response(body=str(cnt), content_type="application/json")
- return count
+class CountNeighborsView(CountView):
+ count_type = "neighbors"
-def make_app(backend, **kwargs):
- app = RPCServerApp(**kwargs)
- app.router.add_get("/", index)
- app.router.add_get("/graph", index)
- app.router.add_get("/graph/stats", stats)
- app.router.add_get("/graph/leaves/{src}", get_simple_traversal_handler("leaves"))
- app.router.add_get(
- "/graph/neighbors/{src}", get_simple_traversal_handler("neighbors")
- )
- app.router.add_get(
- "/graph/visit/nodes/{src}", get_simple_traversal_handler("visit_nodes")
- )
- app.router.add_get("/graph/visit/edges/{src}", visit_edges)
- app.router.add_get("/graph/visit/paths/{src}", visit_paths)
+class CountLeavesView(CountView):
+ count_type = "leaves"
- # temporarily disabled in wait of a proper fix for T1969
- # app.router.add_get('/graph/walk/{src}/{dst}',
- # get_walk_handler(random=False))
- # app.router.add_get('/graph/walk/last/{src}/{dst}',
- # get_walk_handler(random=False, last=True))
- app.router.add_get("/graph/randomwalk/{src}/{dst}", get_walk_handler(random=True))
+class CountVisitNodesView(CountView):
+ count_type = "visit_nodes"
- app.router.add_get("/graph/neighbors/count/{src}", get_count_handler("neighbors"))
- app.router.add_get("/graph/leaves/count/{src}", get_count_handler("leaves"))
- app.router.add_get(
- "/graph/visit/nodes/count/{src}", get_count_handler("visit_nodes")
+
+def make_app(backend, **kwargs):
+ app = RPCServerApp(**kwargs)
+ app.add_routes(
+ [
+ aiohttp.web.get("/", index),
+ aiohttp.web.get("/graph", index),
+ aiohttp.web.view("/graph/stats", StatsView),
+ aiohttp.web.view("/graph/leaves/{src}", LeavesView),
+ aiohttp.web.view("/graph/neighbors/{src}", NeighborsView),
+ aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView),
+ aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView),
+ aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView),
+ # temporarily disabled in wait of a proper fix for T1969
+ # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView)
+ aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView),
+ aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView),
+ aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView),
+ aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView),
+ ]
)
app["backend"] = backend

File Metadata

Mime Type
text/plain
Expires
Mon, Aug 18, 12:01 AM (1 w, 5 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3222861

Event Timeline