diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -126,21 +126,21 @@ close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId, Object max_edges) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitNodesVisitor(srcNodeId, this::writeNode, max_edges); close(); } - public void visit_edges(String direction, String edgesFmt, long srcNodeId, Object max_edges) { + public void visit_edges(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.visitNodesVisitor(srcNodeId, null, this::writeEdge, max_edges, false); + t.visitNodesVisitor(srcNodeId, null, this::writeEdge, max_edges); close(); } - public void visit_paths(String direction, String edgesFmt, long srcNodeId, Object max_edges) { + public void visit_paths(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitPathsVisitor(srcNodeId, this::writePath, max_edges); diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -146,28 +146,21 @@ /** * Push version of {@link #visitNodes}: will fire passed callback on each visited node. */ - public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb, Object max_edges, - boolean limitedVisit) { + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb, long max_edges) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); - Long limit_edges = null; - if (!Objects.isNull(max_edges)) { - limitedVisit = true; - limit_edges = Long.valueOf(max_edges.toString()); - } - while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (nodeCb != null) { nodeCb.accept(currentNodeId); } nbEdgesAccessed += graph.outdegree(currentNodeId); - if (limitedVisit) { - if (limit_edges.compareTo(nbEdgesAccessed) <= 0) { + if (max_edges > 0) { + if (nbEdgesAccessed >= max_edges) { break; } } @@ -186,12 +179,12 @@ /** Two argument version for count_visitor */ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) { - visitNodesVisitor(srcNodeId, cb, null, null, false); + visitNodesVisitor(srcNodeId, cb, null, 0); } /** One-argument version to handle callbacks properly */ - public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb, Object max_edges) { - visitNodesVisitor(srcNodeId, cb, null, max_edges, false); + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb, long max_edges) { + visitNodesVisitor(srcNodeId, cb, null, max_edges); } /** @@ -202,7 +195,7 @@ */ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); - visitNodesVisitor(srcNodeId, nodeIds::add, null); + visitNodesVisitor(srcNodeId, nodeIds::add, 0); return nodeIds; } @@ -210,17 +203,10 @@ * Push version of {@link #visitPaths}: will fire passed callback on each discovered (complete) * path. */ - public void visitPathsVisitor(long srcNodeId, PathConsumer cb, Object max_edges) { + public void visitPathsVisitor(long srcNodeId, PathConsumer cb, long max_edges) { Stack currentPath = new Stack<>(); this.nbEdgesAccessed = 0; - boolean limitedVisit = false; - if (!Objects.isNull(max_edges)) { - limitedVisit = true; - Long l = Long.valueOf(max_edges.toString()); - visitPathsInternalVisitor(srcNodeId, currentPath, cb, l, limitedVisit); - } else { - visitPathsInternalVisitor(srcNodeId, currentPath, cb, null, limitedVisit); - } + visitPathsInternalVisitor(srcNodeId, currentPath, cb, max_edges); } /** @@ -231,26 +217,26 @@ */ public ArrayList> visitPaths(long srcNodeId) { ArrayList> paths = new ArrayList<>(); - visitPathsVisitor(srcNodeId, paths::add, null); + visitPathsVisitor(srcNodeId, paths::add, 0); return paths; } - private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb, Long max_edges, - boolean limitedVisit) { + private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb, + long max_edges) { currentPath.push(currentNodeId); long visitedNeighbors = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); - if (limitedVisit) { - if (max_edges.compareTo(nbEdgesAccessed) <= 0) { + if (max_edges > 0) { + if (nbEdgesAccessed >= max_edges) { currentPath.pop(); return; } } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { - visitPathsInternalVisitor(neighborNodeId, currentPath, cb, max_edges, limitedVisit); + visitPathsInternalVisitor(neighborNodeId, currentPath, cb, max_edges); visitedNeighbors++; } diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -45,20 +45,20 @@ "neighbors/{}".format(src), params={"edges": edges, "direction": direction} ) - def visit_nodes(self, src, edges="*", direction="forward", max_edges=None): + def visit_nodes(self, src, edges="*", direction="forward", max_edges=0): return self.get_lines( "visit/nodes/{}".format(src), params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) - def visit_edges(self, src, edges="*", direction="forward", max_edges=None): + def visit_edges(self, src, edges="*", direction="forward", max_edges=0): for edge in self.get_lines( "visit/edges/{}".format(src), params={"edges": edges, "direction": direction, "max_edges": max_edges}, ): yield tuple(edge.split()) - def visit_paths(self, src, edges="*", direction="forward", max_edges=None): + def visit_paths(self, src, edges="*", direction="forward", max_edges=0): def decode_path_wrapper(it): for e in it: yield json.loads(e) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -112,10 +112,11 @@ def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" - s: Optional[int] = self.request.query.get("max_edges") - if s is not None: + s = self.request.query.get("max_edges", "0") + try: return int(s) - return s + except ValueError: + raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") class StreamingGraphView(GraphView):