diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -42,21 +42,21 @@ return count[0]; } - public int count_leaves(String direction, String edgesFmt, String src) { + public int count_leaves(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::leavesVisitor, srcNodeId); } - public int count_neighbors(String direction, String edgesFmt, String src) { + public int count_neighbors(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::neighborsVisitor, srcNodeId); } - public int count_visit_nodes(String direction, String edgesFmt, String src) { + public int count_visit_nodes(String direction, String edgesFmt, String src, long maxEdges) { long srcNodeId = graph.getNodeId(new SWHID(src)); - Traversal t = new Traversal(graph.copy(), direction, edgesFmt); + Traversal t = new Traversal(graph.copy(), direction, edgesFmt, maxEdges); return count_visitor(t::visitNodesVisitor, srcNodeId); } @@ -152,17 +152,17 @@ close(); } - public void walk(String direction, String edgesFmt, String algorithm, String src, String dst) { + public void walk(String direction, String edgesFmt, String algorithm, String src, String dst, long maxEdges, + String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); ArrayList res; + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); if (dst.matches("ori|snp|rel|rev|dir|cnt")) { Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.walk(srcNodeId, dstType, algorithm); } else { long dstNodeId = graph.getNodeId(new SWHID(dst)); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.walk(srcNodeId, dstNodeId, algorithm); } for (Long nodeId : res) { @@ -171,17 +171,17 @@ close(); } - public void random_walk(String direction, String edgesFmt, int retries, String src, String dst) { + public void random_walk(String direction, String edgesFmt, int retries, String src, String dst, long maxEdges, + String returnTypes) { long srcNodeId = graph.getNodeId(new SWHID(src)); open(); ArrayList res; + Traversal t = new Traversal(graph, direction, edgesFmt, maxEdges, returnTypes); if (dst.matches("ori|snp|rel|rev|dir|cnt")) { Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.randomWalk(srcNodeId, dstType, retries); } else { long dstNodeId = graph.getNodeId(new SWHID(dst)); - Traversal t = new Traversal(graph, direction, edgesFmt); res = t.randomWalk(srcNodeId, dstNodeId, retries); } for (Long nodeId : res) { diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -75,9 +75,9 @@ raise NameError(f"Unknown SWHID: {m[1]}") raise - def count(self, ttype, direction, edges_fmt, src): + def count(self, ttype, *args): method = getattr(self.entry, "count_" + ttype) - return method(direction, edges_fmt, src) + return method(*args) async def traversal(self, ttype, *args): method = getattr(self.stream_proxy, ttype) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -247,6 +247,8 @@ self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() + self.max_edges = self.get_max_edges() + self.return_types = self.get_return_types() self.check_swhid(self.src) if self.dst not in EXTENDED_SWHID_TYPES: @@ -254,7 +256,14 @@ async def get_walk_iterator(self): return self.backend.traversal( - "walk", self.direction, self.edges, self.algo, self.src, self.dst + "walk", + self.direction, + self.edges, + self.algo, + self.src, + self.dst, + self.max_edges, + self.return_types, ) async def stream_response(self): @@ -284,6 +293,8 @@ RANDOM_RETRIES, self.src, self.dst, + self.max_edges, + self.return_types, ) @@ -298,6 +309,7 @@ self.edges = self.get_edges() self.direction = self.get_direction() + self.max_edges = self.get_max_edges() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( @@ -307,6 +319,7 @@ self.direction, self.edges, self.src, + self.max_edges, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json")