diff --git a/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java b/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java --- a/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java +++ b/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java @@ -1,7 +1,7 @@ package org.softwareheritage.graph; /** - * TODO + * Node type restriction, useful to implement filtering of returned nodes during traversal. * * @author The Software Heritage developers */ diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java deleted file mode 100644 --- a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java +++ /dev/null @@ -1,107 +0,0 @@ -package org.softwareheritage.graph; - -import java.util.ArrayList; - -/** - * NodesFiltering - *

- * class that manages the filtering of nodes that have been returned after a visit of the graph. - * parameterized by a string that represents either no filtering (*) or a set of node types. - *

- * - * - * - * How to use NodesFiltering : - * - *
- * {@code
- *  Long id1 = .... // graph.getNodeType(id1) == CNT
- *  Long id2 = .... // graph.getNodeType(id2) == SNP
- *  Long id3 = .... // graph.getNodeType(id3) == ORI
- *  ArrayList nodeIds = nez ArrayList();
- *  nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
- *
- *  NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
- *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
- *
- *  nds = NodesFiltering("*");
- *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
- *
- * }
- * 
- */ - -public class NodesFiltering { - - boolean restricted; - ArrayList allowedNodesTypes; - - /** - * Default constructor, in order to handle the * case (all types of nodes are allowed to be - * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes. - * - */ - public NodesFiltering() { - restricted = false; - allowedNodesTypes = Node.Type.parse("*"); - } - - /** - * Constructor - * - * @param strTypes a formatted string describing the types of nodes we want to allow to be shown. - * - * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP] - * - */ - public NodesFiltering(String strTypes) { - restricted = true; - allowedNodesTypes = new ArrayList(); - String[] types = strTypes.split(","); - for (String type : types) { - allowedNodesTypes.add(Node.Type.fromStr(type)); - } - } - - /** - * Check if the type given in parameter is in the list of allowed types. - * - * @param typ the type of the node. - */ - public boolean typeIsAllowed(Node.Type typ) { - return this.allowedNodesTypes.contains(typ); - } - - /** - *

- * the function that filters the nodes returned, we browse the list of nodes found after a visit and - * we create a new list with only the nodes that have a type that is contained in the list of - * allowed types (allowedNodesTypes) - *

- * - * @param nodeIds the nodes founded during the visit - * @param g the graph in order to find the types of nodes from their id in nodeIds - * @return a new list with the id of node which have a type in allowedTypes - * - * - */ - public ArrayList filterByNodeTypes(ArrayList nodeIds, SwhBidirectionalGraph g) { - ArrayList filteredNodes = new ArrayList(); - for (Long node : nodeIds) { - if (this.typeIsAllowed(g.getNodeType(node))) { - filteredNodes.add(node); - } - } - return filteredNodes; - } -} diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -30,8 +30,10 @@ public class Traversal { /** Graph used in the traversal */ SwhBidirectionalGraph graph; - /** Graph edge restrictions */ - AllowedEdges edges; + /** Type filter on the returned nodes */ + AllowedNodes nodesFilter; + /** Restrictions on which edges can be traversed */ + AllowedEdges edgesRestrictions; /** Hash set storing if we have visited a node */ HashSet visited; @@ -42,8 +44,6 @@ /** The anti Dos limit of edges traversed while a visit */ long maxEdges; - /** The string represent the set of type restriction */ - NodesFiltering ndsfilter; /** random number generator, for random walks */ Random rng; @@ -77,19 +77,14 @@ } else { this.graph = graph; } - this.edges = new AllowedEdges(edgesFmt); + this.nodesFilter = new AllowedNodes(returnTypes); + this.edgesRestrictions = new AllowedEdges(edgesFmt); this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; this.maxEdges = maxEdges; this.rng = new Random(); - - if (returnTypes.equals("*")) { - this.ndsfilter = new NodesFiltering(); - } else { - this.ndsfilter = new NodesFiltering(returnTypes); - } } /** @@ -172,7 +167,7 @@ break; } } - LazyLongIterator it = filterSuccessors(currentNodeId, edges); + LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { neighborsCnt++; if (!visited.contains(neighborNodeId)) { @@ -182,7 +177,9 @@ } if (neighborsCnt == 0) { - cb.accept(currentNodeId); + if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) { + cb.accept(currentNodeId); + } } } } @@ -196,9 +193,6 @@ public ArrayList leaves(long srcNodeId) { ArrayList nodeIds = new ArrayList(); leavesVisitor(srcNodeId, nodeIds::add); - if (ndsfilter.restricted) { - return ndsfilter.filterByNodeTypes(nodeIds, graph); - } return nodeIds; } @@ -212,9 +206,11 @@ return; } } - LazyLongIterator it = filterSuccessors(srcNodeId, edges); + LazyLongIterator it = filterSuccessors(srcNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { - cb.accept(neighborNodeId); + if (nodesFilter.isAllowed(graph.getNodeType(neighborNodeId))) { + cb.accept(neighborNodeId); + } } } @@ -227,9 +223,6 @@ public ArrayList neighbors(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); neighborsVisitor(srcNodeId, nodeIds::add); - if (ndsfilter.restricted) { - return ndsfilter.filterByNodeTypes(nodeIds, graph); - } return nodeIds; } @@ -246,7 +239,9 @@ while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (nodeCb != null) { - nodeCb.accept(currentNodeId); + if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) { + nodeCb.accept(currentNodeId); + } } nbEdgesAccessed += graph.outdegree(currentNodeId); if (this.maxEdges > 0) { @@ -254,10 +249,12 @@ break; } } - LazyLongIterator it = filterSuccessors(currentNodeId, edges); + LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { - edgeCb.accept(currentNodeId, neighborNodeId); + if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) { + edgeCb.accept(currentNodeId, neighborNodeId); + } } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); @@ -281,9 +278,6 @@ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); visitNodesVisitor(srcNodeId, nodeIds::add); - if (ndsfilter.restricted) { - return ndsfilter.filterByNodeTypes(nodeIds, graph); - } return nodeIds; } @@ -321,7 +315,7 @@ return; } } - LazyLongIterator it = filterSuccessors(currentNodeId, edges); + LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { visitPathsInternalVisitor(neighborNodeId, currentPath, cb); visitedNeighbors++; @@ -395,7 +389,7 @@ while (true) { path.add(curNodeId); - LazyLongIterator successors = filterSuccessors(curNodeId, edges); + LazyLongIterator successors = filterSuccessors(curNodeId, edgesRestrictions); curNodeId = randomPick(successors); if (curNodeId < 0) { found = false; @@ -409,9 +403,6 @@ } if (found) { - if (ndsfilter.restricted) { - return ndsfilter.filterByNodeTypes(path, graph); - } return path; } else if (retries > 0) { // try again return randomWalk(srcNodeId, dst, retries - 1); @@ -462,7 +453,7 @@ } nbEdgesAccessed += graph.outdegree(currentNodeId); - LazyLongIterator it = filterSuccessors(currentNodeId, edges); + LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); @@ -496,7 +487,7 @@ } nbEdgesAccessed += graph.outdegree(currentNodeId); - LazyLongIterator it = filterSuccessors(currentNodeId, edges); + LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!visited.contains(neighborNodeId)) { queue.add(neighborNodeId); @@ -571,7 +562,7 @@ if (!lhsStack.isEmpty()) { curNode = lhsStack.poll(); nbEdgesAccessed += graph.outdegree(curNode); - LazyLongIterator it = filterSuccessors(curNode, edges); + LazyLongIterator it = filterSuccessors(curNode, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!lhsVisited.contains(neighborNodeId)) { if (rhsVisited.contains(neighborNodeId)) @@ -585,7 +576,7 @@ if (!rhsStack.isEmpty()) { curNode = rhsStack.poll(); nbEdgesAccessed += graph.outdegree(curNode); - LazyLongIterator it = filterSuccessors(curNode, edges); + LazyLongIterator it = filterSuccessors(curNode, edgesRestrictions); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!rhsVisited.contains(neighborNodeId)) { if (lhsVisited.contains(neighborNodeId)) diff --git a/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java b/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java new file mode 100644 --- /dev/null +++ b/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java @@ -0,0 +1,53 @@ +package org.softwareheritage.graph; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +public class AllowedNodesTest extends GraphTest { + void assertNodeRestriction(AllowedNodes nodes, Set expectedAllowed) { + Node.Type[] nodeTypes = Node.Type.values(); + for (Node.Type t : nodeTypes) { + boolean isAllowed = nodes.isAllowed(t); + boolean isExpected = expectedAllowed.contains(t); + Assertions.assertEquals(isAllowed, isExpected, "Node type: " + t); + } + } + + @Test + public void dirCntNodes() { + AllowedNodes edges = new AllowedNodes("dir,cnt"); + Set expected = Set.of(Node.Type.DIR, Node.Type.CNT); + assertNodeRestriction(edges, expected); + } + + @Test + public void revDirNodes() { + AllowedNodes edges = new AllowedNodes("rev,dir"); + Set expected = Set.of(Node.Type.DIR, Node.Type.REV); + assertNodeRestriction(edges, expected); + } + + @Test + public void relSnpCntNodes() { + AllowedNodes edges = new AllowedNodes("rel,snp,cnt"); + Set expected = Set.of(Node.Type.REL, Node.Type.SNP, Node.Type.CNT); + assertNodeRestriction(edges, expected); + } + + @Test + public void allNodes() { + AllowedNodes edges = new AllowedNodes("*"); + Set expected = Set.of(Node.Type.REL, Node.Type.SNP, Node.Type.CNT, Node.Type.DIR, Node.Type.REV, + Node.Type.ORI); + assertNodeRestriction(edges, expected); + } + + @Test + public void noNodes() { + AllowedNodes edges = new AllowedNodes(""); + Set expected = Set.of(); + assertNodeRestriction(edges, expected); + } +}