diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
--- a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
+++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
@@ -1,43 +1,104 @@
package org.softwareheritage.graph;
import java.util.ArrayList;
-import java.util.Iterator;
/**
+ *
NodesFiltering
+ *
* class that manages the filtering of nodes that have been returned after a visit of the graph.
- * parameterized by a string that represents either no filtering (*) or a set of node types
+ * parameterized by a string that represents either no filtering (*) or a set of node types.
+ *
*
+ *
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==rev will
+ * only return 'rev' nodes.
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010
+ * return_types==rev,snp,cnt will only return 'rev' 'snp' 'cnt' nodes.
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==* will
+ * return all the nodes.
+ *
*/
public class NodesFiltering {
boolean restricted;
- ArrayList restrictedNodesTypes;
-
- public NodesFiltering(String t) {
- if (t.equals("*")) {
- restricted = false;
- restrictedNodesTypes = Node.Type.parse("*"); // contains all types, no restriction
- } else {
- restricted = true;
- restrictedNodesTypes = new ArrayList();
- String[] types = t.split(",");
- for (int i = 0; i < types.length; i++) {
- restrictedNodesTypes.add(Node.Type.fromStr(types[i]));
- }
+ ArrayList allowedNodesTypes;
+
+ /**
+ * Default constructor, in order to handle the * case (all types of nodes are allowed to be
+ * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes.
+ *
+ */
+ public NodesFiltering() {
+ restricted = false;
+ allowedNodesTypes = Node.Type.parse("*");
+ }
+
+ /**
+ * Constructor
+ *
+ * @param strTypes a formatted string describing the types of nodes we want to allow to be shown.
+ *
+ * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP]
+ *
+ */
+ public NodesFiltering(String strTypes) {
+ restricted = true;
+ allowedNodesTypes = new ArrayList();
+ String[] types = strTypes.split(",");
+ for (String type : types) {
+ allowedNodesTypes.add(Node.Type.fromStr(type));
}
}
+ /**
+ * Check if the type given in parameter is in the list of allowed types.
+ *
+ * @param typ the type of the node.
+ */
public boolean typeIsAllowed(Node.Type typ) {
- return this.restrictedNodesTypes.contains(typ);
+ return this.allowedNodesTypes.contains(typ);
}
+ /**
+ *
+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and
+ * we create a new list with only the nodes that have a type that is contained in the list of
+ * allowed types (allowedNodesTypes)
+ *
+ *
+ * @param nodeIds the nodes founded during the visit
+ * @param g the graph in order to find the types of nodes from their id in nodeIds
+ * @return a new list with the id of node which have a type in allowedTypes
+ *
+ *
+ *
+ * {@code
+ * Long id1 = .... // graph.getNodeType(id1) == CNT
+ * Long id2 = .... // graph.getNodeType(id2) == SNP
+ * Long id3 = .... // graph.getNodeType(id3) == ORI
+ * ArrayList nodeIds = nez ArrayList();
+ * nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+ *
+ * NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+ *
+ * nds = NodesFilterind("*");
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+ *
+ * }
+ *
+ */
public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) {
- for (Iterator it = nodeIds.iterator(); it.hasNext();) {
- if (this.typeIsAllowed(g.getNodeType(it.next()))) {
- it.remove();
+ ArrayList filteredNodes = new ArrayList();
+ for (Long node : nodeIds) {
+ if (this.typeIsAllowed(g.getNodeType(node))) {
+ filteredNodes.add(node);
}
}
- return nodeIds;
+ return filteredNodes;
}
}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -75,8 +75,13 @@
this.visited = new HashSet<>();
this.parentNode = new HashMap<>();
this.nbEdgesAccessed = 0;
- this.ndsfilter = new NodesFiltering(returnTypes);
this.rng = new Random();
+
+ if (returnTypes.equals("*")) {
+ this.ndsfilter = new NodesFiltering();
+ } else {
+ this.ndsfilter = new NodesFiltering(returnTypes);
+ }
}
/**
diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -99,16 +99,14 @@
a set of types which we will filter the query results with"""
s = self.request.query.get("return_types", "*")
if any(
- [
- node_type != "*" and node_type not in EXTENDED_SWHID_TYPES
- for node_type in s.split(",")
- ]
+ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES
+ for node_type in s.split(",")
):
raise aiohttp.web.HTTPBadRequest(
text=f"invalid type for filtering res: {s}"
)
- """ if the user puts a star,
- then we filter nothing, we don't need the other information"""
+ # if the user puts a star,
+ # then we filter nothing, we don't need the other information
if "*" in s:
return "*"
else:
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -75,6 +75,41 @@
assert set(actual) == set(expected)
+def test_visit_nodes_filtered(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir",
+ )
+ )
+ expected = [
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ ]
+ assert set(actual) == set(expected)
+
+
+def test_visit_nodes_filtered_star(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="*",
+ )
+ )
+ expected = [
+ "swh:1:rel:0000000000000000000000000000000000000010",
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:cnt:0000000000000000000000000000000000000007",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ "swh:1:cnt:0000000000000000000000000000000000000004",
+ "swh:1:cnt:0000000000000000000000000000000000000005",
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_edges(graph_client):
actual = list(
graph_client.visit_edges(
@@ -212,9 +247,9 @@
def test_random_walk(graph_client):
- """as the walk is random, we test a visit from a cnt node to the only origin in
- the dataset, and only check the final node of the path (i.e., the origin)
-
+ """as the walk is random, we test a visit from a cnt node to the only
+ origin in the dataset, and only check the final node of the path
+ (i.e., the origin)
"""
args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori")
kwargs = {"direction": "backward"}