diff --git a/java/server/src/main/java/org/softwareheritage/graph/Entry.java b/java/server/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/server/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/server/src/main/java/org/softwareheritage/graph/Entry.java @@ -27,6 +27,10 @@ System.err.println("Graph loaded."); } + public Graph get_graph() { + return graph.copy(); + } + public String stats() { try { Stats stats = new Stats(graph.getPath()); diff --git a/swh/graph/dot.py b/swh/graph/dot.py new file mode 100644 --- /dev/null +++ b/swh/graph/dot.py @@ -0,0 +1,49 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from functools import lru_cache +import subprocess +import collections + + +@lru_cache() +def dot_to_svg(dot): + p = subprocess.run(['dot', '-Tsvg'], input=dot, universal_newlines=True, + capture_output=True, check=True) + return p.stdout + + +def graph_dot(nodes): + ids = {n.id for n in nodes} + + by_kind = collections.defaultdict(list) + for n in nodes: + by_kind[n.kind].append(n) + + forward_edges = [ + (node.id, child.id) + for node in nodes + for child in node.children() + if child.id in ids + ] + backward_edges = [ + (parent.id, node.id) + for node in nodes + for parent in node.parents() + if parent.id in ids + ] + edges = set(forward_edges + backward_edges) + edges_fmt = '\n'.join('{} -> {};'.format(a, b) for a, b in edges) + nodes_fmt = '\n'.join(node.dot_fragment() for node in nodes) + + s = """digraph G {{ + ranksep=1; + nodesep=0.5; + + {nodes} + {edges} + + }}""".format(nodes=nodes_fmt, edges=edges_fmt) + return s diff --git a/swh/graph/graph.py b/swh/graph/graph.py new file mode 100644 --- /dev/null +++ b/swh/graph/graph.py @@ -0,0 +1,113 @@ +# Copyright (C) 2019 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +import contextlib +from swh.graph.server.backend import Backend +from swh.graph.dot import dot_to_svg, graph_dot + + +KIND_TO_SHAPE = { + 'ori': 'egg', + 'snp': 'doubleoctagon', + 'rel': 'octagon', + 'rev': 'diamond', + 'dir': 'folder', + 'cnt': 'oval', +} + + +class Neighbors: + """Neighbor iterator with custom O(1) length method""" + def __init__(self, parent_graph, iterator, length_func): + self.parent_graph = parent_graph + self.iterator = iterator + self.length_func = length_func + + def __iter__(self): + return self + + def __next__(self): + succ = self.iterator.nextLong() + if succ == -1: + raise StopIteration + return GraphNode(self.parent_graph, succ) + + def __len__(self): + return self.length_func() + + +class GraphNode: + """Node in the SWH graph""" + + def __init__(self, parent_graph, node_id): + self.parent_graph = parent_graph + self.id = node_id + + def children(self): + return Neighbors( + self.parent_graph, + self.parent_graph.java_graph.successors(self.id), + lambda: self.parent_graph.java_graph.outdegree(self.id)) + + def parents(self): + return Neighbors( + self.parent_graph, + self.parent_graph.java_graph.predecessors(self.id), + lambda: self.parent_graph.java_graph.indegree(self.id)) + + @property + def pid(self): + return self.parent_graph.node2pid[self.id] + + @property + def kind(self): + return self.pid.split(':')[2] + + def __str__(self): + return self.pid + + def __repr__(self): + return '<{}>'.format(self.pid) + + def dot_fragment(self): + swh, version, kind, hash = self.pid.split(':') + label = '{}:{}..{}'.format(kind, hash[0:2], hash[-2:]) + shape = KIND_TO_SHAPE[kind] + return '{} [label="{}", shape="{}"];'.format(self.id, label, shape) + + def _repr_svg_(self): + nodes = [self, *list(self.children()), *list(self.parents())] + dot = graph_dot(nodes) + svg = dot_to_svg(dot) + return svg + + +class Graph: + def __init__(self, java_graph, node2pid, pid2node): + self.java_graph = java_graph + self.node2pid = node2pid + self.pid2node = pid2node + + @property + def path(self): + return self.java_graph.getPath() + + def __len__(self): + return self.java_graph.getNbNodes() + + def __getitem__(self, node_id): + if isinstance(node_id, int): + self.node2pid[node_id] # check existence + return GraphNode(self, node_id) + elif isinstance(node_id, str): + node_id = self.pid2node[node_id] + return GraphNode(self, node_id) + + +@contextlib.contextmanager +def load(graph_path): + with Backend(graph_path) as backend: + yield Graph(backend.entry.get_graph(), + backend.node2pid, backend.pid2node) diff --git a/swh/graph/server/backend.py b/swh/graph/server/backend.py --- a/swh/graph/server/backend.py +++ b/swh/graph/server/backend.py @@ -56,6 +56,7 @@ self.node2pid = IntToPidMap(self.graph_path + '.' + NODE2PID_EXT) self.pid2node = PidToIntMap(self.graph_path + '.' + PID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) + return self def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() diff --git a/swh/graph/tests/conftest.py b/swh/graph/tests/conftest.py --- a/swh/graph/tests/conftest.py +++ b/swh/graph/tests/conftest.py @@ -4,6 +4,7 @@ from aiohttp.test_utils import TestServer, TestClient, loop_context +from swh.graph.graph import load as graph_load from swh.graph.client import RemoteGraphClient from swh.graph.server.backend import Backend from swh.graph.server.app import make_app @@ -38,3 +39,9 @@ url = queue.get() yield RemoteGraphClient(str(url)) server.terminate() + + +@pytest.fixture(scope="module") +def graph(): + with graph_load(str(TEST_GRAPH_PATH)) as g: + yield g diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py new file mode 100644 --- /dev/null +++ b/swh/graph/tests/test_graph.py @@ -0,0 +1,38 @@ +import pytest + + +def test_graph(graph): + assert len(graph) == 21 + + obj = 'swh:1:dir:0000000000000000000000000000000000000008' + node = graph[obj] + + assert str(node) == obj + assert len(node.children()) == 3 + assert len(node.parents()) == 2 + + actual = {p.pid for p in node.children()} + expected = { + 'swh:1:cnt:0000000000000000000000000000000000000001', + 'swh:1:dir:0000000000000000000000000000000000000006', + 'swh:1:cnt:0000000000000000000000000000000000000007' + } + assert expected == actual + + actual = {p.pid for p in node.parents()} + expected = { + 'swh:1:rev:0000000000000000000000000000000000000009', + 'swh:1:dir:0000000000000000000000000000000000000012', + } + assert expected == actual + + +def test_invalid_pid(graph): + with pytest.raises(IndexError): + graph[1337] + + with pytest.raises(IndexError): + graph[len(graph) + 1] + + with pytest.raises(KeyError): + graph['swh:1:dir:0000000000000000000000000000000420000012']