diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -45,6 +45,8 @@ /** The string represent the set of type restriction */ NodesFiltering ndsfilter; + long currentEdgeAccessed = 0; + /** random number generator, for random walks */ Random rng; @@ -115,7 +117,6 @@ public void leavesVisitor(long srcNodeId, NodeIdConsumer cb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; - stack.push(srcNodeId); visited.add(srcNodeId); @@ -124,14 +125,15 @@ long neighborsCnt = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - break; - } - } + LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { neighborsCnt++; + if (this.maxEdges > 0) { + if (neighborsCnt >= this.maxEdges) { + return; + } + } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); @@ -164,14 +166,17 @@ */ public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) { this.nbEdgesAccessed = graph.outdegree(srcNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - return; - } - } + int currentEdgeAccessed = 0; LazyLongIterator it = graph.successors(srcNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { cb.accept(neighborNodeId); + currentEdgeAccessed++; + if (this.maxEdges > 0) { + if (currentEdgeAccessed >= this.maxEdges) { + currentEdgeAccessed = 0; + return; + } + } } } @@ -196,25 +201,36 @@ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; - stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (nodeCb != null) { + if (this.maxEdges > 0) { + // we can go through n arcs, so at the end we must have + // the source node + n nodes reached through these n arcs + if (currentEdgeAccessed > this.maxEdges) { + currentEdgeAccessed = 0; + break; + } + } nodeCb.accept(currentNodeId); + currentEdgeAccessed++; } nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - break; - } - } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { + if (this.maxEdges > 0) { + if (currentEdgeAccessed >= this.maxEdges) { + currentEdgeAccessed = 0; + return; + } + } edgeCb.accept(currentNodeId, neighborNodeId); + currentEdgeAccessed++; + } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); @@ -268,18 +284,17 @@ private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb) { currentPath.push(currentNodeId); - long visitedNeighbors = 0; - nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - currentPath.pop(); - return; - } - } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { + if (this.maxEdges > 0) { + if (currentEdgeAccessed >= this.maxEdges) { + currentEdgeAccessed = 0; + break; + } + } + currentEdgeAccessed++; visitPathsInternalVisitor(neighborNodeId, currentPath, cb); visitedNeighbors++; } diff --git a/swh/graph/naive_client.py b/swh/graph/naive_client.py --- a/swh/graph/naive_client.py +++ b/swh/graph/naive_client.py @@ -175,7 +175,7 @@ if max_edges == 0: max_edges = None # type: ignore else: - max_edges -= 1 + max_edges yield from list(self.graph.iter_edges_dfs(direction, edges, src))[:max_edges] @check_arguments diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -139,36 +139,17 @@ assert set(actual) == set(expected) -def test_visit_edges_limited(graph_client): +@pytest.mark.parametrize("max_edges", [1, 2, 3, 4, 5]) +def test_visit_edges_limited(graph_client, max_edges): actual = list( graph_client.visit_edges( - "swh:1:rel:0000000000000000000000000000000000000010", - max_edges=4, - edges="rel:rev,rev:rev,rev:dir", + "swh:1:rel:0000000000000000000000000000000000000010", max_edges=max_edges ) ) - expected = [ - ( - "swh:1:rel:0000000000000000000000000000000000000010", - "swh:1:rev:0000000000000000000000000000000000000009", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:rev:0000000000000000000000000000000000000003", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000009", - "swh:1:dir:0000000000000000000000000000000000000008", - ), - ( - "swh:1:rev:0000000000000000000000000000000000000003", - "swh:1:dir:0000000000000000000000000000000000000002", - ), - ] - # As there are four valid answers (up to reordering), we cannot check for - # equality. Instead, we check the client returned all edges but one. - assert set(actual).issubset(set(expected)) - assert len(actual) == 3 + # As there are multiple valid answers for every value of max_edges (<= 3), + # we cannot check for equality. + # Instead, we check the client returned all edges but one. + assert len(actual) == max_edges def test_visit_edges_diamond_pattern(graph_client):