diff --git a/java/server/src/main/java/org/softwareheritage/graph/Entry.java b/java/server/src/main/java/org/softwareheritage/graph/Entry.java index c666333..03aae5e 100644 --- a/java/server/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/server/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,139 +1,143 @@ package org.softwareheritage.graph; import java.util.ArrayList; import java.util.Map; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; import py4j.GatewayServer; import org.softwareheritage.graph.Graph; import org.softwareheritage.graph.Node; import org.softwareheritage.graph.algo.Stats; import org.softwareheritage.graph.algo.NodeIdsConsumer; import org.softwareheritage.graph.algo.Traversal; public class Entry { Graph graph; final long PATH_SEPARATOR_ID = -1; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = new Graph(graphBasename); System.err.println("Graph loaded."); } + public Graph get_graph() { + return graph.copy(); + } + public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); String res = objectMapper.writeValueAsString(stats); return res; } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(this.graph.copy(), clientFIFO); } public class QueryHandler { Graph graph; DataOutputStream out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } public void writeNode(long nodeId) { try { out.writeLong(nodeId); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void writePath(ArrayList path) { for (Long nodeId : path) { writeNode(nodeId); } writeNode(PATH_SEPARATOR_ID); } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); this.out = new DataOutputStream(file); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void leaves(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.leavesVisitor(srcNodeId, this::writeNode); close(); } public void neighbors(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.neighborsVisitor(srcNodeId, this::writeNode); close(); } public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitNodesVisitor(srcNodeId, this::writeNode); close(); } public void visit_paths(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } public void walk(String direction, String edgesFmt, String algorithm, long srcNodeId, long dstNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { writeNode(nodeId); } close(); } public void walk_type(String direction, String edgesFmt, String algorithm, long srcNodeId, String dst) { open(); Node.Type dstType = Node.Type.fromStr(dst); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { writeNode(nodeId); } close(); } } } diff --git a/swh/graph/dot.py b/swh/graph/dot.py new file mode 100644 index 0000000..accdc56 --- /dev/null +++ b/swh/graph/dot.py @@ -0,0 +1,49 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from functools import lru_cache +import subprocess +import collections + + +@lru_cache() +def dot_to_svg(dot): + p = subprocess.run(['dot', '-Tsvg'], input=dot, universal_newlines=True, + capture_output=True, check=True) + return p.stdout + + +def graph_dot(nodes): + ids = {n.id for n in nodes} + + by_kind = collections.defaultdict(list) + for n in nodes: + by_kind[n.kind].append(n) + + forward_edges = [ + (node.id, child.id) + for node in nodes + for child in node.children() + if child.id in ids + ] + backward_edges = [ + (parent.id, node.id) + for node in nodes + for parent in node.parents() + if parent.id in ids + ] + edges = set(forward_edges + backward_edges) + edges_fmt = '\n'.join('{} -> {};'.format(a, b) for a, b in edges) + nodes_fmt = '\n'.join(node.dot_fragment() for node in nodes) + + s = """digraph G {{ + ranksep=1; + nodesep=0.5; + + {nodes} + {edges} + + }}""".format(nodes=nodes_fmt, edges=edges_fmt) + return s diff --git a/swh/graph/graph.py b/swh/graph/graph.py new file mode 100644 index 0000000..f295fb4 --- /dev/null +++ b/swh/graph/graph.py @@ -0,0 +1,113 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import contextlib +from swh.graph.server.backend import Backend +from swh.graph.dot import dot_to_svg, graph_dot + + +KIND_TO_SHAPE = { + 'ori': 'egg', + 'snp': 'doubleoctagon', + 'rel': 'octagon', + 'rev': 'diamond', + 'dir': 'folder', + 'cnt': 'oval', +} + + +class Neighbors: + """Neighbor iterator with custom O(1) length method""" + def __init__(self, parent_graph, iterator, length_func): + self.parent_graph = parent_graph + self.iterator = iterator + self.length_func = length_func + + def __iter__(self): + return self + + def __next__(self): + succ = self.iterator.nextLong() + if succ == -1: + raise StopIteration + return GraphNode(self.parent_graph, succ) + + def __len__(self): + return self.length_func() + + +class GraphNode: + """Node in the SWH graph""" + + def __init__(self, parent_graph, node_id): + self.parent_graph = parent_graph + self.id = node_id + + def children(self): + return Neighbors( + self.parent_graph, + self.parent_graph.java_graph.successors(self.id), + lambda: self.parent_graph.java_graph.outdegree(self.id)) + + def parents(self): + return Neighbors( + self.parent_graph, + self.parent_graph.java_graph.predecessors(self.id), + lambda: self.parent_graph.java_graph.indegree(self.id)) + + @property + def pid(self): + return self.parent_graph.node2pid[self.id] + + @property + def kind(self): + return self.pid.split(':')[2] + + def __str__(self): + return self.pid + + def __repr__(self): + return '<{}>'.format(self.pid) + + def dot_fragment(self): + swh, version, kind, hash = self.pid.split(':') + label = '{}:{}..{}'.format(kind, hash[0:2], hash[-2:]) + shape = KIND_TO_SHAPE[kind] + return '{} [label="{}", shape="{}"];'.format(self.id, label, shape) + + def _repr_svg_(self): + nodes = [self, *list(self.children()), *list(self.parents())] + dot = graph_dot(nodes) + svg = dot_to_svg(dot) + return svg + + +class Graph: + def __init__(self, java_graph, node2pid, pid2node): + self.java_graph = java_graph + self.node2pid = node2pid + self.pid2node = pid2node + + @property + def path(self): + return self.java_graph.getPath() + + def __len__(self): + return self.java_graph.getNbNodes() + + def __getitem__(self, node_id): + if isinstance(node_id, int): + self.node2pid[node_id] # check existence + return GraphNode(self, node_id) + elif isinstance(node_id, str): + node_id = self.pid2node[node_id] + return GraphNode(self, node_id) + + +@contextlib.contextmanager +def load(graph_path): + with Backend(graph_path) as backend: + yield Graph(backend.entry.get_graph(), + backend.node2pid, backend.pid2node) diff --git a/swh/graph/server/backend.py b/swh/graph/server/backend.py index 0aecc14..a0d2ccb 100644 --- a/swh/graph/server/backend.py +++ b/swh/graph/server/backend.py @@ -1,161 +1,162 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import json import os import pathlib import struct import sys import tempfile from py4j.java_gateway import JavaGateway from swh.graph.pid import IntToPidMap, PidToIntMap from swh.model.identifiers import PID_TYPES BUF_SIZE = 64*1024 BIN_FMT = '>q' # 64 bit integer, big endian PATH_SEPARATOR_ID = -1 NODE2PID_EXT = 'node2pid.bin' PID2NODE_EXT = 'pid2node.bin' def find_graph_jar(): swh_graph_root = pathlib.Path(__file__).parents[3] try_paths = [ swh_graph_root / 'java/server/target/', pathlib.Path(sys.prefix) / 'share/swh-graph/', ] for path in try_paths: glob = list(path.glob('swh-graph-*.jar')) if glob: return str(glob[0]) raise RuntimeError("swh-graph-*.jar not found. Have you run `make java`?") class Backend: def __init__(self, graph_path): self.gateway = None self.entry = None self.graph_path = graph_path def __enter__(self): self.gateway = JavaGateway.launch_gateway( java_path=None, classpath=find_graph_jar(), die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=sys.stderr, ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.node2pid = IntToPidMap(self.graph_path + '.' + NODE2PID_EXT) self.pid2node = PidToIntMap(self.graph_path + '.' + PID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) + return self def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() def stats(self): return self.entry.stats() async def simple_traversal(self, ttype, direction, edges_fmt, src): assert ttype in ('leaves', 'neighbors', 'visit_nodes', 'visit_paths') src_id = self.pid2node[src] method = getattr(self.stream_proxy, ttype) async for node_id in method(direction, edges_fmt, src_id): if node_id == PATH_SEPARATOR_ID: yield None else: yield self.node2pid[node_id] async def walk(self, direction, edges_fmt, algo, src, dst): src_id = self.pid2node[src] if dst in PID_TYPES: it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src_id, dst) else: dst_id = self.pid2node[dst] it = self.stream_proxy.walk(direction, edges_fmt, algo, src_id, dst_id) async for node_id in it: yield self.node2pid[node_id] async def visit_paths(self, *args): buffer = [] async for res_pid in self.simple_traversal('visit_paths', *args): if res_pid is None: # Path separator, flush yield json.dumps(buffer) buffer = [] else: buffer.append(res_pid) class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() with (await loop.run_in_executor(None, open, fname, 'rb')) as f: while True: data = await loop.run_in_executor(None, f.read, BUF_SIZE) if not data: break for data in struct.iter_unpack(BIN_FMT, data): yield data[0] class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix='swh-graph-') as tmpdirname: cli_fifo = os.path.join(tmpdirname, 'swh-graph.fifo') os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) async for value in reader: yield value await java_task return java_call_iterator diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py index b6b26bc..eff5f0d 100644 --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -1,39 +1,46 @@ import multiprocessing import pytest from pathlib import Path from aiohttp.test_utils import TestServer, TestClient, loop_context +from swh.graph.graph import load as graph_load from swh.graph.client import RemoteGraphClient from swh.graph.server.backend import Backend from swh.graph.server.app import make_app SWH_GRAPH_ROOT = Path(__file__).parents[3] TEST_GRAPH_PATH = SWH_GRAPH_ROOT / 'tests/dataset/output/example' class GraphServerProcess(multiprocessing.Process): def __init__(self, q, *args, **kwargs): self.q = q super().__init__(*args, **kwargs) def run(self): backend = Backend(graph_path=str(TEST_GRAPH_PATH)) with backend: with loop_context() as loop: app = make_app(backend=backend) client = TestClient(TestServer(app), loop=loop) loop.run_until_complete(client.start_server()) url = client.make_url('/graph/') self.q.put(url) loop.run_forever() @pytest.fixture(scope="module") def graph_client(): queue = multiprocessing.Queue() server = GraphServerProcess(queue) server.start() url = queue.get() yield RemoteGraphClient(str(url)) server.terminate() + + +@pytest.fixture(scope="module") +def graph(): + with graph_load(str(TEST_GRAPH_PATH)) as g: + yield g diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py new file mode 100644 index 0000000..36d1cf1 --- /dev/null +++ b/swh/graph/tests/test_graph.py @@ -0,0 +1,38 @@ +import pytest + + +def test_graph(graph): + assert len(graph) == 21 + + obj = 'swh:1:dir:0000000000000000000000000000000000000008' + node = graph[obj] + + assert str(node) == obj + assert len(node.children()) == 3 + assert len(node.parents()) == 2 + + actual = {p.pid for p in node.children()} + expected = { + 'swh:1:cnt:0000000000000000000000000000000000000001', + 'swh:1:dir:0000000000000000000000000000000000000006', + 'swh:1:cnt:0000000000000000000000000000000000000007' + } + assert expected == actual + + actual = {p.pid for p in node.parents()} + expected = { + 'swh:1:rev:0000000000000000000000000000000000000009', + 'swh:1:dir:0000000000000000000000000000000000000012', + } + assert expected == actual + + +def test_invalid_pid(graph): + with pytest.raises(IndexError): + graph[1337] + + with pytest.raises(IndexError): + graph[len(graph) + 1] + + with pytest.raises(KeyError): + graph['swh:1:dir:0000000000000000000000000000000420000012']