Page Menu
Home
Software Heritage
Search
Configure Global Search
Log In
Files
F7066525
D5501.id19651.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Flag For Later
Size
14 KB
Subscribers
None
D5501.id19651.diff
View Options
diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java
--- a/java/src/main/java/org/softwareheritage/graph/Entry.java
+++ b/java/src/main/java/org/softwareheritage/graph/Entry.java
@@ -126,24 +126,24 @@
close();
}
- public void visit_nodes(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_nodes(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitNodesVisitor(srcNodeId, this::writeNode);
+ t.visitNodesVisitor(srcNodeId, this::writeNode, max_edges);
close();
}
- public void visit_edges(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_edges(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitNodesVisitor(srcNodeId, null, this::writeEdge);
+ t.visitNodesVisitor(srcNodeId, null, this::writeEdge, max_edges, false);
close();
}
- public void visit_paths(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_paths(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitPathsVisitor(srcNodeId, this::writePath);
+ t.visitPathsVisitor(srcNodeId, this::writePath, max_edges);
close();
}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -146,20 +146,31 @@
/**
* Push version of {@link #visitNodes}: will fire passed callback on each visited node.
*/
- public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) {
+ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb, Object max_edges,
+ boolean limitedVisit) {
Stack<Long> stack = new Stack<>();
this.nbEdgesAccessed = 0;
stack.push(srcNodeId);
visited.add(srcNodeId);
+ Long limit_edges = null;
+ if (!Objects.isNull(max_edges)) {
+ limitedVisit = true;
+ limit_edges = Long.valueOf(max_edges.toString());
+ }
+
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
if (nodeCb != null) {
nodeCb.accept(currentNodeId);
}
-
nbEdgesAccessed += graph.outdegree(currentNodeId);
+ if (limitedVisit) {
+ if (limit_edges.compareTo(nbEdgesAccessed) <= 0) {
+ break;
+ }
+ }
LazyLongIterator it = graph.successors(currentNodeId, edges);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (edgeCb != null) {
@@ -173,9 +184,14 @@
}
}
- /** One-argument version to handle callbacks properly */
+ /** Two argument version for count_visitor */
public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) {
- visitNodesVisitor(srcNodeId, cb, null);
+ visitNodesVisitor(srcNodeId, cb, null, null, false);
+ }
+
+ /** One-argument version to handle callbacks properly */
+ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb, Object max_edges) {
+ visitNodesVisitor(srcNodeId, cb, null, max_edges, false);
}
/**
@@ -186,7 +202,7 @@
*/
public ArrayList<Long> visitNodes(long srcNodeId) {
ArrayList<Long> nodeIds = new ArrayList<>();
- visitNodesVisitor(srcNodeId, nodeIds::add);
+ visitNodesVisitor(srcNodeId, nodeIds::add, null);
return nodeIds;
}
@@ -194,10 +210,17 @@
* Push version of {@link #visitPaths}: will fire passed callback on each discovered (complete)
* path.
*/
- public void visitPathsVisitor(long srcNodeId, PathConsumer cb) {
+ public void visitPathsVisitor(long srcNodeId, PathConsumer cb, Object max_edges) {
Stack<Long> currentPath = new Stack<>();
this.nbEdgesAccessed = 0;
- visitPathsInternalVisitor(srcNodeId, currentPath, cb);
+ boolean limitedVisit = false;
+ if (!Objects.isNull(max_edges)) {
+ limitedVisit = true;
+ Long l = Long.valueOf(max_edges.toString());
+ visitPathsInternalVisitor(srcNodeId, currentPath, cb, l, limitedVisit);
+ } else {
+ visitPathsInternalVisitor(srcNodeId, currentPath, cb, null, limitedVisit);
+ }
}
/**
@@ -208,18 +231,26 @@
*/
public ArrayList<ArrayList<Long>> visitPaths(long srcNodeId) {
ArrayList<ArrayList<Long>> paths = new ArrayList<>();
- visitPathsVisitor(srcNodeId, paths::add);
+ visitPathsVisitor(srcNodeId, paths::add, null);
return paths;
}
- private void visitPathsInternalVisitor(long currentNodeId, Stack<Long> currentPath, PathConsumer cb) {
+ private void visitPathsInternalVisitor(long currentNodeId, Stack<Long> currentPath, PathConsumer cb, Long max_edges,
+ boolean limitedVisit) {
currentPath.push(currentNodeId);
long visitedNeighbors = 0;
+
nbEdgesAccessed += graph.outdegree(currentNodeId);
+ if (limitedVisit) {
+ if (max_edges.compareTo(nbEdgesAccessed) <= 0) {
+ currentPath.pop();
+ return;
+ }
+ }
LazyLongIterator it = graph.successors(currentNodeId, edges);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
- visitPathsInternalVisitor(neighborNodeId, currentPath, cb);
+ visitPathsInternalVisitor(neighborNodeId, currentPath, cb, max_edges, limitedVisit);
visitedNeighbors++;
}
diff --git a/swh/graph/backend.py b/swh/graph/backend.py
--- a/swh/graph/backend.py
+++ b/swh/graph/backend.py
@@ -69,7 +69,7 @@
return method(direction, edges_fmt, src)
async def simple_traversal(self, ttype, direction, edges_fmt, src):
- assert ttype in ("leaves", "neighbors", "visit_nodes")
+ assert ttype in ("leaves", "neighbors")
method = getattr(self.stream_proxy, ttype)
async for node_id in method(direction, edges_fmt, src):
yield node_id
@@ -92,8 +92,14 @@
async for node_id in it: # TODO return 404 if path is empty
yield node_id
- async def visit_edges(self, direction, edges_fmt, src):
- it = self.stream_proxy.visit_edges(direction, edges_fmt, src)
+ async def visit_nodes(self, direction, edges_fmt, src, max_edges):
+ async for node_id in self.stream_proxy.visit_nodes(
+ direction, edges_fmt, src, max_edges
+ ):
+ yield node_id
+
+ async def visit_edges(self, direction, edges_fmt, src, max_edges):
+ it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges)
# convert stream a, b, c, d -> (a, b), (c, d)
prevNode = None
async for node in it:
@@ -103,9 +109,11 @@
else:
prevNode = node
- async def visit_paths(self, direction, edges_fmt, src):
+ async def visit_paths(self, direction, edges_fmt, src, max_edges):
path = []
- async for node in self.stream_proxy.visit_paths(direction, edges_fmt, src):
+ async for node in self.stream_proxy.visit_paths(
+ direction, edges_fmt, src, max_edges
+ ):
if node == PATH_SEPARATOR_ID:
yield path
path = []
diff --git a/swh/graph/client.py b/swh/graph/client.py
--- a/swh/graph/client.py
+++ b/swh/graph/client.py
@@ -45,20 +45,20 @@
"neighbors/{}".format(src), params={"edges": edges, "direction": direction}
)
- def visit_nodes(self, src, edges="*", direction="forward"):
+ def visit_nodes(self, src, edges="*", direction="forward", max_edges=None):
return self.get_lines(
"visit/nodes/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
)
- def visit_edges(self, src, edges="*", direction="forward"):
+ def visit_edges(self, src, edges="*", direction="forward", max_edges=None):
for edge in self.get_lines(
"visit/edges/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
):
yield tuple(edge.split())
- def visit_paths(self, src, edges="*", direction="forward"):
+ def visit_paths(self, src, edges="*", direction="forward", max_edges=None):
def decode_path_wrapper(it):
for e in it:
yield json.loads(e)
@@ -66,7 +66,7 @@
return decode_path_wrapper(
self.get_lines(
"visit/paths/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
)
)
diff --git a/swh/graph/graph.py b/swh/graph/graph.py
--- a/swh/graph/graph.py
+++ b/swh/graph/graph.py
@@ -83,18 +83,21 @@
def leaves(self, *args, **kwargs):
yield from self.simple_traversal("leaves", *args, **kwargs)
- def visit_nodes(self, *args, **kwargs):
- yield from self.simple_traversal("visit_nodes", *args, **kwargs)
+ def visit_nodes(self, direction="forward", edges="*", max_edges=None):
+ for node in call_async_gen(
+ self.graph.backend.visit_nodes, direction, edges, self.id, max_edges
+ ):
+ yield self.graph[node]
- def visit_edges(self, direction="forward", edges="*"):
+ def visit_edges(self, direction="forward", edges="*", max_edges=None):
for src, dst in call_async_gen(
- self.graph.backend.visit_edges, direction, edges, self.id
+ self.graph.backend.visit_edges, direction, edges, self.id, max_edges
):
yield (self.graph[src], self.graph[dst])
- def visit_paths(self, direction="forward", edges="*"):
+ def visit_paths(self, direction="forward", edges="*", max_edges=None):
for path in call_async_gen(
- self.graph.backend.visit_paths, direction, edges, self.id
+ self.graph.backend.visit_paths, direction, edges, self.id, max_edges
):
yield [self.graph[node] for node in path]
diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -109,6 +109,14 @@
except ValueError:
raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}")
+ def get_max_edges(self):
+ """Validate HTTP query parameter 'max_edges', i.e.,
+ the limit of the number of edges that can be visited"""
+ s: Optional[int] = self.request.query.get("max_edges")
+ if s is not None:
+ return int(s)
+ return s
+
class StreamingGraphView(GraphView):
"""Base class for views streaming their response line by line."""
@@ -166,6 +174,7 @@
self.edges = self.get_edges()
self.direction = self.get_direction()
+ self.max_edges = self.get_max_edges()
async def stream_response(self):
async for res_node in self.backend.simple_traversal(
@@ -183,10 +192,6 @@
simple_traversal_type = "neighbors"
-class VisitNodesView(SimpleTraversalView):
- simple_traversal_type = "visit_nodes"
-
-
class WalkView(StreamingGraphView):
async def prepare_response(self):
src = self.request.match_info["src"]
@@ -234,9 +239,20 @@
)
+class VisitNodesView(SimpleTraversalView):
+ async def stream_response(self):
+ async for res_node in self.backend.visit_nodes(
+ self.direction, self.edges, self.src_node, self.max_edges
+ ):
+ res_swhid = self.swhid_of_node(res_node)
+ await self.stream_line(res_swhid)
+
+
class VisitEdgesView(SimpleTraversalView):
async def stream_response(self):
- it = self.backend.visit_edges(self.direction, self.edges, self.src_node)
+ it = self.backend.visit_edges(
+ self.direction, self.edges, self.src_node, self.max_edges
+ )
async for (res_src, res_dst) in it:
res_src_swhid = self.swhid_of_node(res_src)
res_dst_swhid = self.swhid_of_node(res_dst)
@@ -247,7 +263,9 @@
content_type = "application/x-ndjson"
async def stream_response(self):
- it = self.backend.visit_paths(self.direction, self.edges, self.src_node)
+ it = self.backend.visit_paths(
+ self.direction, self.edges, self.src_node, self.max_edges
+ )
async for res_path in it:
res_path_swhid = [self.swhid_of_node(n) for n in res_path]
line = json.dumps(res_path_swhid)
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -103,6 +103,29 @@
assert set(actual) == set(expected)
+def test_visit_edges_limited(graph_client):
+ actual = list(
+ graph_client.visit_edges(
+ "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4
+ )
+ )
+ expected = [
+ (
+ "swh:1:rel:0000000000000000000000000000000000000010",
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ ),
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_edges_diamond_pattern(graph_client):
actual = list(
graph_client.visit_edges(
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Nov 5 2024, 2:32 PM (12 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3233629
Attached To
D5501: add an anti-Dos limit for edges traversed as a query parameter
Event Timeline
Log In to Comment