Changeset View
Changeset View
Standalone View
Standalone View
swh/graph/http_server.py
- This file was moved from swh/graph/server/app.py.
# Copyright (C) 2019-2020 The Software Heritage developers | # Copyright (C) 2019-2020 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
""" | """ | ||||
A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using | A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using | ||||
FIFO as a transport to stream integers between the two languages. | FIFO as a transport to stream integers between the two languages. | ||||
""" | """ | ||||
import asyncio | import json | ||||
from collections import deque | |||||
import os | import os | ||||
from typing import Optional | from typing import Optional | ||||
import aiohttp.test_utils | |||||
import aiohttp.web | import aiohttp.web | ||||
from google.protobuf import json_format | |||||
import grpc | |||||
from swh.core.api.asynchronous import RPCServerApp | from swh.core.api.asynchronous import RPCServerApp | ||||
from swh.core.config import read as config_read | from swh.core.config import read as config_read | ||||
from swh.graph.backend import Backend | from swh.graph.rpc.swhgraph_pb2 import ( | ||||
CheckSwhidRequest, | |||||
NodeFields, | |||||
NodeFilter, | |||||
StatsRequest, | |||||
TraversalRequest, | |||||
) | |||||
from swh.graph.rpc.swhgraph_pb2_grpc import TraversalServiceStub | |||||
from swh.graph.rpc_server import spawn_java_rpc_server | |||||
from swh.model.swhids import EXTENDED_SWHID_TYPES | from swh.model.swhids import EXTENDED_SWHID_TYPES | ||||
try: | try: | ||||
from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||
except ImportError: | except ImportError: | ||||
# Compatibility with 3.6 backport | # Compatibility with 3.6 backport | ||||
from async_generator import asynccontextmanager # type: ignore | from async_generator import asynccontextmanager # type: ignore | ||||
# maximum number of retries for random walks | # maximum number of retries for random walks | ||||
RANDOM_RETRIES = 10 # TODO make this configurable via rpc-serve configuration | RANDOM_RETRIES = 10 # TODO make this configurable via rpc-serve configuration | ||||
class GraphServerApp(RPCServerApp): | class GraphServerApp(RPCServerApp): | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
self.on_startup.append(self._start_gateway) | self.on_startup.append(self._start) | ||||
self.on_shutdown.append(self._stop_gateway) | self.on_shutdown.append(self._stop) | ||||
@staticmethod | @staticmethod | ||||
async def _start_gateway(app): | async def _start(app): | ||||
# Equivalent to entering `with app["backend"]:` | app["channel"] = grpc.aio.insecure_channel(app["rpc_url"]) | ||||
app["backend"].start_gateway() | await app["channel"].__aenter__() | ||||
app["rpc_client"] = TraversalServiceStub(app["channel"]) | |||||
await app["rpc_client"].Stats(StatsRequest(), wait_for_ready=True) | |||||
@staticmethod | @staticmethod | ||||
async def _stop_gateway(app): | async def _stop(app): | ||||
# Equivalent to exiting `with app["backend"]:` with no error | await app["channel"].__aexit__(None, None, None) | ||||
app["backend"].stop_gateway() | if app.get("local_server"): | ||||
app["local_server"].terminate() | |||||
async def index(request): | async def index(request): | ||||
return aiohttp.web.Response( | return aiohttp.web.Response( | ||||
content_type="text/html", | content_type="text/html", | ||||
body="""<html> | body="""<html> | ||||
<head><title>Software Heritage graph server</title></head> | <head><title>Software Heritage graph server</title></head> | ||||
<body> | <body> | ||||
<p>You have reached the <a href="https://www.softwareheritage.org/"> | <p>You have reached the <a href="https://www.softwareheritage.org/"> | ||||
Software Heritage</a> graph API server.</p> | Software Heritage</a> graph API server.</p> | ||||
<p>See its | <p>See its | ||||
<a href="https://docs.softwareheritage.org/devel/swh-graph/api.html">API | <a href="https://docs.softwareheritage.org/devel/swh-graph/api.html">API | ||||
documentation</a> for more information.</p> | documentation</a> for more information.</p> | ||||
</body> | </body> | ||||
</html>""", | </html>""", | ||||
) | ) | ||||
class GraphView(aiohttp.web.View): | class GraphView(aiohttp.web.View): | ||||
"""Base class for views working on the graph, with utility functions""" | """Base class for views working on the graph, with utility functions""" | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super().__init__(*args, **kwargs) | super().__init__(*args, **kwargs) | ||||
self.backend = self.request.app["backend"] | self.rpc_client: TraversalServiceStub = self.request.app["rpc_client"] | ||||
def get_direction(self): | def get_direction(self): | ||||
"""Validate HTTP query parameter `direction`""" | """Validate HTTP query parameter `direction`""" | ||||
s = self.request.query.get("direction", "forward") | s = self.request.query.get("direction", "forward") | ||||
if s not in ("forward", "backward"): | if s not in ("forward", "backward"): | ||||
raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") | raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") | ||||
return s | return s.upper() | ||||
def get_edges(self): | def get_edges(self): | ||||
"""Validate HTTP query parameter `edges`, i.e., edge restrictions""" | """Validate HTTP query parameter `edges`, i.e., edge restrictions""" | ||||
s = self.request.query.get("edges", "*") | s = self.request.query.get("edges", "*") | ||||
if any( | if any( | ||||
[ | [ | ||||
node_type != "*" and node_type not in EXTENDED_SWHID_TYPES | node_type != "*" and node_type not in EXTENDED_SWHID_TYPES | ||||
for edge in s.split(":") | for edge in s.split(":") | ||||
Show All 40 Lines | def get_max_edges(self): | ||||
"""Validate HTTP query parameter 'max_edges', i.e., | """Validate HTTP query parameter 'max_edges', i.e., | ||||
the limit of the number of edges that can be visited""" | the limit of the number of edges that can be visited""" | ||||
s = self.request.query.get("max_edges", "0") | s = self.request.query.get("max_edges", "0") | ||||
try: | try: | ||||
return int(s) | return int(s) | ||||
except ValueError: | except ValueError: | ||||
raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") | raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") | ||||
def check_swhid(self, swhid): | async def check_swhid(self, swhid): | ||||
"""Validate that the given SWHID exists in the graph""" | """Validate that the given SWHID exists in the graph""" | ||||
try: | r = await self.rpc_client.CheckSwhid(CheckSwhidRequest(swhid=swhid)) | ||||
self.backend.check_swhid(swhid) | if not r.exists: | ||||
except (NameError, ValueError) as e: | raise aiohttp.web.HTTPBadRequest(text=str(r.details)) | ||||
raise aiohttp.web.HTTPBadRequest(text=str(e)) | |||||
class StreamingGraphView(GraphView): | class StreamingGraphView(GraphView): | ||||
"""Base class for views streaming their response line by line.""" | """Base class for views streaming their response line by line.""" | ||||
content_type = "text/plain" | content_type = "text/plain" | ||||
@asynccontextmanager | @asynccontextmanager | ||||
Show All 37 Lines | async def _flush_buffer(self): | ||||
await self.response_stream.write("\n".join(self._buf).encode() + b"\n") | await self.response_stream.write("\n".join(self._buf).encode() + b"\n") | ||||
self._buf = [] | self._buf = [] | ||||
class StatsView(GraphView): | class StatsView(GraphView): | ||||
"""View showing some statistics on the graph""" | """View showing some statistics on the graph""" | ||||
async def get(self): | async def get(self): | ||||
stats = self.backend.stats() | res = await self.rpc_client.Stats(StatsRequest()) | ||||
return aiohttp.web.Response(body=stats, content_type="application/json") | stats = json_format.MessageToDict( | ||||
res, including_default_value_fields=True, preserving_proto_field_name=True | |||||
) | |||||
# Int64 fields are serialized as strings by default. | |||||
for descriptor in res.DESCRIPTOR.fields: | |||||
if descriptor.type == descriptor.TYPE_INT64: | |||||
try: | |||||
stats[descriptor.name] = int(stats[descriptor.name]) | |||||
except KeyError: | |||||
pass | |||||
json_body = json.dumps(stats, indent=4, sort_keys=True) | |||||
return aiohttp.web.Response(body=json_body, content_type="application/json") | |||||
class SimpleTraversalView(StreamingGraphView): | class SimpleTraversalView(StreamingGraphView): | ||||
"""Base class for views of simple traversals""" | """Base class for views of simple traversals""" | ||||
simple_traversal_type: Optional[str] = None | |||||
async def prepare_response(self): | async def prepare_response(self): | ||||
self.src = self.request.match_info["src"] | src = self.request.match_info["src"] | ||||
self.edges = self.get_edges() | self.traversal_request = TraversalRequest( | ||||
self.direction = self.get_direction() | src=[src], | ||||
self.max_edges = self.get_max_edges() | edges=self.get_edges(), | ||||
self.return_types = self.get_return_types() | direction=self.get_direction(), | ||||
self.check_swhid(self.src) | return_nodes=NodeFilter(types=self.get_return_types()), | ||||
return_fields=NodeFields(), | |||||
) | |||||
if self.get_max_edges(): | |||||
self.traversal_request.max_edges = self.get_max_edges() | |||||
await self.check_swhid(src) | |||||
self.configure_request() | |||||
def configure_request(self): | |||||
pass | |||||
async def stream_response(self): | async def stream_response(self): | ||||
async for res_line in self.backend.traversal( | async for node in self.rpc_client.Traverse(self.traversal_request): | ||||
self.simple_traversal_type, | await self.stream_line(node.swhid) | ||||
self.direction, | |||||
self.edges, | |||||
self.src, | |||||
self.max_edges, | |||||
self.return_types, | |||||
): | |||||
await self.stream_line(res_line) | |||||
class LeavesView(SimpleTraversalView): | class LeavesView(SimpleTraversalView): | ||||
simple_traversal_type = "leaves" | def configure_request(self): | ||||
self.traversal_request.return_nodes.max_traversal_successors = 0 | |||||
class NeighborsView(SimpleTraversalView): | class NeighborsView(SimpleTraversalView): | ||||
simple_traversal_type = "neighbors" | def configure_request(self): | ||||
self.traversal_request.min_depth = 1 | |||||
self.traversal_request.max_depth = 1 | |||||
class VisitNodesView(SimpleTraversalView): | class VisitNodesView(SimpleTraversalView): | ||||
simple_traversal_type = "visit_nodes" | pass | ||||
class VisitEdgesView(SimpleTraversalView): | class VisitEdgesView(SimpleTraversalView): | ||||
simple_traversal_type = "visit_edges" | def configure_request(self): | ||||
self.traversal_request.return_fields.successor = True | |||||
class WalkView(StreamingGraphView): | |||||
async def prepare_response(self): | |||||
self.src = self.request.match_info["src"] | |||||
self.dst = self.request.match_info["dst"] | |||||
self.edges = self.get_edges() | |||||
self.direction = self.get_direction() | |||||
self.algo = self.get_traversal() | |||||
self.limit = self.get_limit() | |||||
self.max_edges = self.get_max_edges() | |||||
self.return_types = self.get_return_types() | |||||
self.check_swhid(self.src) | |||||
if self.dst not in EXTENDED_SWHID_TYPES: | |||||
self.check_swhid(self.dst) | |||||
async def get_walk_iterator(self): | |||||
return self.backend.traversal( | |||||
"walk", | |||||
self.direction, | |||||
self.edges, | |||||
self.algo, | |||||
self.src, | |||||
self.dst, | |||||
self.max_edges, | |||||
self.return_types, | |||||
) | |||||
async def stream_response(self): | async def stream_response(self): | ||||
it = self.get_walk_iterator() | async for node in self.rpc_client.Traverse(self.traversal_request): | ||||
if self.limit < 0: | for succ in node.successor: | ||||
queue = deque(maxlen=-self.limit) | await self.stream_line(node.swhid + " " + succ.swhid) | ||||
async for res_swhid in it: | |||||
queue.append(res_swhid) | |||||
while queue: | |||||
await self.stream_line(queue.popleft()) | |||||
else: | |||||
count = 0 | |||||
async for res_swhid in it: | |||||
if self.limit == 0 or count < self.limit: | |||||
await self.stream_line(res_swhid) | |||||
count += 1 | |||||
else: | |||||
break | |||||
class RandomWalkView(WalkView): | |||||
def get_walk_iterator(self): | |||||
return self.backend.traversal( | |||||
"random_walk", | |||||
self.direction, | |||||
self.edges, | |||||
RANDOM_RETRIES, | |||||
self.src, | |||||
self.dst, | |||||
self.max_edges, | |||||
self.return_types, | |||||
) | |||||
class CountView(GraphView): | class CountView(GraphView): | ||||
"""Base class for counting views.""" | """Base class for counting views.""" | ||||
count_type: Optional[str] = None | count_type: Optional[str] = None | ||||
async def get(self): | async def get(self): | ||||
self.src = self.request.match_info["src"] | src = self.request.match_info["src"] | ||||
self.check_swhid(self.src) | self.traversal_request = TraversalRequest( | ||||
src=[src], | |||||
self.edges = self.get_edges() | edges=self.get_edges(), | ||||
self.direction = self.get_direction() | direction=self.get_direction(), | ||||
self.max_edges = self.get_max_edges() | return_nodes=NodeFilter(types=self.get_return_types()), | ||||
return_fields=NodeFields(), | |||||
loop = asyncio.get_event_loop() | |||||
cnt = await loop.run_in_executor( | |||||
None, | |||||
self.backend.count, | |||||
self.count_type, | |||||
self.direction, | |||||
self.edges, | |||||
self.src, | |||||
self.max_edges, | |||||
) | ) | ||||
return aiohttp.web.Response(body=str(cnt), content_type="application/json") | if self.get_max_edges(): | ||||
self.traversal_request.max_edges = self.get_max_edges() | |||||
self.configure_request() | |||||
res = await self.rpc_client.CountNodes(self.traversal_request) | |||||
return aiohttp.web.Response( | |||||
body=str(res.count), content_type="application/json" | |||||
) | |||||
def configure_request(self): | |||||
pass | |||||
class CountNeighborsView(CountView): | class CountNeighborsView(CountView): | ||||
count_type = "neighbors" | def configure_request(self): | ||||
self.traversal_request.min_depth = 1 | |||||
self.traversal_request.max_depth = 1 | |||||
class CountLeavesView(CountView): | class CountLeavesView(CountView): | ||||
count_type = "leaves" | def configure_request(self): | ||||
self.traversal_request.return_nodes.max_traversal_successors = 0 | |||||
class CountVisitNodesView(CountView): | class CountVisitNodesView(CountView): | ||||
count_type = "visit_nodes" | pass | ||||
def make_app(config=None, backend=None, **kwargs): | def make_app(config=None, rpc_url=None, **kwargs): | ||||
if (config is None) == (backend is None): | |||||
raise ValueError("make_app() expects exactly one of 'config' or 'backend'") | |||||
if backend is None: | |||||
backend = Backend(graph_path=config["graph"]["path"], config=config["graph"]) | |||||
app = GraphServerApp(**kwargs) | app = GraphServerApp(**kwargs) | ||||
if rpc_url is None: | |||||
app["local_server"], port = spawn_java_rpc_server(config) | |||||
rpc_url = f"localhost:{port}" | |||||
app.add_routes( | app.add_routes( | ||||
[ | [ | ||||
aiohttp.web.get("/", index), | aiohttp.web.get("/", index), | ||||
aiohttp.web.get("/graph", index), | aiohttp.web.get("/graph", index), | ||||
aiohttp.web.view("/graph/stats", StatsView), | aiohttp.web.view("/graph/stats", StatsView), | ||||
aiohttp.web.view("/graph/leaves/{src}", LeavesView), | aiohttp.web.view("/graph/leaves/{src}", LeavesView), | ||||
aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), | aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), | ||||
aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), | aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), | ||||
aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), | aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), | ||||
# temporarily disabled in wait of a proper fix for T1969 | |||||
# aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) | |||||
aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), | |||||
aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), | aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), | ||||
aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), | aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), | ||||
aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), | aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), | ||||
] | ] | ||||
) | ) | ||||
app["backend"] = backend | app["rpc_url"] = rpc_url | ||||
return app | return app | ||||
def make_app_from_configfile(): | def make_app_from_configfile(): | ||||
"""Load configuration and then build application to run""" | """Load configuration and then build application to run""" | ||||
config_file = os.environ.get("SWH_CONFIG_FILENAME") | config_file = os.environ.get("SWH_CONFIG_FILENAME") | ||||
config = config_read(config_file) | config = config_read(config_file) | ||||
return make_app(config=config) | return make_app(config=config) |