diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java index e110941..a2d3f5a 100644 --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,193 +1,193 @@ package org.softwareheritage.graph; import java.io.*; import java.util.ArrayList; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; public class Entry { private Graph graph; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = Graph.loadMapped(graphBasename); System.err.println("Graph loaded."); } public Graph get_graph() { return graph.copy(); } public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } public void check_swhid(String src) { graph.getNodeId(new SWHID(src)); } private int count_visitor(NodeCountVisitor f, long srcNodeId) { int[] count = {0}; f.accept(srcNodeId, (node) -> { count[0]++; }); return count[0]; } - public int count_leaves(String direction, String edgesFmt, String src) { + public int count_leaves(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::leavesVisitor, srcNodeId); } - public int count_neighbors(String direction, String edgesFmt, String src) { + public int count_neighbors(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::neighborsVisitor, srcNodeId); } - public int count_visit_nodes(String direction, String edgesFmt, String src) { + public int count_visit_nodes(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::visitNodesVisitor, srcNodeId); } public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(graph.copy(), clientFIFO); } private interface NodeCountVisitor { void accept(long nodeId, Traversal.NodeIdConsumer consumer); } public class QueryHandler { Graph graph; BufferedWriter out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } public void writeNode(SWHID swhid) { try { out.write(swhid.toString() + "\n"); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void writeEdge(SWHID src, SWHID dst) { try { out.write(src.toString() + " " + dst.toString() + "\n"); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); this.out = new BufferedWriter(new OutputStreamWriter(file)); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void leaves(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.leaves(srcNodeId)) { writeNode(graph.getSWHID(nodeId)); } close(); } public void neighbors(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.neighbors(srcNodeId)) { writeNode(graph.getSWHID(nodeId)); } close(); } public void visit_nodes(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); for (Long nodeId : t.visitNodes(srcNodeId)) { writeNode(graph.getSWHID(nodeId)); } close(); } public void visit_edges(String direction, String edgesFmt, String src, long maxEdges, String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges); t.visitNodesVisitor(srcNodeId, null, (srcId, dstId) -> { writeEdge(graph.getSWHID(srcId), graph.getSWHID(dstId)); }); close(); } - public void walk(String direction, String edgesFmt, String algorithm, String src, String dst) { + public void walk(String direction, String edgesFmt, String algorithm, String src, String dst, long maxEdges, + String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); ArrayList res; + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); if (dst.matches("ori|snp|rel|rev|dir|cnt")) { Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.walk(srcNodeId, dstType, algorithm); } else { long dstNodeId = graph.getNodeId(new SWHID(dst)); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.walk(srcNodeId, dstNodeId, algorithm); } for (Long nodeId : res) { writeNode(graph.getSWHID(nodeId)); } close(); } - public void random_walk(String direction, String edgesFmt, int retries, String src, String dst) { + public void random_walk(String direction, String edgesFmt, int retries, String src, String dst, long maxEdges, + String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); ArrayList res; + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); if (dst.matches("ori|snp|rel|rev|dir|cnt")) { Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.randomWalk(srcNodeId, dstType, retries); } else { long dstNodeId = graph.getNodeId(new SWHID(dst)); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.randomWalk(srcNodeId, dstNodeId, retries); } for (Long nodeId : res) { writeNode(graph.getSWHID(nodeId)); } close(); } } } diff --git a/swh/graph/backend.py b/swh/graph/backend.py index b123238..5fb82f5 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,176 +1,176 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os import re import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway from py4j.protocol import Py4JJavaError from swh.graph.config import check_config BUF_LINES = 1024 def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path, config=None): self.gateway = None self.entry = None self.graph_path = graph_path self.config = check_config(config or {}) def start_gateway(self): self.gateway = JavaGateway.launch_gateway( java_path=None, javaopts=self.config["java_tool_options"].split(), classpath=self.config["classpath"], die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.stream_proxy = JavaStreamProxy(self.entry) def stop_gateway(self): self.gateway.shutdown() def __enter__(self): self.start_gateway() return self def __exit__(self, exc_type, exc_value, tb): self.stop_gateway() def stats(self): return self.entry.stats() def check_swhid(self, swhid): try: self.entry.check_swhid(swhid) except Py4JJavaError as e: m = re.search(r"malformed SWHID: (\w+)", str(e)) if m: raise ValueError(f"malformed SWHID: {m[1]}") m = re.search(r"Unknown SWHID: (\w+)", str(e)) if m: raise NameError(f"Unknown SWHID: {m[1]}") raise - def count(self, ttype, direction, edges_fmt, src): + def count(self, ttype, *args): method = getattr(self.entry, "count_" + ttype) - return method(direction, edges_fmt, src) + return method(*args) async def traversal(self, ttype, *args): method = getattr(self.stream_proxy, ttype) async for line in method(*args): yield line.decode().rstrip("\n") class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, "rb") # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: def read_n_lines(f, n): buf = [] for _ in range(n): try: buf.append(next(f)) except StopIteration: break return buf while True: lines = await loop.run_in_executor(None, read_n_lines, f, BUF_LINES) if not lines: break for line in lines: yield line class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix="swh-graph-") as tmpdirname: cli_fifo = os.path.join(tmpdirname, "swh-graph.fifo") os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index be07fc0..0128f2a 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,360 +1,373 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio from collections import deque import os from typing import Optional import aiohttp.web from swh.core.api.asynchronous import RPCServerApp from swh.core.config import read as config_read from swh.graph.backend import Backend from swh.model.swhids import EXTENDED_SWHID_TYPES try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration class GraphServerApp(RPCServerApp): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.on_startup.append(self._start_gateway) self.on_shutdown.append(self._stop_gateway) @staticmethod async def _start_gateway(app): # Equivalent to entering `with app["backend"]:` app["backend"].start_gateway() @staticmethod async def _stop_gateway(app): # Equivalent to exiting `with app["backend"]:` with no error app["backend"].stop_gateway() async def index(request): return aiohttp.web.Response( content_type="text/html", body=""" Software Heritage graph server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.backend = self.request.app["backend"] def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s def get_return_types(self): """Validate HTTP query parameter 'return types', i.e, a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) # if the user puts a star, # then we filter nothing, we don't need the other information if "*" in s: return "*" else: return s def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(text=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") def check_swhid(self, swhid): """Validate that the given SWHID exists in the graph""" try: self.backend.check_swhid(swhid) except (NameError, ValueError) as e: raise aiohttp.web.HTTPBadRequest(text=str(e)) class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: self._buf = [] try: await self.stream_response() finally: await self._flush_buffer() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" self._buf.append(line) if len(self._buf) > 100: await self._flush_buffer() async def _flush_buffer(self): await self.response_stream.write("\n".join(self._buf).encode() + b"\n") self._buf = [] class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): stats = self.backend.stats() return aiohttp.web.Response(body=stats, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" simple_traversal_type: Optional[str] = None async def prepare_response(self): self.src = self.request.match_info["src"] self.edges = self.get_edges() self.direction = self.get_direction() self.max_edges = self.get_max_edges() self.return_types = self.get_return_types() self.check_swhid(self.src) async def stream_response(self): async for res_line in self.backend.traversal( self.simple_traversal_type, self.direction, self.edges, self.src, self.max_edges, self.return_types, ): await self.stream_line(res_line) class LeavesView(SimpleTraversalView): simple_traversal_type = "leaves" class NeighborsView(SimpleTraversalView): simple_traversal_type = "neighbors" class VisitNodesView(SimpleTraversalView): simple_traversal_type = "visit_nodes" class VisitEdgesView(SimpleTraversalView): simple_traversal_type = "visit_edges" class WalkView(StreamingGraphView): async def prepare_response(self): self.src = self.request.match_info["src"] self.dst = self.request.match_info["dst"] self.edges = self.get_edges() self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() + self.max_edges = self.get_max_edges() + self.return_types = self.get_return_types() self.check_swhid(self.src) if self.dst not in EXTENDED_SWHID_TYPES: self.check_swhid(self.dst) async def get_walk_iterator(self): return self.backend.traversal( - "walk", self.direction, self.edges, self.algo, self.src, self.dst + "walk", + self.direction, + self.edges, + self.algo, + self.src, + self.dst, + self.max_edges, + self.return_types, ) async def stream_response(self): it = self.get_walk_iterator() if self.limit < 0: queue = deque(maxlen=-self.limit) async for res_swhid in it: queue.append(res_swhid) while queue: await self.stream_line(queue.popleft()) else: count = 0 async for res_swhid in it: if self.limit == 0 or count < self.limit: await self.stream_line(res_swhid) count += 1 else: break class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.traversal( "random_walk", self.direction, self.edges, RANDOM_RETRIES, self.src, self.dst, + self.max_edges, + self.return_types, ) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): self.src = self.request.match_info["src"] self.check_swhid(self.src) self.edges = self.get_edges() self.direction = self.get_direction() + self.max_edges = self.get_max_edges() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( None, self.backend.count, self.count_type, self.direction, self.edges, self.src, + self.max_edges, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") class CountNeighborsView(CountView): count_type = "neighbors" class CountLeavesView(CountView): count_type = "leaves" class CountVisitNodesView(CountView): count_type = "visit_nodes" def make_app(config=None, backend=None, **kwargs): if (config is None) == (backend is None): raise ValueError("make_app() expects exactly one of 'config' or 'backend'") if backend is None: backend = Backend(graph_path=config["graph"]["path"], config=config["graph"]) app = GraphServerApp(**kwargs) app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), # temporarily disabled in wait of a proper fix for T1969 # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["backend"] = backend return app def make_app_from_configfile(): """Load configuration and then build application to run """ config_file = os.environ.get("SWH_CONFIG_FILENAME") config = config_read(config_file) return make_app(config=config)