diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java
--- a/java/src/main/java/org/softwareheritage/graph/Entry.java
+++ b/java/src/main/java/org/softwareheritage/graph/Entry.java
@@ -112,37 +112,43 @@
}
}
- public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.leavesVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.leaves(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
- public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.neighborsVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.neighbors(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
- public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.visitNodesVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.visitNodes(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
public void visit_edges(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, "*");
t.visitNodesVisitor(srcNodeId, null, this::writeEdge);
close();
}
public void visit_paths(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, "*");
t.visitPathsVisitor(srcNodeId, this::writePath);
close();
}
@@ -166,19 +172,21 @@
close();
}
- public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) {
+ public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId,
+ String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes);
for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) {
writeNode(nodeId);
}
close();
}
- public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) {
+ public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst,
+ String returnTypes) {
open();
Node.Type dstType = Node.Type.fromStr(dst);
- Traversal t = new Traversal(this.graph, direction, edgesFmt);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, returnTypes);
for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) {
writeNode(nodeId);
}
diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
new file mode 100644
--- /dev/null
+++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
@@ -0,0 +1,104 @@
+package org.softwareheritage.graph;
+
+import java.util.ArrayList;
+
+/**
+ *
NodesFiltering
+ *
+ * class that manages the filtering of nodes that have been returned after a visit of the graph.
+ * parameterized by a string that represents either no filtering (*) or a set of node types.
+ *
+ *
+ *
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==rev will
+ * only return 'rev' nodes.
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010
+ * return_types==rev,snp,cnt will only return 'rev' 'snp' 'cnt' nodes.
+ *
+ * - graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==* will
+ * return all the nodes.
+ *
+ */
+
+public class NodesFiltering {
+
+ boolean restricted;
+ ArrayList allowedNodesTypes;
+
+ /**
+ * Default constructor, in order to handle the * case (all types of nodes are allowed to be
+ * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes.
+ *
+ */
+ public NodesFiltering() {
+ restricted = false;
+ allowedNodesTypes = Node.Type.parse("*");
+ }
+
+ /**
+ * Constructor
+ *
+ * @param strTypes a formatted string describing the types of nodes we want to allow to be shown.
+ *
+ * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP]
+ *
+ */
+ public NodesFiltering(String strTypes) {
+ restricted = true;
+ allowedNodesTypes = new ArrayList();
+ String[] types = strTypes.split(",");
+ for (String type : types) {
+ allowedNodesTypes.add(Node.Type.fromStr(type));
+ }
+ }
+
+ /**
+ * Check if the type given in parameter is in the list of allowed types.
+ *
+ * @param typ the type of the node.
+ */
+ public boolean typeIsAllowed(Node.Type typ) {
+ return this.allowedNodesTypes.contains(typ);
+ }
+
+ /**
+ *
+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and
+ * we create a new list with only the nodes that have a type that is contained in the list of
+ * allowed types (allowedNodesTypes)
+ *
+ *
+ * @param nodeIds the nodes founded during the visit
+ * @param g the graph in order to find the types of nodes from their id in nodeIds
+ * @return a new list with the id of node which have a type in allowedTypes
+ *
+ *
+ *
+ * {@code
+ * Long id1 = .... // graph.getNodeType(id1) == CNT
+ * Long id2 = .... // graph.getNodeType(id2) == SNP
+ * Long id3 = .... // graph.getNodeType(id3) == ORI
+ * ArrayList nodeIds = nez ArrayList();
+ * nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+ *
+ * NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+ *
+ * nds = NodesFilterind("*");
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+ *
+ * }
+ *
+ */
+ public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) {
+ ArrayList filteredNodes = new ArrayList();
+ for (Long node : nodeIds) {
+ if (this.typeIsAllowed(g.getNodeType(node))) {
+ filteredNodes.add(node);
+ }
+ }
+ return filteredNodes;
+ }
+}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -33,6 +33,9 @@
/** The anti Dos limit of edges traversed while a visit */
long maxEdges;
+ /** The string represent the set of type restriction */
+ NodesFiltering ndsfilter;
+
/** random number generator, for random walks */
Random rng;
@@ -46,10 +49,14 @@
* edges
*/
public Traversal(Graph graph, String direction, String edgesFmt) {
- this(graph, direction, edgesFmt, 0);
+ this(graph, direction, edgesFmt, 0, "*");
+ }
+
+ public Traversal(Graph graph, String direction, String edgesFmt, String returnTypes) {
+ this(graph, direction, edgesFmt, 0, returnTypes);
}
- public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges) {
+ public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges, String returnTypes) {
if (!direction.matches("forward|backward")) {
throw new IllegalArgumentException("Unknown traversal direction: " + direction);
}
@@ -66,6 +73,12 @@
this.nbEdgesAccessed = 0;
this.maxEdges = maxEdges;
this.rng = new Random();
+
+ if (returnTypes.equals("*")) {
+ this.ndsfilter = new NodesFiltering();
+ } else {
+ this.ndsfilter = new NodesFiltering(returnTypes);
+ }
}
/**
@@ -128,8 +141,11 @@
* @return list of node ids corresponding to the leaves
*/
public ArrayList leaves(long srcNodeId) {
- ArrayList nodeIds = new ArrayList<>();
+ ArrayList nodeIds = new ArrayList();
leavesVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -158,6 +174,9 @@
public ArrayList neighbors(long srcNodeId) {
ArrayList nodeIds = new ArrayList<>();
neighborsVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -209,6 +228,9 @@
public ArrayList visitNodes(long srcNodeId) {
ArrayList nodeIds = new ArrayList<>();
visitNodesVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -334,6 +356,9 @@
}
if (found) {
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(path, graph);
+ }
return path;
} else if (retries > 0) { // try again
return randomWalk(srcNodeId, dst, retries - 1);
diff --git a/swh/graph/backend.py b/swh/graph/backend.py
--- a/swh/graph/backend.py
+++ b/swh/graph/backend.py
@@ -68,10 +68,12 @@
method = getattr(self.entry, "count_" + ttype)
return method(direction, edges_fmt, src)
- async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges):
+ async def simple_traversal(
+ self, ttype, direction, edges_fmt, src, max_edges, return_types
+ ):
assert ttype in ("leaves", "neighbors", "visit_nodes")
method = getattr(self.stream_proxy, ttype)
- async for node_id in method(direction, edges_fmt, src, max_edges):
+ async for node_id in method(direction, edges_fmt, src, max_edges, return_types):
yield node_id
async def walk(self, direction, edges_fmt, algo, src, dst):
@@ -82,13 +84,15 @@
async for node_id in it:
yield node_id
- async def random_walk(self, direction, edges_fmt, retries, src, dst):
+ async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types):
if dst in EXTENDED_SWHID_TYPES:
it = self.stream_proxy.random_walk_type(
- direction, edges_fmt, retries, src, dst
+ direction, edges_fmt, retries, src, dst, return_types
)
else:
- it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst)
+ it = self.stream_proxy.random_walk(
+ direction, edges_fmt, retries, src, dst, return_types
+ )
async for node_id in it: # TODO return 404 if path is empty
yield node_id
diff --git a/swh/graph/client.py b/swh/graph/client.py
--- a/swh/graph/client.py
+++ b/swh/graph/client.py
@@ -12,7 +12,10 @@
"""Graph API Error"""
def __str__(self):
- return "An unexpected error occurred in the Graph backend: {}".format(self.args)
+ return """An unexpected error occurred
+ in the Graph backend: {}""".format(
+ self.args
+ )
class RemoteGraphClient(RPCClient):
@@ -35,22 +38,43 @@
def stats(self):
return self.get("stats")
- def leaves(self, src, edges="*", direction="forward", max_edges=0):
+ def leaves(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"leaves/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
- def neighbors(self, src, edges="*", direction="forward", max_edges=0):
+ def neighbors(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"neighbors/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
- def visit_nodes(self, src, edges="*", direction="forward", max_edges=0):
+ def visit_nodes(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"visit/nodes/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
def visit_edges(self, src, edges="*", direction="forward", max_edges=0):
@@ -86,11 +110,18 @@
},
)
- def random_walk(self, src, dst, edges="*", direction="forward", limit=None):
+ def random_walk(
+ self, src, dst, edges="*", direction="forward", limit=None, return_types="*"
+ ):
endpoint = "randomwalk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
- params={"edges": edges, "direction": direction, "limit": limit},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "limit": limit,
+ "return_types": return_types,
+ },
)
def count_leaves(self, src, edges="*", direction="forward"):
diff --git a/swh/graph/graph.py b/swh/graph/graph.py
--- a/swh/graph/graph.py
+++ b/swh/graph/graph.py
@@ -74,7 +74,9 @@
lambda: self.graph.java_graph.indegree(self.id),
)
- def simple_traversal(self, ttype, direction="forward", edges="*", max_edges=0):
+ def simple_traversal(
+ self, ttype, direction="forward", edges="*", max_edges=0, return_types="*"
+ ):
for node in call_async_gen(
self.graph.backend.simple_traversal,
ttype,
@@ -82,6 +84,7 @@
edges,
self.id,
max_edges,
+ return_types,
):
yield self.graph[node]
diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -118,6 +118,24 @@
except ValueError:
raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}")
+ def get_return_types(self):
+ """Validate HTTP query parameter 'return types', i.e,
+ a set of types which we will filter the query results with"""
+ s = self.request.query.get("return_types", "*")
+ if any(
+ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES
+ for node_type in s.split(",")
+ ):
+ raise aiohttp.web.HTTPBadRequest(
+ text=f"invalid type for filtering res: {s}"
+ )
+ # if the user puts a star,
+ # then we filter nothing, we don't need the other information
+ if "*" in s:
+ return "*"
+ else:
+ return s
+
class StreamingGraphView(GraphView):
"""Base class for views streaming their response line by line."""
@@ -176,6 +194,7 @@
self.edges = self.get_edges()
self.direction = self.get_direction()
self.max_edges = self.get_max_edges()
+ self.return_types = self.get_return_types()
async def stream_response(self):
async for res_node in self.backend.simple_traversal(
@@ -184,6 +203,7 @@
self.edges,
self.src_node,
self.max_edges,
+ self.return_types,
):
res_swhid = self.swhid_of_node(res_node)
await self.stream_line(res_swhid)
@@ -215,6 +235,7 @@
self.direction = self.get_direction()
self.algo = self.get_traversal()
self.limit = self.get_limit()
+ self.return_types = self.get_return_types()
async def get_walk_iterator(self):
return self.backend.walk(
@@ -244,7 +265,12 @@
class RandomWalkView(WalkView):
def get_walk_iterator(self):
return self.backend.random_walk(
- self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing
+ self.direction,
+ self.edges,
+ RANDOM_RETRIES,
+ self.src_node,
+ self.dst_thing,
+ self.return_types,
)
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -75,6 +75,41 @@
assert set(actual) == set(expected)
+def test_visit_nodes_filtered(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir",
+ )
+ )
+ expected = [
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ ]
+ assert set(actual) == set(expected)
+
+
+def test_visit_nodes_filtered_star(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="*",
+ )
+ )
+ expected = [
+ "swh:1:rel:0000000000000000000000000000000000000010",
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:cnt:0000000000000000000000000000000000000007",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ "swh:1:cnt:0000000000000000000000000000000000000004",
+ "swh:1:cnt:0000000000000000000000000000000000000005",
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_edges(graph_client):
actual = list(
graph_client.visit_edges(