Changeset View
Changeset View
Standalone View
Standalone View
swh/graph/server/app.py
# Copyright (C) 2019 The Software Heritage developers | # Copyright (C) 2019 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
""" | """ | ||||
A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using | A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using | ||||
FIFO as a transport to stream integers between the two languages. | FIFO as a transport to stream integers between the two languages. | ||||
""" | """ | ||||
import asyncio | import asyncio | ||||
import json | import json | ||||
import aiohttp.web | import aiohttp.web | ||||
from collections import deque | from collections import deque | ||||
from typing import Optional | |||||
from swh.core.api.asynchronous import RPCServerApp | from swh.core.api.asynchronous import RPCServerApp | ||||
from swh.model.identifiers import PID_TYPES | from swh.model.identifiers import PID_TYPES | ||||
from swh.model.exceptions import ValidationError | from swh.model.exceptions import ValidationError | ||||
try: | try: | ||||
from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||
except ImportError: | except ImportError: | ||||
# Compatibility with 3.6 backport | # Compatibility with 3.6 backport | ||||
from async_generator import asynccontextmanager # type: ignore | from async_generator import asynccontextmanager # type: ignore | ||||
# maximum number of retries for random walks | # maximum number of retries for random walks | ||||
RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration | RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration | ||||
@asynccontextmanager | |||||
async def stream_response(request, content_type="text/plain", *args, **kwargs): | |||||
response = aiohttp.web.StreamResponse(*args, **kwargs) | |||||
response.content_type = content_type | |||||
await response.prepare(request) | |||||
yield response | |||||
await response.write_eof() | |||||
async def index(request): | async def index(request): | ||||
return aiohttp.web.Response( | return aiohttp.web.Response( | ||||
content_type="text/html", | content_type="text/html", | ||||
body="""<html> | body="""<html> | ||||
<head><title>Software Heritage storage server</title></head> | <head><title>Software Heritage storage server</title></head> | ||||
<body> | <body> | ||||
<p>You have reached the <a href="https://www.softwareheritage.org/"> | <p>You have reached the <a href="https://www.softwareheritage.org/"> | ||||
Software Heritage</a> graph API server.</p> | Software Heritage</a> graph API server.</p> | ||||
<p>See its | <p>See its | ||||
<a href="https://docs.softwareheritage.org/devel/swh-graph/api.html">API | <a href="https://docs.softwareheritage.org/devel/swh-graph/api.html">API | ||||
documentation</a> for more information.</p> | documentation</a> for more information.</p> | ||||
</body> | </body> | ||||
</html>""", | </html>""", | ||||
) | ) | ||||
async def stats(request): | class GraphView(aiohttp.web.View): | ||||
stats = request.app["backend"].stats() | """Base class for views working on the graph, with utility functions""" | ||||
return aiohttp.web.Response(body=stats, content_type="application/json") | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.backend = self.request.app["backend"] | |||||
def node_of_pid(self, pid): | |||||
"""Lookup a PID in a pid2node map, failing in an HTTP-nice way if needed.""" | |||||
try: | |||||
return self.backend.pid2node[pid] | |||||
except KeyError: | |||||
raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") | |||||
except ValidationError: | |||||
raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") | |||||
def get_direction(request): | def pid_of_node(self, node): | ||||
"""validate HTTP query parameter `direction`""" | """Lookup a node in a node2pid map, failing in an HTTP-nice way if needed.""" | ||||
s = request.query.get("direction", "forward") | try: | ||||
return self.backend.node2pid[node] | |||||
except KeyError: | |||||
raise aiohttp.web.HTTPInternalServerError( | |||||
body=f"reverse lookup failed for node id: {node}" | |||||
) | |||||
def get_direction(self): | |||||
"""Validate HTTP query parameter `direction`""" | |||||
s = self.request.query.get("direction", "forward") | |||||
if s not in ("forward", "backward"): | if s not in ("forward", "backward"): | ||||
raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") | raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") | ||||
return s | return s | ||||
def get_edges(self): | |||||
def get_edges(request): | """Validate HTTP query parameter `edges`, i.e., edge restrictions""" | ||||
"""validate HTTP query parameter `edges`, i.e., edge restrictions""" | s = self.request.query.get("edges", "*") | ||||
s = request.query.get("edges", "*") | |||||
if any( | if any( | ||||
[ | [ | ||||
node_type != "*" and node_type not in PID_TYPES | node_type != "*" and node_type not in PID_TYPES | ||||
for edge in s.split(":") | for edge in s.split(":") | ||||
for node_type in edge.split(",", maxsplit=1) | for node_type in edge.split(",", maxsplit=1) | ||||
] | ] | ||||
): | ): | ||||
raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") | raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") | ||||
return s | return s | ||||
def get_traversal(self): | |||||
def get_traversal(request): | """Validate HTTP query parameter `traversal`, i.e., visit order""" | ||||
"""validate HTTP query parameter `traversal`, i.e., visit order""" | s = self.request.query.get("traversal", "dfs") | ||||
s = request.query.get("traversal", "dfs") | |||||
if s not in ("bfs", "dfs"): | if s not in ("bfs", "dfs"): | ||||
raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") | raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") | ||||
return s | return s | ||||
def get_limit(self): | |||||
def get_limit(request): | """Validate HTTP query parameter `limit`, i.e., number of results""" | ||||
"""validate HTTP query parameter `limit`, i.e., number of results""" | s = self.request.query.get("limit", "0") | ||||
s = request.query.get("limit", "0") | |||||
try: | try: | ||||
return int(s) | return int(s) | ||||
except ValueError: | except ValueError: | ||||
raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") | raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") | ||||
def node_of_pid(pid, backend): | class StreamingGraphView(GraphView): | ||||
"""lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" | """Base class for views streaming their response line by line.""" | ||||
try: | |||||
return backend.pid2node[pid] | |||||
except KeyError: | |||||
raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") | |||||
except ValidationError: | |||||
raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") | |||||
content_type = "text/plain" | |||||
def pid_of_node(node, backend): | @asynccontextmanager | ||||
"""lookup a node in a node2pid map, failing in an HTTP-nice way if needed | async def response_streamer(self, *args, **kwargs): | ||||
"""Context manager to prepare then close a StreamResponse""" | |||||
response = aiohttp.web.StreamResponse(*args, **kwargs) | |||||
response.content_type = self.content_type | |||||
await response.prepare(self.request) | |||||
yield response | |||||
await response.write_eof() | |||||
async def get(self): | |||||
await self.prepare_response() | |||||
async with self.response_streamer() as self.response_stream: | |||||
await self.stream_response() | |||||
return self.response_stream | |||||
async def prepare_response(self): | |||||
"""This can be overridden with some setup to be run before the response | |||||
actually starts streaming. | |||||
""" | """ | ||||
try: | pass | ||||
return backend.node2pid[node] | |||||
except KeyError: | async def stream_response(self): | ||||
raise aiohttp.web.HTTPInternalServerError( | """Override this to perform the response streaming. Implementations of | ||||
body=f"reverse lookup failed for node id: {node}" | this should await self.stream_line(line) to write each line. | ||||
) | """ | ||||
raise NotImplementedError | |||||
async def stream_line(self, line): | |||||
"""Write a line in the response stream.""" | |||||
await self.response_stream.write((line + "\n").encode()) | |||||
class StatsView(GraphView): | |||||
"""View showing some statistics on the graph""" | |||||
async def get(self): | |||||
stats = self.backend.stats() | |||||
return aiohttp.web.Response(body=stats, content_type="application/json") | |||||
class SimpleTraversalView(StreamingGraphView): | |||||
"""Base class for views of simple traversals""" | |||||
def get_simple_traversal_handler(ttype): | simple_traversal_type: Optional[str] = None | ||||
async def simple_traversal(request): | |||||
backend = request.app["backend"] | async def prepare_response(self): | ||||
src = self.request.match_info["src"] | |||||
src = request.match_info["src"] | self.src_node = self.node_of_pid(src) | ||||
edges = get_edges(request) | |||||
direction = get_direction(request) | self.edges = self.get_edges() | ||||
self.direction = self.get_direction() | |||||
src_node = node_of_pid(src, backend) | |||||
async with stream_response(request) as response: | async def stream_response(self): | ||||
async for res_node in backend.simple_traversal( | async for res_node in self.backend.simple_traversal( | ||||
ttype, direction, edges, src_node | self.simple_traversal_type, self.direction, self.edges, self.src_node | ||||
): | ): | ||||
res_pid = pid_of_node(res_node, backend) | res_pid = self.pid_of_node(res_node) | ||||
await response.write("{}\n".format(res_pid).encode()) | await self.stream_line(res_pid) | ||||
return response | |||||
class LeavesView(SimpleTraversalView): | |||||
simple_traversal_type = "leaves" | |||||
return simple_traversal | |||||
class NeighborsView(SimpleTraversalView): | |||||
simple_traversal_type = "neighbors" | |||||
def get_walk_handler(random=False): | |||||
async def walk(request): | |||||
backend = request.app["backend"] | |||||
src = request.match_info["src"] | class VisitNodesView(SimpleTraversalView): | ||||
dst = request.match_info["dst"] | simple_traversal_type = "visit_nodes" | ||||
edges = get_edges(request) | |||||
direction = get_direction(request) | |||||
algo = get_traversal(request) | |||||
limit = get_limit(request) | |||||
src_node = node_of_pid(src, backend) | |||||
class WalkView(StreamingGraphView): | |||||
async def prepare_response(self): | |||||
src = self.request.match_info["src"] | |||||
dst = self.request.match_info["dst"] | |||||
self.src_node = self.node_of_pid(src) | |||||
if dst not in PID_TYPES: | if dst not in PID_TYPES: | ||||
dst = node_of_pid(dst, backend) | self.dst_thing = self.node_of_pid(dst) | ||||
async with stream_response(request) as response: | |||||
if random: | |||||
it = backend.random_walk( | |||||
direction, edges, RANDOM_RETRIES, src_node, dst | |||||
) | |||||
else: | else: | ||||
it = backend.walk(direction, edges, algo, src_node, dst) | self.dst_thing = dst | ||||
self.edges = self.get_edges() | |||||
self.direction = self.get_direction() | |||||
self.algo = self.get_traversal() | |||||
self.limit = self.get_limit() | |||||
async def get_walk_iterator(self): | |||||
return self.backend.walk( | |||||
self.direction, self.edges, self.algo, self.src_node, self.dst_thing | |||||
) | |||||
if limit < 0: | async def stream_response(self): | ||||
queue = deque(maxlen=-limit) | it = self.get_walk_iterator() | ||||
if self.limit < 0: | |||||
queue = deque(maxlen=-self.limit) | |||||
async for res_node in it: | async for res_node in it: | ||||
res_pid = pid_of_node(res_node, backend) | res_pid = self.pid_of_node(res_node) | ||||
queue.append("{}\n".format(res_pid).encode()) | queue.append(res_pid) | ||||
while queue: | while queue: | ||||
await response.write(queue.popleft()) | await self.stream_line(queue.popleft()) | ||||
else: | else: | ||||
count = 0 | count = 0 | ||||
async for res_node in it: | async for res_node in it: | ||||
if limit == 0 or count < limit: | if self.limit == 0 or count < self.limit: | ||||
res_pid = pid_of_node(res_node, backend) | res_pid = self.pid_of_node(res_node) | ||||
await response.write("{}\n".format(res_pid).encode()) | await self.stream_line(res_pid) | ||||
count += 1 | count += 1 | ||||
else: | else: | ||||
break | break | ||||
return response | |||||
return walk | |||||
class RandomWalkView(WalkView): | |||||
def get_walk_iterator(self): | |||||
return self.backend.random_walk( | |||||
self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing | |||||
) | |||||
async def visit_paths(request): | class VisitEdgesView(SimpleTraversalView): | ||||
backend = request.app["backend"] | async def stream_response(self): | ||||
it = self.backend.visit_edges(self.direction, self.edges, self.src_node) | |||||
async for (res_src, res_dst) in it: | |||||
res_src_pid = self.pid_of_node(res_src) | |||||
res_dst_pid = self.pid_of_node(res_dst) | |||||
await self.stream_line("{} {}".format(res_src_pid, res_dst_pid)) | |||||
src = request.match_info["src"] | |||||
edges = get_edges(request) | |||||
direction = get_direction(request) | |||||
src_node = node_of_pid(src, backend) | class VisitPathsView(SimpleTraversalView): | ||||
it = backend.visit_paths(direction, edges, src_node) | content_type = "application/x-ndjson" | ||||
async with stream_response( | |||||
request, content_type="application/x-ndjson" | async def stream_response(self): | ||||
) as response: | it = self.backend.visit_paths(self.direction, self.edges, self.src_node) | ||||
async for res_path in it: | async for res_path in it: | ||||
res_path_pid = [pid_of_node(n, backend) for n in res_path] | res_path_pid = [self.pid_of_node(n) for n in res_path] | ||||
line = json.dumps(res_path_pid) | line = json.dumps(res_path_pid) | ||||
await response.write("{}\n".format(line).encode()) | await self.stream_line(line) | ||||
return response | |||||
async def visit_edges(request): | class CountView(GraphView): | ||||
backend = request.app["backend"] | """Base class for counting views.""" | ||||
src = request.match_info["src"] | count_type: Optional[str] = None | ||||
edges = get_edges(request) | |||||
direction = get_direction(request) | |||||
src_node = node_of_pid(src, backend) | async def get(self): | ||||
it = backend.visit_edges(direction, edges, src_node) | src = self.request.match_info["src"] | ||||
async with stream_response(request) as response: | self.src_node = self.node_of_pid(src) | ||||
async for (res_src, res_dst) in it: | |||||
res_src_pid = pid_of_node(res_src, backend) | |||||
res_dst_pid = pid_of_node(res_dst, backend) | |||||
await response.write("{} {}\n".format(res_src_pid, res_dst_pid).encode()) | |||||
return response | |||||
self.edges = self.get_edges() | |||||
self.direction = self.get_direction() | |||||
def get_count_handler(ttype): | |||||
async def count(request): | |||||
loop = asyncio.get_event_loop() | loop = asyncio.get_event_loop() | ||||
backend = request.app["backend"] | |||||
src = request.match_info["src"] | |||||
edges = get_edges(request) | |||||
direction = get_direction(request) | |||||
src_node = node_of_pid(src, backend) | |||||
cnt = await loop.run_in_executor( | cnt = await loop.run_in_executor( | ||||
None, backend.count, ttype, direction, edges, src_node | None, | ||||
self.backend.count, | |||||
self.count_type, | |||||
self.direction, | |||||
self.edges, | |||||
self.src_node, | |||||
) | ) | ||||
return aiohttp.web.Response(body=str(cnt), content_type="application/json") | return aiohttp.web.Response(body=str(cnt), content_type="application/json") | ||||
return count | |||||
class CountNeighborsView(CountView): | |||||
count_type = "neighbors" | |||||
class CountLeavesView(CountView): | |||||
count_type = "leaves" | |||||
class CountVisitNodesView(CountView): | |||||
count_type = "visit_nodes" | |||||
def make_app(backend, **kwargs): | def make_app(backend, **kwargs): | ||||
app = RPCServerApp(**kwargs) | app = RPCServerApp(**kwargs) | ||||
app.router.add_get("/", index) | app.add_routes( | ||||
app.router.add_get("/graph", index) | [ | ||||
app.router.add_get("/graph/stats", stats) | aiohttp.web.get("/", index), | ||||
aiohttp.web.get("/graph", index), | |||||
app.router.add_get("/graph/leaves/{src}", get_simple_traversal_handler("leaves")) | aiohttp.web.view("/graph/stats", StatsView), | ||||
app.router.add_get( | aiohttp.web.view("/graph/leaves/{src}", LeavesView), | ||||
"/graph/neighbors/{src}", get_simple_traversal_handler("neighbors") | aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), | ||||
) | aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), | ||||
app.router.add_get( | aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), | ||||
"/graph/visit/nodes/{src}", get_simple_traversal_handler("visit_nodes") | aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), | ||||
) | |||||
app.router.add_get("/graph/visit/edges/{src}", visit_edges) | |||||
app.router.add_get("/graph/visit/paths/{src}", visit_paths) | |||||
# temporarily disabled in wait of a proper fix for T1969 | # temporarily disabled in wait of a proper fix for T1969 | ||||
# app.router.add_get('/graph/walk/{src}/{dst}', | # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) | ||||
# get_walk_handler(random=False)) | aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), | ||||
# app.router.add_get('/graph/walk/last/{src}/{dst}', | aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), | ||||
# get_walk_handler(random=False, last=True)) | aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), | ||||
aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), | |||||
app.router.add_get("/graph/randomwalk/{src}/{dst}", get_walk_handler(random=True)) | ] | ||||
app.router.add_get("/graph/neighbors/count/{src}", get_count_handler("neighbors")) | |||||
app.router.add_get("/graph/leaves/count/{src}", get_count_handler("leaves")) | |||||
app.router.add_get( | |||||
"/graph/visit/nodes/count/{src}", get_count_handler("visit_nodes") | |||||
) | ) | ||||
app["backend"] = backend | app["backend"] = backend | ||||
return app | return app |