Page MenuHomeSoftware Heritage

D5501.id19651.diff
No OneTemporary

D5501.id19651.diff

diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java
--- a/java/src/main/java/org/softwareheritage/graph/Entry.java
+++ b/java/src/main/java/org/softwareheritage/graph/Entry.java
@@ -126,24 +126,24 @@
close();
}
- public void visit_nodes(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_nodes(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitNodesVisitor(srcNodeId, this::writeNode);
+ t.visitNodesVisitor(srcNodeId, this::writeNode, max_edges);
close();
}
- public void visit_edges(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_edges(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitNodesVisitor(srcNodeId, null, this::writeEdge);
+ t.visitNodesVisitor(srcNodeId, null, this::writeEdge, max_edges, false);
close();
}
- public void visit_paths(String direction, String edgesFmt, long srcNodeId) {
+ public void visit_paths(String direction, String edgesFmt, long srcNodeId, Object max_edges) {
open();
Traversal t = new Traversal(this.graph, direction, edgesFmt);
- t.visitPathsVisitor(srcNodeId, this::writePath);
+ t.visitPathsVisitor(srcNodeId, this::writePath, max_edges);
close();
}
diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java
--- a/java/src/main/java/org/softwareheritage/graph/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java
@@ -146,20 +146,31 @@
/**
* Push version of {@link #visitNodes}: will fire passed callback on each visited node.
*/
- public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) {
+ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb, Object max_edges,
+ boolean limitedVisit) {
Stack<Long> stack = new Stack<>();
this.nbEdgesAccessed = 0;
stack.push(srcNodeId);
visited.add(srcNodeId);
+ Long limit_edges = null;
+ if (!Objects.isNull(max_edges)) {
+ limitedVisit = true;
+ limit_edges = Long.valueOf(max_edges.toString());
+ }
+
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
if (nodeCb != null) {
nodeCb.accept(currentNodeId);
}
-
nbEdgesAccessed += graph.outdegree(currentNodeId);
+ if (limitedVisit) {
+ if (limit_edges.compareTo(nbEdgesAccessed) <= 0) {
+ break;
+ }
+ }
LazyLongIterator it = graph.successors(currentNodeId, edges);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
if (edgeCb != null) {
@@ -173,9 +184,14 @@
}
}
- /** One-argument version to handle callbacks properly */
+ /** Two argument version for count_visitor */
public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) {
- visitNodesVisitor(srcNodeId, cb, null);
+ visitNodesVisitor(srcNodeId, cb, null, null, false);
+ }
+
+ /** One-argument version to handle callbacks properly */
+ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb, Object max_edges) {
+ visitNodesVisitor(srcNodeId, cb, null, max_edges, false);
}
/**
@@ -186,7 +202,7 @@
*/
public ArrayList<Long> visitNodes(long srcNodeId) {
ArrayList<Long> nodeIds = new ArrayList<>();
- visitNodesVisitor(srcNodeId, nodeIds::add);
+ visitNodesVisitor(srcNodeId, nodeIds::add, null);
return nodeIds;
}
@@ -194,10 +210,17 @@
* Push version of {@link #visitPaths}: will fire passed callback on each discovered (complete)
* path.
*/
- public void visitPathsVisitor(long srcNodeId, PathConsumer cb) {
+ public void visitPathsVisitor(long srcNodeId, PathConsumer cb, Object max_edges) {
Stack<Long> currentPath = new Stack<>();
this.nbEdgesAccessed = 0;
- visitPathsInternalVisitor(srcNodeId, currentPath, cb);
+ boolean limitedVisit = false;
+ if (!Objects.isNull(max_edges)) {
+ limitedVisit = true;
+ Long l = Long.valueOf(max_edges.toString());
+ visitPathsInternalVisitor(srcNodeId, currentPath, cb, l, limitedVisit);
+ } else {
+ visitPathsInternalVisitor(srcNodeId, currentPath, cb, null, limitedVisit);
+ }
}
/**
@@ -208,18 +231,26 @@
*/
public ArrayList<ArrayList<Long>> visitPaths(long srcNodeId) {
ArrayList<ArrayList<Long>> paths = new ArrayList<>();
- visitPathsVisitor(srcNodeId, paths::add);
+ visitPathsVisitor(srcNodeId, paths::add, null);
return paths;
}
- private void visitPathsInternalVisitor(long currentNodeId, Stack<Long> currentPath, PathConsumer cb) {
+ private void visitPathsInternalVisitor(long currentNodeId, Stack<Long> currentPath, PathConsumer cb, Long max_edges,
+ boolean limitedVisit) {
currentPath.push(currentNodeId);
long visitedNeighbors = 0;
+
nbEdgesAccessed += graph.outdegree(currentNodeId);
+ if (limitedVisit) {
+ if (max_edges.compareTo(nbEdgesAccessed) <= 0) {
+ currentPath.pop();
+ return;
+ }
+ }
LazyLongIterator it = graph.successors(currentNodeId, edges);
for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) {
- visitPathsInternalVisitor(neighborNodeId, currentPath, cb);
+ visitPathsInternalVisitor(neighborNodeId, currentPath, cb, max_edges, limitedVisit);
visitedNeighbors++;
}
diff --git a/swh/graph/backend.py b/swh/graph/backend.py
--- a/swh/graph/backend.py
+++ b/swh/graph/backend.py
@@ -69,7 +69,7 @@
return method(direction, edges_fmt, src)
async def simple_traversal(self, ttype, direction, edges_fmt, src):
- assert ttype in ("leaves", "neighbors", "visit_nodes")
+ assert ttype in ("leaves", "neighbors")
method = getattr(self.stream_proxy, ttype)
async for node_id in method(direction, edges_fmt, src):
yield node_id
@@ -92,8 +92,14 @@
async for node_id in it: # TODO return 404 if path is empty
yield node_id
- async def visit_edges(self, direction, edges_fmt, src):
- it = self.stream_proxy.visit_edges(direction, edges_fmt, src)
+ async def visit_nodes(self, direction, edges_fmt, src, max_edges):
+ async for node_id in self.stream_proxy.visit_nodes(
+ direction, edges_fmt, src, max_edges
+ ):
+ yield node_id
+
+ async def visit_edges(self, direction, edges_fmt, src, max_edges):
+ it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges)
# convert stream a, b, c, d -> (a, b), (c, d)
prevNode = None
async for node in it:
@@ -103,9 +109,11 @@
else:
prevNode = node
- async def visit_paths(self, direction, edges_fmt, src):
+ async def visit_paths(self, direction, edges_fmt, src, max_edges):
path = []
- async for node in self.stream_proxy.visit_paths(direction, edges_fmt, src):
+ async for node in self.stream_proxy.visit_paths(
+ direction, edges_fmt, src, max_edges
+ ):
if node == PATH_SEPARATOR_ID:
yield path
path = []
diff --git a/swh/graph/client.py b/swh/graph/client.py
--- a/swh/graph/client.py
+++ b/swh/graph/client.py
@@ -45,20 +45,20 @@
"neighbors/{}".format(src), params={"edges": edges, "direction": direction}
)
- def visit_nodes(self, src, edges="*", direction="forward"):
+ def visit_nodes(self, src, edges="*", direction="forward", max_edges=None):
return self.get_lines(
"visit/nodes/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
)
- def visit_edges(self, src, edges="*", direction="forward"):
+ def visit_edges(self, src, edges="*", direction="forward", max_edges=None):
for edge in self.get_lines(
"visit/edges/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
):
yield tuple(edge.split())
- def visit_paths(self, src, edges="*", direction="forward"):
+ def visit_paths(self, src, edges="*", direction="forward", max_edges=None):
def decode_path_wrapper(it):
for e in it:
yield json.loads(e)
@@ -66,7 +66,7 @@
return decode_path_wrapper(
self.get_lines(
"visit/paths/{}".format(src),
- params={"edges": edges, "direction": direction},
+ params={"edges": edges, "direction": direction, "max_edges": max_edges},
)
)
diff --git a/swh/graph/graph.py b/swh/graph/graph.py
--- a/swh/graph/graph.py
+++ b/swh/graph/graph.py
@@ -83,18 +83,21 @@
def leaves(self, *args, **kwargs):
yield from self.simple_traversal("leaves", *args, **kwargs)
- def visit_nodes(self, *args, **kwargs):
- yield from self.simple_traversal("visit_nodes", *args, **kwargs)
+ def visit_nodes(self, direction="forward", edges="*", max_edges=None):
+ for node in call_async_gen(
+ self.graph.backend.visit_nodes, direction, edges, self.id, max_edges
+ ):
+ yield self.graph[node]
- def visit_edges(self, direction="forward", edges="*"):
+ def visit_edges(self, direction="forward", edges="*", max_edges=None):
for src, dst in call_async_gen(
- self.graph.backend.visit_edges, direction, edges, self.id
+ self.graph.backend.visit_edges, direction, edges, self.id, max_edges
):
yield (self.graph[src], self.graph[dst])
- def visit_paths(self, direction="forward", edges="*"):
+ def visit_paths(self, direction="forward", edges="*", max_edges=None):
for path in call_async_gen(
- self.graph.backend.visit_paths, direction, edges, self.id
+ self.graph.backend.visit_paths, direction, edges, self.id, max_edges
):
yield [self.graph[node] for node in path]
diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py
--- a/swh/graph/server/app.py
+++ b/swh/graph/server/app.py
@@ -109,6 +109,14 @@
except ValueError:
raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}")
+ def get_max_edges(self):
+ """Validate HTTP query parameter 'max_edges', i.e.,
+ the limit of the number of edges that can be visited"""
+ s: Optional[int] = self.request.query.get("max_edges")
+ if s is not None:
+ return int(s)
+ return s
+
class StreamingGraphView(GraphView):
"""Base class for views streaming their response line by line."""
@@ -166,6 +174,7 @@
self.edges = self.get_edges()
self.direction = self.get_direction()
+ self.max_edges = self.get_max_edges()
async def stream_response(self):
async for res_node in self.backend.simple_traversal(
@@ -183,10 +192,6 @@
simple_traversal_type = "neighbors"
-class VisitNodesView(SimpleTraversalView):
- simple_traversal_type = "visit_nodes"
-
-
class WalkView(StreamingGraphView):
async def prepare_response(self):
src = self.request.match_info["src"]
@@ -234,9 +239,20 @@
)
+class VisitNodesView(SimpleTraversalView):
+ async def stream_response(self):
+ async for res_node in self.backend.visit_nodes(
+ self.direction, self.edges, self.src_node, self.max_edges
+ ):
+ res_swhid = self.swhid_of_node(res_node)
+ await self.stream_line(res_swhid)
+
+
class VisitEdgesView(SimpleTraversalView):
async def stream_response(self):
- it = self.backend.visit_edges(self.direction, self.edges, self.src_node)
+ it = self.backend.visit_edges(
+ self.direction, self.edges, self.src_node, self.max_edges
+ )
async for (res_src, res_dst) in it:
res_src_swhid = self.swhid_of_node(res_src)
res_dst_swhid = self.swhid_of_node(res_dst)
@@ -247,7 +263,9 @@
content_type = "application/x-ndjson"
async def stream_response(self):
- it = self.backend.visit_paths(self.direction, self.edges, self.src_node)
+ it = self.backend.visit_paths(
+ self.direction, self.edges, self.src_node, self.max_edges
+ )
async for res_path in it:
res_path_swhid = [self.swhid_of_node(n) for n in res_path]
line = json.dumps(res_path_swhid)
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -103,6 +103,29 @@
assert set(actual) == set(expected)
+def test_visit_edges_limited(graph_client):
+ actual = list(
+ graph_client.visit_edges(
+ "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4
+ )
+ )
+ expected = [
+ (
+ "swh:1:rel:0000000000000000000000000000000000000010",
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ ),
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_edges_diamond_pattern(graph_client):
actual = list(
graph_client.visit_edges(

File Metadata

Mime Type
text/plain
Expires
Nov 5 2024, 2:22 PM (12 w, 4 d ago)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
3233629

Event Timeline