diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -112,37 +112,43 @@ } } - public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.leavesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.leaves(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.neighborsVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.neighbors(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.visitNodesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.visitNodes(srcNodeId)) { + writeNode(nodeId); + } close(); } public void visit_edges(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, "*"); t.visitNodesVisitor(srcNodeId, null, this::writeEdge); close(); } public void visit_paths(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, "*"); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } @@ -166,19 +172,21 @@ close(); } - public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) { + public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId, + String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { writeNode(nodeId); } close(); } - public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) { + public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst, + String returnTypes) { open(); Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { writeNode(nodeId); } diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java new file mode 100644 --- /dev/null +++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java @@ -0,0 +1,104 @@ +package org.softwareheritage.graph; + +import java.util.ArrayList; + +/** + *

NodesFiltering

+ *

+ * class that manages the filtering of nodes that have been returned after a visit of the graph. + * parameterized by a string that represents either no filtering (*) or a set of node types. + *

+ * + * + */ + +public class NodesFiltering { + + boolean restricted; + ArrayList allowedNodesTypes; + + /** + * Default constructor, in order to handle the * case (all types of nodes are allowed to be + * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes. + * + */ + public NodesFiltering() { + restricted = false; + allowedNodesTypes = Node.Type.parse("*"); + } + + /** + * Constructor + * + * @param strTypes a formatted string describing the types of nodes we want to allow to be shown. + * + * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP] + * + */ + public NodesFiltering(String strTypes) { + restricted = true; + allowedNodesTypes = new ArrayList(); + String[] types = strTypes.split(","); + for (String type : types) { + allowedNodesTypes.add(Node.Type.fromStr(type)); + } + } + + /** + * Check if the type given in parameter is in the list of allowed types. + * + * @param typ the type of the node. + */ + public boolean typeIsAllowed(Node.Type typ) { + return this.allowedNodesTypes.contains(typ); + } + + /** + *

+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and + * we create a new list with only the nodes that have a type that is contained in the list of + * allowed types (allowedNodesTypes) + *

+ * + * @param nodeIds the nodes founded during the visit + * @param g the graph in order to find the types of nodes from their id in nodeIds + * @return a new list with the id of node which have a type in allowedTypes + * + * + *
+     * {@code
+     *  Long id1 = .... // graph.getNodeType(id1) == CNT
+     *  Long id2 = .... // graph.getNodeType(id2) == SNP
+     *  Long id3 = .... // graph.getNodeType(id3) == ORI
+     *  ArrayList nodeIds = nez ArrayList();
+     *  nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+     *
+     *  NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+     *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+     *
+     *  nds = NodesFilterind("*");
+     *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+     *
+     * }
+     * 
+ */ + public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) { + ArrayList filteredNodes = new ArrayList(); + for (Long node : nodeIds) { + if (this.typeIsAllowed(g.getNodeType(node))) { + filteredNodes.add(node); + } + } + return filteredNodes; + } +} diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -33,6 +33,9 @@ /** The anti Dos limit of edges traversed while a visit */ long maxEdges; + /** The string represent the set of type restriction */ + NodesFiltering ndsfilter; + /** random number generator, for random walks */ Random rng; @@ -46,10 +49,14 @@ * edges */ public Traversal(Graph graph, String direction, String edgesFmt) { - this(graph, direction, edgesFmt, 0); + this(graph, direction, edgesFmt, 0, "*"); + } + + public Traversal(Graph graph, String direction, String edgesFmt, String returnTypes) { + this(graph, direction, edgesFmt, 0, returnTypes); } - public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges) { + public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges, String returnTypes) { if (!direction.matches("forward|backward")) { throw new IllegalArgumentException("Unknown traversal direction: " + direction); } @@ -66,6 +73,12 @@ this.nbEdgesAccessed = 0; this.maxEdges = maxEdges; this.rng = new Random(); + + if (returnTypes.equals("*")) { + this.ndsfilter = new NodesFiltering(); + } else { + this.ndsfilter = new NodesFiltering(returnTypes); + } } /** @@ -128,8 +141,11 @@ * @return list of node ids corresponding to the leaves */ public ArrayList leaves(long srcNodeId) { - ArrayList nodeIds = new ArrayList<>(); + ArrayList nodeIds = new ArrayList(); leavesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -158,6 +174,9 @@ public ArrayList neighbors(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); neighborsVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -209,6 +228,9 @@ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); visitNodesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } @@ -334,6 +356,9 @@ } if (found) { + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(path, graph); + } return path; } else if (retries > 0) { // try again return randomWalk(srcNodeId, dst, retries - 1); diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -68,10 +68,12 @@ method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) - async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges): + async def simple_traversal( + self, ttype, direction, edges_fmt, src, max_edges, return_types + ): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) - async for node_id in method(direction, edges_fmt, src, max_edges): + async for node_id in method(direction, edges_fmt, src, max_edges, return_types): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): @@ -82,13 +84,15 @@ async for node_id in it: yield node_id - async def random_walk(self, direction, edges_fmt, retries, src, dst): + async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.random_walk_type( - direction, edges_fmt, retries, src, dst + direction, edges_fmt, retries, src, dst, return_types ) else: - it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst) + it = self.stream_proxy.random_walk( + direction, edges_fmt, retries, src, dst, return_types + ) async for node_id in it: # TODO return 404 if path is empty yield node_id diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -12,7 +12,10 @@ """Graph API Error""" def __str__(self): - return "An unexpected error occurred in the Graph backend: {}".format(self.args) + return """An unexpected error occurred + in the Graph backend: {}""".format( + self.args + ) class RemoteGraphClient(RPCClient): @@ -35,22 +38,43 @@ def stats(self): return self.get("stats") - def leaves(self, src, edges="*", direction="forward", max_edges=0): + def leaves( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "leaves/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) - def neighbors(self, src, edges="*", direction="forward", max_edges=0): + def neighbors( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "neighbors/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) - def visit_nodes(self, src, edges="*", direction="forward", max_edges=0): + def visit_nodes( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "visit/nodes/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) def visit_edges(self, src, edges="*", direction="forward", max_edges=0): @@ -86,11 +110,18 @@ }, ) - def random_walk(self, src, dst, edges="*", direction="forward", limit=None): + def random_walk( + self, src, dst, edges="*", direction="forward", limit=None, return_types="*" + ): endpoint = "randomwalk/{}/{}" return self.get_lines( endpoint.format(src, dst), - params={"edges": edges, "direction": direction, "limit": limit}, + params={ + "edges": edges, + "direction": direction, + "limit": limit, + "return_types": return_types, + }, ) def count_leaves(self, src, edges="*", direction="forward"): diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -74,7 +74,9 @@ lambda: self.graph.java_graph.indegree(self.id), ) - def simple_traversal(self, ttype, direction="forward", edges="*", max_edges=0): + def simple_traversal( + self, ttype, direction="forward", edges="*", max_edges=0, return_types="*" + ): for node in call_async_gen( self.graph.backend.simple_traversal, ttype, @@ -82,6 +84,7 @@ edges, self.id, max_edges, + return_types, ): yield self.graph[node] diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -118,6 +118,24 @@ except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") + def get_return_types(self): + """Validate HTTP query parameter 'return types', i.e, + a set of types which we will filter the query results with""" + s = self.request.query.get("return_types", "*") + if any( + node_type != "*" and node_type not in EXTENDED_SWHID_TYPES + for node_type in s.split(",") + ): + raise aiohttp.web.HTTPBadRequest( + text=f"invalid type for filtering res: {s}" + ) + # if the user puts a star, + # then we filter nothing, we don't need the other information + if "*" in s: + return "*" + else: + return s + class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" @@ -176,6 +194,7 @@ self.edges = self.get_edges() self.direction = self.get_direction() self.max_edges = self.get_max_edges() + self.return_types = self.get_return_types() async def stream_response(self): async for res_node in self.backend.simple_traversal( @@ -184,6 +203,7 @@ self.edges, self.src_node, self.max_edges, + self.return_types, ): res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) @@ -215,6 +235,7 @@ self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() + self.return_types = self.get_return_types() async def get_walk_iterator(self): return self.backend.walk( @@ -244,7 +265,12 @@ class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.random_walk( - self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing + self.direction, + self.edges, + RANDOM_RETRIES, + self.src_node, + self.dst_thing, + self.return_types, ) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -75,6 +75,41 @@ assert set(actual) == set(expected) +def test_visit_nodes_filtered(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir", + ) + ) + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="*", + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected) + + def test_visit_edges(graph_client): actual = list( graph_client.visit_edges(