diff --git a/java/server/src/main/java/org/softwareheritage/graph/Entry.java b/java/server/src/main/java/org/softwareheritage/graph/Entry.java index 6967f43..cf54321 100644 --- a/java/server/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/server/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,39 +1,121 @@ package org.softwareheritage.graph; +import java.util.Map; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; import py4j.GatewayServer; import org.softwareheritage.graph.Graph; +import org.softwareheritage.graph.Node; +import org.softwareheritage.graph.algo.Stats; import org.softwareheritage.graph.algo.NodeIdsConsumer; import org.softwareheritage.graph.algo.Traversal; public class Entry { Graph graph; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = new Graph(graphBasename); System.err.println("Graph loaded."); } - public void visit(long srcNodeId, String direction, String edgesFmt, - String clientFIFO) { - Traversal t = new Traversal(this.graph, direction, edgesFmt); + public String stats() { try { - FileOutputStream file = new FileOutputStream(clientFIFO); - DataOutputStream data = new DataOutputStream(file); - t.visitNodesVisitor(srcNodeId, (nodeId) -> { - try { - data.writeLong(nodeId); - } catch (IOException e) { - throw new RuntimeException("cannot write response to client: " + e); - }}); - data.close(); + Stats stats = new Stats(graph.getPath()); + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + String res = objectMapper.writeValueAsString(stats); + return res; } catch (IOException e) { - System.err.println("cannot write response to client: " + e); + throw new RuntimeException("Cannot read stats: " + e); + } + } + + public QueryHandler get_handler(String clientFIFO) { + return new QueryHandler(this.graph.copy(), clientFIFO); + } + + public class QueryHandler { + Graph graph; + DataOutputStream out; + String clientFIFO; + + public QueryHandler(Graph graph, String clientFIFO) { + this.graph = graph; + this.clientFIFO = clientFIFO; + this.out = null; + } + + public void writeNode(long nodeId) { + try { + out.writeLong(nodeId); + } catch (IOException e) { + throw new RuntimeException("Cannot write response to client: " + e); + } + } + + public void open() { + try { + FileOutputStream file = new FileOutputStream(this.clientFIFO); + this.out = new DataOutputStream(file); + } catch (IOException e) { + throw new RuntimeException("Cannot create FIFO: " + e); + } + } + + public void close() { + try { + out.close(); + } catch (IOException e) { + throw new RuntimeException("Cannot write response to client: " + e); + } + } + + public void leaves(String direction, String edgesFmt, long srcNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + t.leavesVisitor(srcNodeId, this::writeNode); + close(); + } + + public void neighbors(String direction, String edgesFmt, long srcNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + t.neighborsVisitor(srcNodeId, this::writeNode); + close(); + } + + public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + t.visitNodesVisitor(srcNodeId, this::writeNode); + close(); + } + + public void walk(String direction, String edgesFmt, String algorithm, + long srcNodeId, long dstNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { + writeNode(nodeId); + } + close(); + } + + public void walk_type(String direction, String edgesFmt, String algorithm, + long srcNodeId, String dst) { + open(); + Node.Type dstType = Node.Type.fromStr(dst); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { + writeNode(nodeId); + } + close(); } } } diff --git a/requirements.txt b/requirements.txt index f0777c3..3f9470a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiohttp click vcversioner +py4j diff --git a/swh/graph/cli.py b/swh/graph/cli.py index 329cea8..4033969 100644 --- a/swh/graph/cli.py +++ b/swh/graph/cli.py @@ -1,122 +1,122 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import click import sys from swh.core.cli import CONTEXT_SETTINGS, AliasedGroup from swh.graph import client from swh.graph.pid import PidToIntMap, IntToPidMap @click.group(name='graph', context_settings=CONTEXT_SETTINGS, cls=AliasedGroup) @click.pass_context def cli(ctx): """Software Heritage graph tools.""" ctx.ensure_object(dict) @cli.command('api-client') @click.option('--host', default='localhost', help='Graph server host') @click.option('--port', default='5009', help='Graph server port') @click.pass_context def api_client(ctx, host, port): """Client for the Software Heritage Graph REST service """ url = 'http://{}:{}'.format(host, port) app = client.RemoteGraphClient(url) # TODO: run web app print(app.stats()) @cli.group('map') @click.pass_context def map(ctx): """Manage swh-graph on-disk maps""" pass def dump_pid2int(filename): for (pid, int) in PidToIntMap(filename): print('{}\t{}'.format(pid, int)) def dump_int2pid(filename): for (int, pid) in IntToPidMap(filename): print('{}\t{}'.format(int, pid)) def restore_pid2int(filename): """read a textual PID->int map from stdin and write its binary version to filename """ with open(filename, 'wb') as dst: for line in sys.stdin: (str_pid, str_int) = line.split() PidToIntMap.write_record(dst, str_pid, int(str_int)) def restore_int2pid(filename, length): """read a textual int->PID map from stdin and write its binary version to filename """ int2pid = IntToPidMap(filename, mode='wb', length=length) for line in sys.stdin: (str_int, str_pid) = line.split() int2pid[int(str_int)] = str_pid int2pid.close() @map.command('dump') @click.option('--type', '-t', 'map_type', required=True, type=click.Choice(['pid2int', 'int2pid']), help='type of map to dump') @click.argument('filename', required=True, type=click.Path(exists=True)) @click.pass_context def dump_map(ctx, map_type, filename): """dump a binary PID<->int map to textual format""" if map_type == 'pid2int': dump_pid2int(filename) elif map_type == 'int2pid': dump_int2pid(filename) else: raise ValueError('invalid map type: ' + map_type) pass @map.command('restore') @click.option('--type', '-t', 'map_type', required=True, type=click.Choice(['pid2int', 'int2pid']), help='type of map to dump') @click.option('--length', '-l', type=int, help='''map size in number of logical records (required for int2pid maps)''') @click.argument('filename', required=True, type=click.Path()) @click.pass_context def restore_map(ctx, map_type, length, filename): """restore a binary PID<->int map from textual format""" if map_type == 'pid2int': - restore_pid2int(filename, length) + restore_pid2int(filename) elif map_type == 'int2pid': if length is None: raise click.UsageError( 'map length is required when restoring {} maps'.format( map_type), ctx) restore_int2pid(filename, length) else: raise ValueError('invalid map type: ' + map_type) def main(): return cli(auto_envvar_prefix='SWH_GRAPH') if __name__ == '__main__': main() diff --git a/swh/graph/server/__main__.py b/swh/graph/server/__main__.py index d65cc6a..7bd3017 100755 --- a/swh/graph/server/__main__.py +++ b/swh/graph/server/__main__.py @@ -1,54 +1,101 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import argparse +import contextlib import aiohttp.web from swh.core.api.asynchronous import RPCServerApp from swh.graph.server.backend import Backend +@contextlib.asynccontextmanager +async def stream_response(request, *args, **kwargs): + response = aiohttp.web.StreamResponse(*args, **kwargs) + await response.prepare(request) + yield response + await response.write_eof() + + async def index(request): return aiohttp.web.Response(body="SWH Graph API server") +async def stats(request): + stats = request.app['backend'].stats() + return aiohttp.web.Response(body=stats, content_type='application/json') + + +async def _simple_traversal(request, ttype): + assert ttype in ('leaves', 'neighbors', 'visit_nodes') + method = getattr(request.app['backend'], ttype) + + src = request.match_info['src'] + edges = request.query.get('edges', '*') + direction = request.query.get('direction', 'forward') + + async with stream_response(request) as response: + async for res_pid in method(direction, edges, src): + await response.write('{}\n'.format(res_pid).encode()) + return response + + +async def leaves(request): + return (await _simple_traversal(request, 'leaves')) + + +async def neighbors(request): + return (await _simple_traversal(request, 'neighbors')) + + async def visit(request): - node_id = int(request.match_info['id']) - response = aiohttp.web.StreamResponse(status=200) - await response.prepare(request) - async for node_id in request.app['backend'].visit(node_id): - await response.write('{}\n'.format(node_id).encode()) - await response.write_eof() - return response + return (await _simple_traversal(request, 'visit_nodes')) + + +async def walk(request): + src = request.match_info['src'] + dst = request.match_info['dst'] + edges = request.query.get('edges', '*') + direction = request.query.get('direction', 'forward') + algo = request.query.get('traversal', 'dfs') + + it = request.app['backend'].walk(direction, edges, algo, src, dst) + async with stream_response(request) as response: + async for res_pid in it: + await response.write('{}\n'.format(res_pid).encode()) + return response def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_route('GET', '/', index) - - # Endpoints used by the web API - app.router.add_route('GET', '/visit/{id}', visit) + app.router.add_route('GET', '/graph/stats', stats) + app.router.add_route('GET', '/graph/leaves/{src}', leaves) + app.router.add_route('GET', '/graph/neighbors/{src}', neighbors) + app.router.add_route('GET', '/graph/walk/{src}/{dst}', walk) + app.router.add_route('GET', '/graph/visit/nodes/{src}', visit) + # TODO: graph/visit/paths/ ? app['backend'] = backend return app def main(): parser = argparse.ArgumentParser() parser.add_argument('--host', default='0.0.0.0') parser.add_argument('--port', type=int, default=5009) parser.add_argument('--graph', required=True) args = parser.parse_args() backend = Backend(graph_path=args.graph) app = make_app(backend=backend) with backend: aiohttp.web.run_app(app, host=args.host, port=args.port) if __name__ == '__main__': main() diff --git a/swh/graph/server/backend.py b/swh/graph/server/backend.py index 3e13f42..4838f0b 100644 --- a/swh/graph/server/backend.py +++ b/swh/graph/server/backend.py @@ -1,60 +1,144 @@ import asyncio +import contextlib import os import struct import sys import tempfile from py4j.java_gateway import JavaGateway -GATEWAY_SERVER_PORT = 25335 +from swh.graph.pid import IntToPidMap, PidToIntMap BUF_SIZE = 64*1024 BIN_FMT = '>q' # 64 bit integer, big endian +NODE2PID_EXT = 'node2pid.bin' +PID2NODE_EXT = 'pid2node.bin' - -async def read_node_ids(fname): - loop = asyncio.get_event_loop() - with open(fname, 'rb') as f: - while True: - data = await loop.run_in_executor(None, f.read, BUF_SIZE) - if not data: - break - for data in struct.iter_unpack(BIN_FMT, data): - yield data[0] +JAR_PATH = os.path.join( + os.path.dirname(__file__), '../../..', + 'java/server/target/swh-graph-0.0.2-jar-with-dependencies.jar' +) class Backend: def __init__(self, graph_path): self.gateway = None self.entry = None self.graph_path = graph_path def __enter__(self): self.gateway = JavaGateway.launch_gateway( - port=GATEWAY_SERVER_PORT, - classpath='java/server/target/swh-graph-0.0.2-jar-with-dependencies.jar', + java_path=None, + classpath=JAR_PATH, # noqa die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=sys.stderr, ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) - # "/home/seirl/swh-graph/sample/big/compressed/swh-graph") + self.node2pid = IntToPidMap(self.graph_path + '.' + NODE2PID_EXT) + self.pid2node = PidToIntMap(self.graph_path + '.' + PID2NODE_EXT) + self.stream_proxy = JavaStreamProxy(self.entry) - def __exit__(self): + def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() - async def visit(self, node_id): + def stats(self): + return self.entry.stats() + + async def _simple_traversal(self, ttype, direction, edges_fmt, src): + assert ttype in ('leaves', 'neighbors', 'visit_nodes') + src_id = self.pid2node[src] + method = getattr(self.stream_proxy, ttype) + async for node_id in method(direction, edges_fmt, src_id): + yield self.node2pid[node_id] + + async def leaves(self, *args): + async for res_pid in self._simple_traversal('leaves', *args): + yield res_pid + + async def neighbors(self, *args): + async for res_pid in self._simple_traversal('neighbors', *args): + yield res_pid + + async def visit_nodes(self, *args): + async for res_pid in self._simple_traversal('visit_nodes', *args): + yield res_pid + + async def walk(self, direction, edges_fmt, algo, src, dst): + src_id = self.pid2node[src] + if dst in ('cnt', 'dir', 'rel', 'rev', 'snp', 'ori'): + it = self.stream_proxy.walk_type(direction, edges_fmt, algo, + src_id, dst) + else: + dst_id = self.pid2node[dst] + it = self.stream_proxy.walk(direction, edges_fmt, algo, + src_id, dst_id) + + async for node_id in it: + yield self.node2pid[node_id] + + +class JavaStreamProxy: + """A proxy class for the org.softwareheritage.graph.Entry Java class that + takes care of the setup and teardown of the named-pipe FIFO communication + between Python and Java. + + Initialize JavaStreamProxy using: + + proxy = JavaStreamProxy(swh_entry_class_instance) + + Then you can call an Entry method and iterate on the FIFO results like + this: + + async for value in proxy.java_method(arg1, arg2): + print(value) + """ + + def __init__(self, entry): + self.entry = entry + + async def read_node_ids(self, fname): loop = asyncio.get_event_loop() + with (await loop.run_in_executor(None, open, fname, 'rb')) as f: + while True: + data = await loop.run_in_executor(None, f.read, BUF_SIZE) + if not data: + break + for data in struct.iter_unpack(BIN_FMT, data): + yield data[0] + + class _HandlerWrapper: + def __init__(self, handler): + self._handler = handler - with tempfile.TemporaryDirectory() as tmpdirname: + def __getattr__(self, name): + func = getattr(self._handler, name) + + async def java_call(*args, **kwargs): + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + def java_task(*args, **kwargs): + return asyncio.create_task(java_call(*args, **kwargs)) + + return java_task + + @contextlib.contextmanager + def get_handler(self): + with tempfile.TemporaryDirectory(prefix='swh-graph-') as tmpdirname: cli_fifo = os.path.join(tmpdirname, 'swh-graph.fifo') os.mkfifo(cli_fifo) + reader = self.read_node_ids(cli_fifo) + query_handler = self.entry.get_handler(cli_fifo) + handler = self._HandlerWrapper(query_handler) + yield (handler, reader) - def _visit(): - return self.entry.visit(node_id, 'forward', '*', cli_fifo) - - java_call = loop.run_in_executor(None, _visit) - async for node_id in read_node_ids(cli_fifo): - yield node_id - await java_call + def __getattr__(self, name): + async def java_call_iterator(*args, **kwargs): + with self.get_handler() as (handler, reader): + java_task = getattr(handler, name)(*args, **kwargs) + async for value in reader: + yield value + await java_task + return java_call_iterator