diff --git a/java/server/src/main/java/org/softwareheritage/graph/Entry.java b/java/server/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/server/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/server/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,25 +1,21 @@ package org.softwareheritage.graph; import java.util.ArrayList; -import java.util.Map; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; -import py4j.GatewayServer; +import org.softwareheritage.graph.algo.NodeIdConsumer; -import org.softwareheritage.graph.Graph; -import org.softwareheritage.graph.Node; import org.softwareheritage.graph.algo.Stats; -import org.softwareheritage.graph.algo.NodeIdsConsumer; import org.softwareheritage.graph.algo.Traversal; public class Entry { - Graph graph; + private Graph graph; - final long PATH_SEPARATOR_ID = -1; + private final long PATH_SEPARATOR_ID = -1; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); @@ -36,13 +32,37 @@ Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); - String res = objectMapper.writeValueAsString(stats); - return res; + return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } + private interface NodeCountVisitor { + void accept(long nodeId, NodeIdConsumer consumer); + } + + private int count_visitor(NodeCountVisitor f, long srcNodeId) { + int count[] = { 0 }; + f.accept(srcNodeId, (node) -> { count[0]++; }); + return count[0]; + } + + public int count_leaves(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::leavesVisitor, srcNodeId); + } + + public int count_neighbors(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::neighborsVisitor, srcNodeId); + } + + public int count_visit_nodes(String direction, String edgesFmt, long srcNodeId) { + Traversal t = new Traversal(this.graph, direction, edgesFmt); + return count_visitor(t::visitNodesVisitor, srcNodeId); + } + public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(this.graph.copy(), clientFIFO); } diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -75,6 +75,10 @@ def stats(self): return self.entry.stats() + def count(self, ttype, direction, edges_fmt, src): + method = getattr(self.entry, 'count_' + ttype) + return method(direction, edges_fmt, src) + async def simple_traversal(self, ttype, direction, edges_fmt, src): assert ttype in ('leaves', 'neighbors', 'visit_nodes') method = getattr(self.stream_proxy, ttype) diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -80,3 +80,27 @@ 'traversal': traversal, 'direction': direction }) + + def count_leaves(self, src, edges="*", direction="forward"): + return self.get( + 'leaves/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) + + def count_neighbors(self, src, edges="*", direction="forward"): + return self.get( + 'neighbors/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) + + def count_visit_nodes(self, src, edges="*", direction="forward"): + return self.get( + 'visit/nodes/count/{}'.format(src), + params={ + 'edges': edges, + 'direction': direction + }) diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -5,17 +5,19 @@ import asyncio import contextlib +import functools from swh.graph.backend import Backend from swh.graph.dot import dot_to_svg, graph_dot, KIND_TO_SHAPE -KIND_TO_URL = { - 'ori': 'https://archive.softwareheritage.org/browse/origin/{}', - 'snp': 'https://archive.softwareheritage.org/browse/snapshot/{}', - 'rel': 'https://archive.softwareheritage.org/browse/release/{}', - 'rev': 'https://archive.softwareheritage.org/browse/revision/{}', - 'dir': 'https://archive.softwareheritage.org/browse/directory/{}', - 'cnt': 'https://archive.softwareheritage.org/browse/content/sha1_git:{}/', +BASE_URL = 'https://archive.softwareheritage.org/browse' +KIND_TO_URL_FRAGMENT = { + 'ori': '/origin/{}', + 'snp': '/snapshot/{}', + 'rel': '/release/{}', + 'rev': '/revision/{}', + 'dir': '/directory/{}', + 'cnt': '/content/sha1_git:{}/', } @@ -96,6 +98,13 @@ ): yield self.graph[node] + def _count(self, ttype, direction='forward', edges='*'): + return self.graph.backend.count(ttype, direction, edges, self.id) + + count_leaves = functools.partialmethod(_count, ttype='leaves') + count_neighbors = functools.partialmethod(_count, ttype='neighbors') + count_visit_nodes = functools.partialmethod(_count, ttype='visit_nodes') + @property def pid(self): return self.graph.node2pid[self.id] @@ -113,7 +122,7 @@ def dot_fragment(self): swh, version, kind, hash = self.pid.split(':') label = '{}:{}..{}'.format(kind, hash[0:2], hash[-2:]) - url = KIND_TO_URL[kind].format(hash) + url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) shape = KIND_TO_SHAPE[kind] return ('{} [label="{}", href="{}", target="_blank", shape="{}"];' .format(self.id, label, url, shape)) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -103,10 +103,27 @@ return response +def get_count_handler(ttype): + async def count(request): + backend = request.app['backend'] + + src = request.match_info['src'] + edges = request.query.get('edges', '*') + direction = request.query.get('direction', 'forward') + + src_node = backend.pid2node[src] + cnt = backend.count(ttype, direction, edges, src_node) + return aiohttp.web.Response(body=str(cnt), + content_type='application/json') + + return count + + def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_route('GET', '/', index) app.router.add_route('GET', '/graph/stats', stats) + app.router.add_route('GET', '/graph/leaves/{src}', get_simple_traversal_handler('leaves')) app.router.add_route('GET', '/graph/neighbors/{src}', @@ -116,5 +133,12 @@ app.router.add_route('GET', '/graph/visit/paths/{src}', visit_paths) app.router.add_route('GET', '/graph/walk/{src}/{dst}', walk) + app.router.add_route('GET', '/graph/neighbors/count/{src}', + get_count_handler('neighbors')) + app.router.add_route('GET', '/graph/leaves/count/{src}', + get_count_handler('leaves')) + app.router.add_route('GET', '/graph/visit/nodes/count/{src}', + get_count_handler('visit_nodes')) + app['backend'] = backend return app diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -102,3 +102,21 @@ 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + + +def test_count(graph_client): + print(graph_client) + actual = graph_client.count_leaves( + 'swh:1:ori:0000000000000000000000000000000000000021' + ) + assert actual == 4 + actual = graph_client.count_visit_nodes( + 'swh:1:rel:0000000000000000000000000000000000000010', + edges='rel:rev,rev:rev' + ) + assert actual == 3 + actual = graph_client.count_neighbors( + 'swh:1:rev:0000000000000000000000000000000000000009', + direction='backward' + ) + assert actual == 3 diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py --- a/swh/graph/tests/test_graph.py +++ b/swh/graph/tests/test_graph.py @@ -101,3 +101,12 @@ 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) + + +def test_count(graph): + assert (graph['swh:1:ori:0000000000000000000000000000000000000021'] + .count_leaves() == 4) + assert (graph['swh:1:rel:0000000000000000000000000000000000000010'] + .count_visit_nodes(edges='rel:rev,rev:rev') == 3) + assert (graph['swh:1:rev:0000000000000000000000000000000000000009'] + .count_neighbors(direction='backward') == 3)