diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java --- a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java +++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java @@ -1,43 +1,104 @@ package org.softwareheritage.graph; import java.util.ArrayList; -import java.util.Iterator; /** + *

NodesFiltering

+ *

* class that manages the filtering of nodes that have been returned after a visit of the graph. - * parameterized by a string that represents either no filtering (*) or a set of node types + * parameterized by a string that represents either no filtering (*) or a set of node types. + *

* + * */ public class NodesFiltering { boolean restricted; - ArrayList restrictedNodesTypes; - - public NodesFiltering(String t) { - if (t.equals("*")) { - restricted = false; - restrictedNodesTypes = Node.Type.parse("*"); // contains all types, no restriction - } else { - restricted = true; - restrictedNodesTypes = new ArrayList(); - String[] types = t.split(","); - for (int i = 0; i < types.length; i++) { - restrictedNodesTypes.add(Node.Type.fromStr(types[i])); - } + ArrayList allowedNodesTypes; + + /** + * Default constructor, in order to handle the * case (all types of nodes are allowed to be + * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes. + * + */ + public NodesFiltering() { + restricted = false; + allowedNodesTypes = Node.Type.parse("*"); + } + + /** + * Constructor + * + * @param strTypes a formatted string describing the types of nodes we want to allow to be shown. + * + * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP] + * + */ + public NodesFiltering(String strTypes) { + restricted = true; + allowedNodesTypes = new ArrayList(); + String[] types = strTypes.split(","); + for (String type : types) { + allowedNodesTypes.add(Node.Type.fromStr(type)); } } + /** + * Check if the type given in parameter is in the list of allowed types. + * + * @param typ the type of the node. + */ public boolean typeIsAllowed(Node.Type typ) { - return this.restrictedNodesTypes.contains(typ); + return this.allowedNodesTypes.contains(typ); } + /** + *

+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and + * we create a new list with only the nodes that have a type that is contained in the list of + * allowed types (allowedNodesTypes) + *

+ * + * @param nodeIds the nodes founded during the visit + * @param g the graph in order to find the types of nodes from their id in nodeIds + * @return a new list with the id of node which have a type in allowedTypes + * + * + *
+     * {@code
+     *  Long id1 = .... // graph.getNodeType(id1) == CNT
+     *  Long id2 = .... // graph.getNodeType(id2) == SNP
+     *  Long id3 = .... // graph.getNodeType(id3) == ORI
+     *  ArrayList nodeIds = nez ArrayList();
+     *  nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+     *
+     *  NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+     *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+     *
+     *  nds = NodesFilterind("*");
+     *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+     *
+     * }
+     * 
+ */ public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) { - for (Iterator it = nodeIds.iterator(); it.hasNext();) { - if (this.typeIsAllowed(g.getNodeType(it.next()))) { - it.remove(); + ArrayList filteredNodes = new ArrayList(); + for (Long node : nodeIds) { + if (this.typeIsAllowed(g.getNodeType(node))) { + filteredNodes.add(node); } } - return nodeIds; + return filteredNodes; } } diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -75,8 +75,13 @@ this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; - this.ndsfilter = new NodesFiltering(returnTypes); this.rng = new Random(); + + if (returnTypes.equals("*")) { + this.ndsfilter = new NodesFiltering(); + } else { + this.ndsfilter = new NodesFiltering(returnTypes); + } } /** diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -99,16 +99,14 @@ a set of types which we will filter the query results with""" s = self.request.query.get("return_types", "*") if any( - [ - node_type != "*" and node_type not in EXTENDED_SWHID_TYPES - for node_type in s.split(",") - ] + node_type != "*" and node_type not in EXTENDED_SWHID_TYPES + for node_type in s.split(",") ): raise aiohttp.web.HTTPBadRequest( text=f"invalid type for filtering res: {s}" ) - """ if the user puts a star, - then we filter nothing, we don't need the other information""" + # if the user puts a star, + # then we filter nothing, we don't need the other information if "*" in s: return "*" else: diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -75,6 +75,41 @@ assert set(actual) == set(expected) +def test_visit_nodes_filtered(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir", + ) + ) + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="*", + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected) + + def test_visit_edges(graph_client): actual = list( graph_client.visit_edges( @@ -212,9 +247,9 @@ def test_random_walk(graph_client): - """as the walk is random, we test a visit from a cnt node to the only origin in - the dataset, and only check the final node of the path (i.e., the origin) - + """as the walk is random, we test a visit from a cnt node to the only + origin in the dataset, and only check the final node of the path + (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"}