diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -126,24 +126,24 @@ close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.visitNodesVisitor(srcNodeId, this::writeNode); + t.visitNodesVisitor(srcNodeId, this::writeNode, max_edges); close(); } - public void visit_edges(String direction, String edgesFmt, long srcNodeId) { + public void visit_edges(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.visitNodesVisitor(srcNodeId, null, this::writeEdge); + t.visitNodesVisitor(srcNodeId, null, this::writeEdge, max_edges); close(); } - public void visit_paths(String direction, String edgesFmt, long srcNodeId) { + public void visit_paths(String direction, String edgesFmt, long srcNodeId, long max_edges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.visitPathsVisitor(srcNodeId, this::writePath); + t.visitPathsVisitor(srcNodeId, this::writePath, max_edges); close(); } diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -146,7 +146,7 @@ /** * Push version of {@link #visitNodes}: will fire passed callback on each visited node. */ - public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) { + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb, long max_edges) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; @@ -158,8 +158,12 @@ if (nodeCb != null) { nodeCb.accept(currentNodeId); } - nbEdgesAccessed += graph.outdegree(currentNodeId); + if (max_edges > 0) { + if (nbEdgesAccessed >= max_edges) { + break; + } + } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { @@ -173,9 +177,14 @@ } } - /** One-argument version to handle callbacks properly */ + /** Two argument version for count_visitor */ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) { - visitNodesVisitor(srcNodeId, cb, null); + visitNodesVisitor(srcNodeId, cb, null, 0); + } + + /** One-argument version to handle callbacks properly */ + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb, long max_edges) { + visitNodesVisitor(srcNodeId, cb, null, max_edges); } /** @@ -186,7 +195,7 @@ */ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); - visitNodesVisitor(srcNodeId, nodeIds::add); + visitNodesVisitor(srcNodeId, nodeIds::add, 0); return nodeIds; } @@ -194,10 +203,10 @@ * Push version of {@link #visitPaths}: will fire passed callback on each discovered (complete) * path. */ - public void visitPathsVisitor(long srcNodeId, PathConsumer cb) { + public void visitPathsVisitor(long srcNodeId, PathConsumer cb, long max_edges) { Stack currentPath = new Stack<>(); this.nbEdgesAccessed = 0; - visitPathsInternalVisitor(srcNodeId, currentPath, cb); + visitPathsInternalVisitor(srcNodeId, currentPath, cb, max_edges); } /** @@ -208,18 +217,26 @@ */ public ArrayList> visitPaths(long srcNodeId) { ArrayList> paths = new ArrayList<>(); - visitPathsVisitor(srcNodeId, paths::add); + visitPathsVisitor(srcNodeId, paths::add, 0); return paths; } - private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb) { + private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb, + long max_edges) { currentPath.push(currentNodeId); long visitedNeighbors = 0; + nbEdgesAccessed += graph.outdegree(currentNodeId); + if (max_edges > 0) { + if (nbEdgesAccessed >= max_edges) { + currentPath.pop(); + return; + } + } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { - visitPathsInternalVisitor(neighborNodeId, currentPath, cb); + visitPathsInternalVisitor(neighborNodeId, currentPath, cb, max_edges); visitedNeighbors++; } diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -69,7 +69,7 @@ return method(direction, edges_fmt, src) async def simple_traversal(self, ttype, direction, edges_fmt, src): - assert ttype in ("leaves", "neighbors", "visit_nodes") + assert ttype in ("leaves", "neighbors") method = getattr(self.stream_proxy, ttype) async for node_id in method(direction, edges_fmt, src): yield node_id @@ -92,8 +92,14 @@ async for node_id in it: # TODO return 404 if path is empty yield node_id - async def visit_edges(self, direction, edges_fmt, src): - it = self.stream_proxy.visit_edges(direction, edges_fmt, src) + async def visit_nodes(self, direction, edges_fmt, src, max_edges): + async for node_id in self.stream_proxy.visit_nodes( + direction, edges_fmt, src, max_edges + ): + yield node_id + + async def visit_edges(self, direction, edges_fmt, src, max_edges): + it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges) # convert stream a, b, c, d -> (a, b), (c, d) prevNode = None async for node in it: @@ -103,9 +109,11 @@ else: prevNode = node - async def visit_paths(self, direction, edges_fmt, src): + async def visit_paths(self, direction, edges_fmt, src, max_edges): path = [] - async for node in self.stream_proxy.visit_paths(direction, edges_fmt, src): + async for node in self.stream_proxy.visit_paths( + direction, edges_fmt, src, max_edges + ): if node == PATH_SEPARATOR_ID: yield path path = [] diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -45,20 +45,20 @@ "neighbors/{}".format(src), params={"edges": edges, "direction": direction} ) - def visit_nodes(self, src, edges="*", direction="forward"): + def visit_nodes(self, src, edges="*", direction="forward", max_edges=0): return self.get_lines( "visit/nodes/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) - def visit_edges(self, src, edges="*", direction="forward"): + def visit_edges(self, src, edges="*", direction="forward", max_edges=0): for edge in self.get_lines( "visit/edges/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ): yield tuple(edge.split()) - def visit_paths(self, src, edges="*", direction="forward"): + def visit_paths(self, src, edges="*", direction="forward", max_edges=0): def decode_path_wrapper(it): for e in it: yield json.loads(e) @@ -66,7 +66,7 @@ return decode_path_wrapper( self.get_lines( "visit/paths/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) ) diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -83,18 +83,21 @@ def leaves(self, *args, **kwargs): yield from self.simple_traversal("leaves", *args, **kwargs) - def visit_nodes(self, *args, **kwargs): - yield from self.simple_traversal("visit_nodes", *args, **kwargs) + def visit_nodes(self, direction="forward", edges="*", max_edges=0): + for node in call_async_gen( + self.graph.backend.visit_nodes, direction, edges, self.id, max_edges + ): + yield self.graph[node] - def visit_edges(self, direction="forward", edges="*"): + def visit_edges(self, direction="forward", edges="*", max_edges=0): for src, dst in call_async_gen( - self.graph.backend.visit_edges, direction, edges, self.id + self.graph.backend.visit_edges, direction, edges, self.id, max_edges ): yield (self.graph[src], self.graph[dst]) - def visit_paths(self, direction="forward", edges="*"): + def visit_paths(self, direction="forward", edges="*", max_edges=0): for path in call_async_gen( - self.graph.backend.visit_paths, direction, edges, self.id + self.graph.backend.visit_paths, direction, edges, self.id, max_edges ): yield [self.graph[node] for node in path] diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -109,6 +109,15 @@ except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") + def get_max_edges(self): + """Validate HTTP query parameter 'max_edges', i.e., + the limit of the number of edges that can be visited""" + s = self.request.query.get("max_edges", "0") + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") + class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" @@ -166,6 +175,7 @@ self.edges = self.get_edges() self.direction = self.get_direction() + self.max_edges = self.get_max_edges() async def stream_response(self): async for res_node in self.backend.simple_traversal( @@ -183,10 +193,6 @@ simple_traversal_type = "neighbors" -class VisitNodesView(SimpleTraversalView): - simple_traversal_type = "visit_nodes" - - class WalkView(StreamingGraphView): async def prepare_response(self): src = self.request.match_info["src"] @@ -234,9 +240,20 @@ ) +class VisitNodesView(SimpleTraversalView): + async def stream_response(self): + async for res_node in self.backend.visit_nodes( + self.direction, self.edges, self.src_node, self.max_edges + ): + res_swhid = self.swhid_of_node(res_node) + await self.stream_line(res_swhid) + + class VisitEdgesView(SimpleTraversalView): async def stream_response(self): - it = self.backend.visit_edges(self.direction, self.edges, self.src_node) + it = self.backend.visit_edges( + self.direction, self.edges, self.src_node, self.max_edges + ) async for (res_src, res_dst) in it: res_src_swhid = self.swhid_of_node(res_src) res_dst_swhid = self.swhid_of_node(res_dst) @@ -247,7 +264,9 @@ content_type = "application/x-ndjson" async def stream_response(self): - it = self.backend.visit_paths(self.direction, self.edges, self.src_node) + it = self.backend.visit_paths( + self.direction, self.edges, self.src_node, self.max_edges + ) async for res_path in it: res_path_swhid = [self.swhid_of_node(n) for n in res_path] line = json.dumps(res_path_swhid) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -103,6 +103,29 @@ assert set(actual) == set(expected) +def test_visit_edges_limited(graph_client): + actual = list( + graph_client.visit_edges( + "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4 + ) + ) + expected = [ + ( + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:dir:0000000000000000000000000000000000000008", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ), + ] + assert set(actual) == set(expected) + + def test_visit_edges_diamond_pattern(graph_client): actual = list( graph_client.visit_edges(