diff --git a/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java b/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java
--- a/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java
+++ b/java/src/main/java/org/softwareheritage/graph/AllowedNodes.java
@@ -1,7 +1,7 @@
package org.softwareheritage.graph;
/**
- * TODO
+ * Node type restriction, useful to implement filtering of returned nodes during traversal.
*
* @author The Software Heritage developers
*/
diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
deleted file mode 100644
--- a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
+++ /dev/null
@@ -1,107 +0,0 @@
-package org.softwareheritage.graph;
-
-import java.util.ArrayList;
-
-/**
- * NodesFiltering
- *
- * class that manages the filtering of nodes that have been returned after a visit of the graph.
- * parameterized by a string that represents either no filtering (*) or a set of node types.
- *
- *
- *
- *
- * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==rev will
- * only return 'rev' nodes.
- *
- * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010
- * return_types==rev,snp,cnt will only return 'rev' 'snp' 'cnt' nodes.
- *
- * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==* will
- * return all the nodes.
- *
- *
- * How to use NodesFiltering :
- *
- *
- * {@code
- * Long id1 = .... // graph.getNodeType(id1) == CNT
- * Long id2 = .... // graph.getNodeType(id2) == SNP
- * Long id3 = .... // graph.getNodeType(id3) == ORI
- * ArrayList nodeIds = nez ArrayList();
- * nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
- *
- * NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
- * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
- *
- * nds = NodesFiltering("*");
- * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
- *
- * }
- *
- */
-
-public class NodesFiltering {
-
- boolean restricted;
- ArrayList allowedNodesTypes;
-
- /**
- * Default constructor, in order to handle the * case (all types of nodes are allowed to be
- * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes.
- *
- */
- public NodesFiltering() {
- restricted = false;
- allowedNodesTypes = Node.Type.parse("*");
- }
-
- /**
- * Constructor
- *
- * @param strTypes a formatted string describing the types of nodes we want to allow to be shown.
- *
- * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP]
- *
- */
- public NodesFiltering(String strTypes) {
- restricted = true;
- allowedNodesTypes = new ArrayList();
- String[] types = strTypes.split(",");
- for (String type : types) {
- allowedNodesTypes.add(Node.Type.fromStr(type));
- }
- }
-
- /**
- * Check if the type given in parameter is in the list of allowed types.
- *
- * @param typ the type of the node.
- */
- public boolean typeIsAllowed(Node.Type typ) {
- return this.allowedNodesTypes.contains(typ);
- }
-
- /**
- *
- * the function that filters the nodes returned, we browse the list of nodes found after a visit and
- * we create a new list with only the nodes that have a type that is contained in the list of
- * allowed types (allowedNodesTypes)
- *
- *
- * @param nodeIds the nodes founded during the visit
- * @param g the graph in order to find the types of nodes from their id in nodeIds
- * @return a new list with the id of node which have a type in allowedTypes
- *
- *
- */
- public ArrayList filterByNodeTypes(ArrayList nodeIds, SwhBidirectionalGraph g) {
- ArrayList filteredNodes = new ArrayList();
- for (Long node : nodeIds) {
- if (this.typeIsAllowed(g.getNodeType(node))) {
- filteredNodes.add(node);
- }
- }
- return filteredNodes;
- }
-}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -30,8 +30,10 @@
public class Traversal {
/** Graph used in the traversal */
SwhBidirectionalGraph graph;
- /** Graph edge restrictions */
- AllowedEdges edges;
+ /** Type filter on the returned nodes */
+ AllowedNodes nodesFilter;
+ /** Restrictions on which edges can be traversed */
+ AllowedEdges edgesRestrictions;
/** Hash set storing if we have visited a node */
HashSet visited;
@@ -42,8 +44,6 @@
/** The anti Dos limit of edges traversed while a visit */
long maxEdges;
- /** The string represent the set of type restriction */
- NodesFiltering ndsfilter;
/** random number generator, for random walks */
Random rng;
@@ -77,19 +77,14 @@
} else {
this.graph = graph;
}
- this.edges = new AllowedEdges(edgesFmt);
+ this.nodesFilter = new AllowedNodes(returnTypes);
+ this.edgesRestrictions = new AllowedEdges(edgesFmt);
this.visited = new HashSet<>();
this.parentNode = new HashMap<>();
this.nbEdgesAccessed = 0;
this.maxEdges = maxEdges;
this.rng = new Random();
-
- if (returnTypes.equals("*")) {
- this.ndsfilter = new NodesFiltering();
- } else {
- this.ndsfilter = new NodesFiltering(returnTypes);
- }
}
/**
@@ -172,7 +167,7 @@
break;
}
}
- LazyLongIterator it = filterSuccessors(currentNodeId, edges);
+ LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
neighborsCnt++;
if (!visited.contains(neighborNodeId)) {
@@ -182,7 +177,9 @@
}
if (neighborsCnt == 0) {
- cb.accept(currentNodeId);
+ if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) {
+ cb.accept(currentNodeId);
+ }
}
}
}
@@ -196,9 +193,6 @@
public ArrayList leaves(long srcNodeId) {
ArrayList nodeIds = new ArrayList();
leavesVisitor(srcNodeId, nodeIds::add);
- if (ndsfilter.restricted) {
- return ndsfilter.filterByNodeTypes(nodeIds, graph);
- }
return nodeIds;
}
@@ -212,9 +206,11 @@
return;
}
}
- LazyLongIterator it = filterSuccessors(srcNodeId, edges);
+ LazyLongIterator it = filterSuccessors(srcNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
- cb.accept(neighborNodeId);
+ if (nodesFilter.isAllowed(graph.getNodeType(neighborNodeId))) {
+ cb.accept(neighborNodeId);
+ }
}
}
@@ -227,9 +223,6 @@
public ArrayList neighbors(long srcNodeId) {
ArrayList nodeIds = new ArrayList<>();
neighborsVisitor(srcNodeId, nodeIds::add);
- if (ndsfilter.restricted) {
- return ndsfilter.filterByNodeTypes(nodeIds, graph);
- }
return nodeIds;
}
@@ -246,7 +239,9 @@
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
if (nodeCb != null) {
- nodeCb.accept(currentNodeId);
+ if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) {
+ nodeCb.accept(currentNodeId);
+ }
}
nbEdgesAccessed += graph.outdegree(currentNodeId);
if (this.maxEdges > 0) {
@@ -254,10 +249,12 @@
break;
}
}
- LazyLongIterator it = filterSuccessors(currentNodeId, edges);
+ LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (edgeCb != null) {
- edgeCb.accept(currentNodeId, neighborNodeId);
+ if (nodesFilter.isAllowed(graph.getNodeType(currentNodeId))) {
+ edgeCb.accept(currentNodeId, neighborNodeId);
+ }
}
if (!visited.contains(neighborNodeId)) {
stack.push(neighborNodeId);
@@ -281,9 +278,6 @@
public ArrayList visitNodes(long srcNodeId) {
ArrayList nodeIds = new ArrayList<>();
visitNodesVisitor(srcNodeId, nodeIds::add);
- if (ndsfilter.restricted) {
- return ndsfilter.filterByNodeTypes(nodeIds, graph);
- }
return nodeIds;
}
@@ -321,7 +315,7 @@
return;
}
}
- LazyLongIterator it = filterSuccessors(currentNodeId, edges);
+ LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
visitPathsInternalVisitor(neighborNodeId, currentPath, cb);
visitedNeighbors++;
@@ -395,7 +389,7 @@
while (true) {
path.add(curNodeId);
- LazyLongIterator successors = filterSuccessors(curNodeId, edges);
+ LazyLongIterator successors = filterSuccessors(curNodeId, edgesRestrictions);
curNodeId = randomPick(successors);
if (curNodeId < 0) {
found = false;
@@ -409,9 +403,6 @@
}
if (found) {
- if (ndsfilter.restricted) {
- return ndsfilter.filterByNodeTypes(path, graph);
- }
return path;
} else if (retries > 0) { // try again
return randomWalk(srcNodeId, dst, retries - 1);
@@ -462,7 +453,7 @@
}
nbEdgesAccessed += graph.outdegree(currentNodeId);
- LazyLongIterator it = filterSuccessors(currentNodeId, edges);
+ LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (!visited.contains(neighborNodeId)) {
stack.push(neighborNodeId);
@@ -496,7 +487,7 @@
}
nbEdgesAccessed += graph.outdegree(currentNodeId);
- LazyLongIterator it = filterSuccessors(currentNodeId, edges);
+ LazyLongIterator it = filterSuccessors(currentNodeId, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (!visited.contains(neighborNodeId)) {
queue.add(neighborNodeId);
@@ -571,7 +562,7 @@
if (!lhsStack.isEmpty()) {
curNode = lhsStack.poll();
nbEdgesAccessed += graph.outdegree(curNode);
- LazyLongIterator it = filterSuccessors(curNode, edges);
+ LazyLongIterator it = filterSuccessors(curNode, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (!lhsVisited.contains(neighborNodeId)) {
if (rhsVisited.contains(neighborNodeId))
@@ -585,7 +576,7 @@
if (!rhsStack.isEmpty()) {
curNode = rhsStack.poll();
nbEdgesAccessed += graph.outdegree(curNode);
- LazyLongIterator it = filterSuccessors(curNode, edges);
+ LazyLongIterator it = filterSuccessors(curNode, edgesRestrictions);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (!rhsVisited.contains(neighborNodeId)) {
if (lhsVisited.contains(neighborNodeId))
diff --git a/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java b/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java
new file mode 100644
--- /dev/null
+++ b/java/src/test/java/org/softwareheritage/graph/AllowedNodesTest.java
@@ -0,0 +1,53 @@
+package org.softwareheritage.graph;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.Set;
+
+public class AllowedNodesTest extends GraphTest {
+ void assertNodeRestriction(AllowedNodes nodes, Set expectedAllowed) {
+ Node.Type[] nodeTypes = Node.Type.values();
+ for (Node.Type t : nodeTypes) {
+ boolean isAllowed = nodes.isAllowed(t);
+ boolean isExpected = expectedAllowed.contains(t);
+ Assertions.assertEquals(isAllowed, isExpected, "Node type: " + t);
+ }
+ }
+
+ @Test
+ public void dirCntNodes() {
+ AllowedNodes edges = new AllowedNodes("dir,cnt");
+ Set expected = Set.of(Node.Type.DIR, Node.Type.CNT);
+ assertNodeRestriction(edges, expected);
+ }
+
+ @Test
+ public void revDirNodes() {
+ AllowedNodes edges = new AllowedNodes("rev,dir");
+ Set expected = Set.of(Node.Type.DIR, Node.Type.REV);
+ assertNodeRestriction(edges, expected);
+ }
+
+ @Test
+ public void relSnpCntNodes() {
+ AllowedNodes edges = new AllowedNodes("rel,snp,cnt");
+ Set expected = Set.of(Node.Type.REL, Node.Type.SNP, Node.Type.CNT);
+ assertNodeRestriction(edges, expected);
+ }
+
+ @Test
+ public void allNodes() {
+ AllowedNodes edges = new AllowedNodes("*");
+ Set expected = Set.of(Node.Type.REL, Node.Type.SNP, Node.Type.CNT, Node.Type.DIR, Node.Type.REV,
+ Node.Type.ORI);
+ assertNodeRestriction(edges, expected);
+ }
+
+ @Test
+ public void noNodes() {
+ AllowedNodes edges = new AllowedNodes("");
+ Set expected = Set.of();
+ assertNodeRestriction(edges, expected);
+ }
+}