diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,13 +1,13 @@ package org.softwareheritage.graph; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategy; - import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import java.util.ArrayList; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; + public class Entry { private final long PATH_SEPARATOR_ID = -1; private Graph graph; @@ -112,24 +112,30 @@ } } - public void leaves(String direction, String edgesFmt, long srcNodeId) { + public void leaves(String direction, String edgesFmt, long srcNodeId, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.leavesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); + for (Long nodeId : t.leaves(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void neighbors(String direction, String edgesFmt, long srcNodeId) { + public void neighbors(String direction, String edgesFmt, long srcNodeId, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.neighborsVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); + for (Long nodeId : t.neighbors(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); - t.visitNodesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); + for (Long nodeId : t.visitNodes(srcNodeId)) { + writeNode(nodeId); + } close(); } @@ -166,19 +172,21 @@ close(); } - public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) { + public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId, + String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { writeNode(nodeId); } close(); } - public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) { + public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst, + String returnTypes) { open(); Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { writeNode(nodeId); } diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java new file mode 100644 --- /dev/null +++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java @@ -0,0 +1,107 @@ +package org.softwareheritage.graph; + +import java.util.ArrayList; + +/** + *

NodesFiltering

+ *

+ * class that manages the filtering of nodes that have been returned after a visit of the graph. + * parameterized by a string that represents either no filtering (*) or a set of node types. + *

+ * + * + * + * How to use NodesFiltering : + * + *
+ * {@code
+ *  Long id1 = .... // graph.getNodeType(id1) == CNT
+ *  Long id2 = .... // graph.getNodeType(id2) == SNP
+ *  Long id3 = .... // graph.getNodeType(id3) == ORI
+ *  ArrayList nodeIds = nez ArrayList();
+ *  nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+ *
+ *  NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+ *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+ *
+ *  nds = NodesFiltering("*");
+ *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+ *
+ * }
+ * 
+ */ + +public class NodesFiltering { + + boolean restricted; + ArrayList allowedNodesTypes; + + /** + * Default constructor, in order to handle the * case (all types of nodes are allowed to be + * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes. + * + */ + public NodesFiltering() { + restricted = false; + allowedNodesTypes = Node.Type.parse("*"); + } + + /** + * Constructor + * + * @param strTypes a formatted string describing the types of nodes we want to allow to be shown. + * + * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP] + * + */ + public NodesFiltering(String strTypes) { + restricted = true; + allowedNodesTypes = new ArrayList(); + String[] types = strTypes.split(","); + for (String type : types) { + allowedNodesTypes.add(Node.Type.fromStr(type)); + } + } + + /** + * Check if the type given in parameter is in the list of allowed types. + * + * @param typ the type of the node. + */ + public boolean typeIsAllowed(Node.Type typ) { + return this.allowedNodesTypes.contains(typ); + } + + /** + *

+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and + * we create a new list with only the nodes that have a type that is contained in the list of + * allowed types (allowedNodesTypes) + *

+ * + * @param nodeIds the nodes founded during the visit + * @param g the graph in order to find the types of nodes from their id in nodeIds + * @return a new list with the id of node which have a type in allowedTypes + * + * + */ + public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) { + ArrayList filteredNodes = new ArrayList(); + for (Long node : nodeIds) { + if (this.typeIsAllowed(g.getNodeType(node))) { + filteredNodes.add(node); + } + } + return filteredNodes; + } +} diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -1,12 +1,22 @@ package org.softwareheritage.graph; -import it.unimi.dsi.big.webgraph.LazyLongIterator; -import org.softwareheritage.graph.server.Endpoint; - -import java.util.*; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.Stack; import java.util.function.Consumer; import java.util.function.LongConsumer; +import org.softwareheritage.graph.server.Endpoint; + +import it.unimi.dsi.big.webgraph.LazyLongIterator; + /** * Traversal algorithms on the compressed graph. *

@@ -30,6 +40,9 @@ /** Number of edges accessed during traversal */ long nbEdgesAccessed; + /** The string represent the set of type restriction */ + NodesFiltering ndsfilter; + /** random number generator, for random walks */ Random rng; @@ -42,7 +55,12 @@ * "https://docs.softwareheritage.org/devel/swh-graph/api.html#terminology">allowed * edges */ + public Traversal(Graph graph, String direction, String edgesFmt) { + this(graph, direction, edgesFmt, "*"); + } + + public Traversal(Graph graph, String direction, String edgesFmt, String returnTypes) { if (!direction.matches("forward|backward")) { throw new IllegalArgumentException("Unknown traversal direction: " + direction); } @@ -58,6 +76,12 @@ this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; this.rng = new Random(); + + if (returnTypes.equals("*")) { + this.ndsfilter = new NodesFiltering(); + } else { + this.ndsfilter = new NodesFiltering(returnTypes); + } } /** @@ -115,8 +139,11 @@ * @return list of node ids corresponding to the leaves */ public ArrayList leaves(long srcNodeId) { - ArrayList nodeIds = new ArrayList<>(); + ArrayList nodeIds = new ArrayList(); leavesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -140,6 +167,9 @@ public ArrayList neighbors(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); neighborsVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -187,6 +217,9 @@ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); visitNodesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -305,6 +338,9 @@ } if (found) { + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(path, graph); + } return path; } else if (retries > 0) { // try again return randomWalk(srcNodeId, dst, retries - 1); diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -68,10 +68,10 @@ method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) - async def simple_traversal(self, ttype, direction, edges_fmt, src): + async def simple_traversal(self, ttype, direction, edges_fmt, src, return_types): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) - async for node_id in method(direction, edges_fmt, src): + async for node_id in method(direction, edges_fmt, src, return_types): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): @@ -82,13 +82,15 @@ async for node_id in it: yield node_id - async def random_walk(self, direction, edges_fmt, retries, src, dst): + async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.random_walk_type( - direction, edges_fmt, retries, src, dst + direction, edges_fmt, retries, src, dst, return_types ) else: - it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst) + it = self.stream_proxy.random_walk( + direction, edges_fmt, retries, src, dst, return_types + ) async for node_id in it: # TODO return 404 if path is empty yield node_id diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -12,7 +12,10 @@ """Graph API Error""" def __str__(self): - return "An unexpected error occurred in the Graph backend: {}".format(self.args) + return """An unexpected error occurred + in the Graph backend: {}""".format( + self.args + ) class RemoteGraphClient(RPCClient): @@ -35,20 +38,34 @@ def stats(self): return self.get("stats") - def leaves(self, src, edges="*", direction="forward"): + def leaves(self, src, edges="*", direction="forward", return_types="*"): return self.get_lines( - "leaves/{}".format(src), params={"edges": edges, "direction": direction} + "leaves/{}".format(src), + params={ + "edges": edges, + "direction": direction, + "return_types": return_types, + }, ) - def neighbors(self, src, edges="*", direction="forward"): + def neighbors(self, src, edges="*", direction="forward", return_types="*"): return self.get_lines( - "neighbors/{}".format(src), params={"edges": edges, "direction": direction} + "neighbors/{}".format(src), + params={ + "edges": edges, + "direction": direction, + "return_types": return_types, + }, ) - def visit_nodes(self, src, edges="*", direction="forward"): + def visit_nodes(self, src, edges="*", direction="forward", return_types="*"): return self.get_lines( "visit/nodes/{}".format(src), - params={"edges": edges, "direction": direction}, + params={ + "edges": edges, + "direction": direction, + "return_types": return_types, + }, ) def visit_edges(self, src, edges="*", direction="forward"): @@ -84,11 +101,18 @@ }, ) - def random_walk(self, src, dst, edges="*", direction="forward", limit=None): + def random_walk( + self, src, dst, edges="*", direction="forward", limit=None, return_types="*" + ): endpoint = "randomwalk/{}/{}" return self.get_lines( endpoint.format(src, dst), - params={"edges": edges, "direction": direction, "limit": limit}, + params={ + "edges": edges, + "direction": direction, + "limit": limit, + "return_types": return_types, + }, ) def count_leaves(self, src, edges="*", direction="forward"): diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -74,9 +74,14 @@ lambda: self.graph.java_graph.indegree(self.id), ) - def simple_traversal(self, ttype, direction="forward", edges="*"): + def simple_traversal(self, ttype, direction="forward", edges="*", return_types="*"): for node in call_async_gen( - self.graph.backend.simple_traversal, ttype, direction, edges, self.id + self.graph.backend.simple_traversal, + ttype, + direction, + edges, + self.id, + return_types, ): yield self.graph[node] diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -94,6 +94,24 @@ raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s + def get_return_types(self): + """Validate HTTP query parameter 'return types', i.e, + a set of types which we will filter the query results with""" + s = self.request.query.get("return_types", "*") + if any( + node_type != "*" and node_type not in EXTENDED_SWHID_TYPES + for node_type in s.split(",") + ): + raise aiohttp.web.HTTPBadRequest( + text=f"invalid type for filtering res: {s}" + ) + # if the user puts a star, + # then we filter nothing, we don't need the other information + if "*" in s: + return "*" + else: + return s + def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") @@ -166,10 +184,15 @@ self.edges = self.get_edges() self.direction = self.get_direction() + self.return_types = self.get_return_types() async def stream_response(self): async for res_node in self.backend.simple_traversal( - self.simple_traversal_type, self.direction, self.edges, self.src_node + self.simple_traversal_type, + self.direction, + self.edges, + self.src_node, + self.return_types, ): res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) @@ -201,6 +224,7 @@ self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() + self.return_types = self.get_return_types() async def get_walk_iterator(self): return self.backend.walk( @@ -230,7 +254,12 @@ class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.random_walk( - self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing + self.direction, + self.edges, + RANDOM_RETRIES, + self.src_node, + self.dst_thing, + self.return_types, ) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -75,6 +75,41 @@ assert set(actual) == set(expected) +def test_visit_nodes_filtered(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir", + ) + ) + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="*", + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected) + + def test_visit_edges(graph_client): actual = list( graph_client.visit_edges( @@ -212,9 +247,9 @@ def test_random_walk(graph_client): - """as the walk is random, we test a visit from a cnt node to the only origin in - the dataset, and only check the final node of the path (i.e., the origin) - + """as the walk is random, we test a visit from a cnt node to the only + origin in the dataset, and only check the final node of the path + (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"}