Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7122802
D5577.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
21 KB
Subscribers
None
D5577.diff
View Options
diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java
--- a/java/src/main/java/org/softwareheritage/graph/Entry.java
+++ b/java/src/main/java/org/softwareheritage/graph/Entry.java
@@ -1,13 +1,13 @@
package org.softwareheritage.graph;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.fasterxml.jackson.databind.PropertyNamingStrategy;
-
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.PropertyNamingStrategy;
+
public class Entry {
private final long PATH_SEPARATOR_ID = -1;
private Graph graph;
@@ -112,24 +112,30 @@
}
}
- public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.leavesVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.leaves(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
- public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.neighborsVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.neighbors(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
- public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges) {
+ public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges);
- t.visitNodesVisitor(srcNodeId, this::writeNode);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes);
+ for (Long nodeId : t.visitNodes(srcNodeId)) {
+ writeNode(nodeId);
+ }
close();
}
@@ -166,19 +172,21 @@
close();
}
- public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) {
+ public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId,
+ String returnTypes) {
open();
- Traversal t = new Traversal(this.graph, direction, edgesFmt);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes);
for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) {
writeNode(nodeId);
}
close();
}
- public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) {
+ public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst,
+ String returnTypes) {
open();
Node.Type dstType = Node.Type.fromStr(dst);
- Traversal t = new Traversal(this.graph, direction, edgesFmt);
+ Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes);
for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) {
writeNode(nodeId);
}
diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
new file mode 100644
--- /dev/null
+++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java
@@ -0,0 +1,107 @@
+package org.softwareheritage.graph;
+
+import java.util.ArrayList;
+
+/**
+ * <h3>NodesFiltering</h3>
+ * <p>
+ * class that manages the filtering of nodes that have been returned after a visit of the graph.
+ * parameterized by a string that represents either no filtering (*) or a set of node types.
+ * </p>
+ *
+ * <ul>
+ *
+ * <li>graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==rev will
+ * only return 'rev' nodes.</li>
+ *
+ * <li>graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010
+ * return_types==rev,snp,cnt will only return 'rev' 'snp' 'cnt' nodes.</li>
+ *
+ * <li>graph/visit/nodes/swh:1:rel:0000000000000000000000000000000000000010 return_types==* will
+ * return all the nodes.</li>
+ * </ul>
+ *
+ * How to use NodesFiltering :
+ *
+ * <pre>
+ * {@code
+ * Long id1 = .... // graph.getNodeType(id1) == CNT
+ * Long id2 = .... // graph.getNodeType(id2) == SNP
+ * Long id3 = .... // graph.getNodeType(id3) == ORI
+ * ArrayList<Long> nodeIds = nez ArrayList<Long>();
+ * nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+ *
+ * NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+ *
+ * nds = NodesFiltering("*");
+ * System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+ *
+ * }
+ * </pre>
+ */
+
+public class NodesFiltering {
+
+ boolean restricted;
+ ArrayList<Node.Type> allowedNodesTypes;
+
+ /**
+ * Default constructor, in order to handle the * case (all types of nodes are allowed to be
+ * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes.
+ *
+ */
+ public NodesFiltering() {
+ restricted = false;
+ allowedNodesTypes = Node.Type.parse("*");
+ }
+
+ /**
+ * Constructor
+ *
+ * @param strTypes a formatted string describing the types of nodes we want to allow to be shown.
+ *
+ * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP]
+ *
+ */
+ public NodesFiltering(String strTypes) {
+ restricted = true;
+ allowedNodesTypes = new ArrayList<Node.Type>();
+ String[] types = strTypes.split(",");
+ for (String type : types) {
+ allowedNodesTypes.add(Node.Type.fromStr(type));
+ }
+ }
+
+ /**
+ * Check if the type given in parameter is in the list of allowed types.
+ *
+ * @param typ the type of the node.
+ */
+ public boolean typeIsAllowed(Node.Type typ) {
+ return this.allowedNodesTypes.contains(typ);
+ }
+
+ /**
+ * <p>
+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and
+ * we create a new list with only the nodes that have a type that is contained in the list of
+ * allowed types (allowedNodesTypes)
+ * </p>
+ *
+ * @param nodeIds the nodes founded during the visit
+ * @param g the graph in order to find the types of nodes from their id in nodeIds
+ * @return a new list with the id of node which have a type in allowedTypes
+ *
+ *
+ */
+ public ArrayList<Long> filterByNodeTypes(ArrayList<Long> nodeIds, Graph g) {
+ ArrayList<Long> filteredNodes = new ArrayList<Long>();
+ for (Long node : nodeIds) {
+ if (this.typeIsAllowed(g.getNodeType(node))) {
+ filteredNodes.add(node);
+ }
+ }
+ return filteredNodes;
+ }
+}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -1,12 +1,22 @@
package org.softwareheritage.graph;
-import it.unimi.dsi.big.webgraph.LazyLongIterator;
-import org.softwareheritage.graph.server.Endpoint;
-
-import java.util.*;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Stack;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
+import org.softwareheritage.graph.server.Endpoint;
+
+import it.unimi.dsi.big.webgraph.LazyLongIterator;
+
/**
* Traversal algorithms on the compressed graph.
* <p>
@@ -32,6 +42,8 @@
/** The anti Dos limit of edges traversed while a visit */
long maxEdges;
+ /** The string represent the set of type restriction */
+ NodesFiltering ndsfilter;
/** random number generator, for random walks */
Random rng;
@@ -45,11 +57,16 @@
* "https://docs.softwareheritage.org/devel/swh-graph/api.html#terminology">allowed
* edges</a>
*/
+
public Traversal(Graph graph, String direction, String edgesFmt) {
this(graph, direction, edgesFmt, 0);
}
public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges) {
+ this(graph, direction, edgesFmt, 0, "*");
+ }
+
+ public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges, String returnTypes) {
if (!direction.matches("forward|backward")) {
throw new IllegalArgumentException("Unknown traversal direction: " + direction);
}
@@ -66,6 +83,12 @@
this.nbEdgesAccessed = 0;
this.maxEdges = maxEdges;
this.rng = new Random();
+
+ if (returnTypes.equals("*")) {
+ this.ndsfilter = new NodesFiltering();
+ } else {
+ this.ndsfilter = new NodesFiltering(returnTypes);
+ }
}
/**
@@ -128,8 +151,11 @@
* @return list of node ids corresponding to the leaves
*/
public ArrayList<Long> leaves(long srcNodeId) {
- ArrayList<Long> nodeIds = new ArrayList<>();
+ ArrayList<Long> nodeIds = new ArrayList<Long>();
leavesVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -158,6 +184,9 @@
public ArrayList<Long> neighbors(long srcNodeId) {
ArrayList<Long> nodeIds = new ArrayList<>();
neighborsVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -209,6 +238,9 @@
public ArrayList<Long> visitNodes(long srcNodeId) {
ArrayList<Long> nodeIds = new ArrayList<>();
visitNodesVisitor(srcNodeId, nodeIds::add);
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(nodeIds, graph);
+ }
return nodeIds;
}
@@ -334,6 +366,9 @@
}
if (found) {
+ if (ndsfilter.restricted) {
+ return ndsfilter.filterByNodeTypes(path, graph);
+ }
return path;
} else if (retries > 0) { // try again
return randomWalk(srcNodeId, dst, retries - 1);
diff --git a/swh/graph/backend.py b/swh/graph/backend.py
--- a/swh/graph/backend.py
+++ b/swh/graph/backend.py
@@ -68,10 +68,12 @@
method = getattr(self.entry, "count_" + ttype)
return method(direction, edges_fmt, src)
- async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges):
+ async def simple_traversal(
+ self, ttype, direction, edges_fmt, src, max_edges, return_types
+ ):
assert ttype in ("leaves", "neighbors", "visit_nodes")
method = getattr(self.stream_proxy, ttype)
- async for node_id in method(direction, edges_fmt, src, max_edges):
+ async for node_id in method(direction, edges_fmt, src, max_edges, return_types):
yield node_id
async def walk(self, direction, edges_fmt, algo, src, dst):
@@ -82,13 +84,15 @@
async for node_id in it:
yield node_id
- async def random_walk(self, direction, edges_fmt, retries, src, dst):
+ async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types):
if dst in EXTENDED_SWHID_TYPES:
it = self.stream_proxy.random_walk_type(
- direction, edges_fmt, retries, src, dst
+ direction, edges_fmt, retries, src, dst, return_types
)
else:
- it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst)
+ it = self.stream_proxy.random_walk(
+ direction, edges_fmt, retries, src, dst, return_types
+ )
async for node_id in it: # TODO return 404 if path is empty
yield node_id
diff --git a/swh/graph/client.py b/swh/graph/client.py
--- a/swh/graph/client.py
+++ b/swh/graph/client.py
@@ -12,7 +12,10 @@
"""Graph API Error"""
def __str__(self):
- return "An unexpected error occurred in the Graph backend: {}".format(self.args)
+ return """An unexpected error occurred
+ in the Graph backend: {}""".format(
+ self.args
+ )
class RemoteGraphClient(RPCClient):
@@ -35,22 +38,43 @@
def stats(self):
return self.get("stats")
- def leaves(self, src, edges="*", direction="forward", max_edges=0):
+ def leaves(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"leaves/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
- def neighbors(self, src, edges="*", direction="forward", max_edges=0):
+ def neighbors(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"neighbors/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
- def visit_nodes(self, src, edges="*", direction="forward", max_edges=0):
+ def visit_nodes(
+ self, src, edges="*", direction="forward", max_edges=0, return_types="*"
+ ):
return self.get_lines(
"visit/nodes/{}".format(src),
- params={"edges": edges, "direction": direction, "max_edges": max_edges},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "max_edges": max_edges,
+ "return_types": return_types,
+ },
)
def visit_edges(self, src, edges="*", direction="forward", max_edges=0):
@@ -86,11 +110,18 @@
},
)
- def random_walk(self, src, dst, edges="*", direction="forward", limit=None):
+ def random_walk(
+ self, src, dst, edges="*", direction="forward", limit=None, return_types="*"
+ ):
endpoint = "randomwalk/{}/{}"
return self.get_lines(
endpoint.format(src, dst),
- params={"edges": edges, "direction": direction, "limit": limit},
+ params={
+ "edges": edges,
+ "direction": direction,
+ "limit": limit,
+ "return_types": return_types,
+ },
)
def count_leaves(self, src, edges="*", direction="forward"):
diff --git a/swh/graph/graph.py b/swh/graph/graph.py
--- a/swh/graph/graph.py
+++ b/swh/graph/graph.py
@@ -74,7 +74,9 @@
lambda: self.graph.java_graph.indegree(self.id),
)
- def simple_traversal(self, ttype, direction="forward", edges="*", max_edges=0):
+ def simple_traversal(
+ self, ttype, direction="forward", edges="*", max_edges=0, return_types="*"
+ ):
for node in call_async_gen(
self.graph.backend.simple_traversal,
ttype,
@@ -82,6 +84,7 @@
edges,
self.id,
max_edges,
+ return_types,
):
yield self.graph[node]
diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -94,6 +94,24 @@
raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}")
return s
+ def get_return_types(self):
+ """Validate HTTP query parameter 'return types', i.e,
+ a set of types which we will filter the query results with"""
+ s = self.request.query.get("return_types", "*")
+ if any(
+ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES
+ for node_type in s.split(",")
+ ):
+ raise aiohttp.web.HTTPBadRequest(
+ text=f"invalid type for filtering res: {s}"
+ )
+ # if the user puts a star,
+ # then we filter nothing, we don't need the other information
+ if "*" in s:
+ return "*"
+ else:
+ return s
+
def get_traversal(self):
"""Validate HTTP query parameter `traversal`, i.e., visit order"""
s = self.request.query.get("traversal", "dfs")
@@ -176,6 +194,7 @@
self.edges = self.get_edges()
self.direction = self.get_direction()
self.max_edges = self.get_max_edges()
+ self.return_types = self.get_return_types()
async def stream_response(self):
async for res_node in self.backend.simple_traversal(
@@ -184,6 +203,7 @@
self.edges,
self.src_node,
self.max_edges,
+ self.return_types,
):
res_swhid = self.swhid_of_node(res_node)
await self.stream_line(res_swhid)
@@ -215,6 +235,7 @@
self.direction = self.get_direction()
self.algo = self.get_traversal()
self.limit = self.get_limit()
+ self.return_types = self.get_return_types()
async def get_walk_iterator(self):
return self.backend.walk(
@@ -244,7 +265,12 @@
class RandomWalkView(WalkView):
def get_walk_iterator(self):
return self.backend.random_walk(
- self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing
+ self.direction,
+ self.edges,
+ RANDOM_RETRIES,
+ self.src_node,
+ self.dst_thing,
+ self.return_types,
)
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -75,6 +75,41 @@
assert set(actual) == set(expected)
+def test_visit_nodes_filtered(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir",
+ )
+ )
+ expected = [
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ ]
+ assert set(actual) == set(expected)
+
+
+def test_visit_nodes_filtered_star(graph_client):
+ actual = list(
+ graph_client.visit_nodes(
+ "swh:1:rel:0000000000000000000000000000000000000010", return_types="*",
+ )
+ )
+ expected = [
+ "swh:1:rel:0000000000000000000000000000000000000010",
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:cnt:0000000000000000000000000000000000000007",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ "swh:1:cnt:0000000000000000000000000000000000000004",
+ "swh:1:cnt:0000000000000000000000000000000000000005",
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_edges(graph_client):
actual = list(
graph_client.visit_edges(
@@ -132,7 +167,7 @@
# As there are four valid answers (up to reordering), we cannot check for
# equality. Instead, we check the client returned all edges but one.
assert set(actual).issubset(set(expected))
- assert len(actual) == 3
+ assert len(actual) == 4
def test_visit_edges_diamond_pattern(graph_client):
@@ -244,9 +279,9 @@
def test_random_walk(graph_client):
- """as the walk is random, we test a visit from a cnt node to the only origin in
- the dataset, and only check the final node of the path (i.e., the origin)
-
+ """as the walk is random, we test a visit from a cnt node to the only
+ origin in the dataset, and only check the final node of the path
+ (i.e., the origin)
"""
args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori")
kwargs = {"direction": "backward"}
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Dec 17 2024, 1:54 AM (14 w, 5 h ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3223110
Attached To
D5577: adds a filter by node type as a query argument
Event Timeline
Log In to Comment