diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -45,6 +45,8 @@ /** The string represent the set of type restriction */ NodesFiltering ndsfilter; + long currentEdgeAccessed = 0; + /** random number generator, for random walks */ Random rng; @@ -115,7 +117,6 @@ public void leavesVisitor(long srcNodeId, NodeIdConsumer cb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; - stack.push(srcNodeId); visited.add(srcNodeId); @@ -124,14 +125,15 @@ long neighborsCnt = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - break; - } - } + LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { neighborsCnt++; + if (this.maxEdges > 0) { + if (neighborsCnt >= this.maxEdges) { + return; + } + } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); @@ -164,14 +166,15 @@ */ public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) { this.nbEdgesAccessed = graph.outdegree(srcNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - return; - } - } LazyLongIterator it = graph.successors(srcNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { cb.accept(neighborNodeId); + this.currentEdgeAccessed++; + if (this.maxEdges > 0) { + if (this.currentEdgeAccessed == this.maxEdges) { + return; + } + } } } @@ -196,25 +199,34 @@ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; - stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (nodeCb != null) { + if (this.maxEdges > 0) { + // we can go through n arcs, so at the end we must have + // the source node + n nodes reached through these n arcs + if (this.currentEdgeAccessed > this.maxEdges) { + break; + } + } nodeCb.accept(currentNodeId); + this.currentEdgeAccessed++; } nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - break; - } - } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { + if (this.maxEdges > 0) { + if (this.currentEdgeAccessed == this.maxEdges) { + return; + } + } edgeCb.accept(currentNodeId, neighborNodeId); + this.currentEdgeAccessed++; + } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); @@ -268,18 +280,16 @@ private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb) { currentPath.push(currentNodeId); - long visitedNeighbors = 0; - nbEdgesAccessed += graph.outdegree(currentNodeId); - if (this.maxEdges > 0) { - if (nbEdgesAccessed >= this.maxEdges) { - currentPath.pop(); - return; - } - } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { + if (this.maxEdges > 0) { + if (this.currentEdgeAccessed == this.maxEdges) { + break; + } + } + this.currentEdgeAccessed++; visitPathsInternalVisitor(neighborNodeId, currentPath, cb); visitedNeighbors++; } diff --git a/swh/graph/naive_client.py b/swh/graph/naive_client.py --- a/swh/graph/naive_client.py +++ b/swh/graph/naive_client.py @@ -174,8 +174,6 @@ ) -> Iterator[Tuple[str, str]]: if max_edges == 0: max_edges = None # type: ignore - else: - max_edges -= 1 yield from list(self.graph.iter_edges_dfs(direction, edges, src))[:max_edges] @check_arguments diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -139,12 +139,13 @@ assert set(actual) == set(expected) -def test_visit_edges_limited(graph_client): +@pytest.mark.parametrize("max_edges", [1, 2, 3, 4, 5]) +def test_visit_edges_limited(graph_client, max_edges, edges): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", - max_edges=4, - edges="rel:rev,rev:rev,rev:dir", + max_edges=max_edges, + edges=edges, ) ) expected = [ @@ -160,15 +161,40 @@ "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), + ( + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ), + ( + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + ), + ( + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000001", + ), + ( + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000005", + ), + ( + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), + ( + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + ), ] - # As there are four valid answers (up to reordering), we cannot check for - # equality. Instead, we check the client returned all edges but one. + # As there are multiple valid answers for every value of max_edges (<= 3), + # we cannot check for equality. + # Instead, we check the client returned all edges but one. assert set(actual).issubset(set(expected)) - assert len(actual) == 3 + assert len(actual) == max_edges def test_visit_edges_diamond_pattern(graph_client):