diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -112,37 +112,37 @@ } } - public void leaves(String direction, String edgesFmt, long srcNodeId) { + public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.leavesVisitor(srcNodeId, this::writeNode); close(); } - public void neighbors(String direction, String edgesFmt, long srcNodeId) { + public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.neighborsVisitor(srcNodeId, this::writeNode); close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.visitNodesVisitor(srcNodeId, this::writeNode); close(); } - public void visit_edges(String direction, String edgesFmt, long srcNodeId) { + public void visit_edges(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.visitNodesVisitor(srcNodeId, null, this::writeEdge); close(); } - public void visit_paths(String direction, String edgesFmt, long srcNodeId) { + public void visit_paths(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -30,6 +30,9 @@ /** Number of edges accessed during traversal */ long nbEdgesAccessed; + /** The anti Dos limit of edges traversed while a visit */ + long maxEdges; + /** random number generator, for random walks */ Random rng; @@ -43,6 +46,10 @@ * edges */ public Traversal(Graph graph, String direction, String edgesFmt) { + this(graph, direction, edgesFmt, 0); + } + + public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges) { if (!direction.matches("forward|backward")) { throw new IllegalArgumentException("Unknown traversal direction: " + direction); } @@ -57,6 +64,7 @@ this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; + this.maxEdges = maxEdges; this.rng = new Random(); } @@ -93,6 +101,11 @@ long neighborsCnt = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); + if (this.maxEdges > 0) { + if (nbEdgesAccessed >= this.maxEdges) { + break; + } + } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { neighborsCnt++; @@ -125,6 +138,11 @@ */ public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) { this.nbEdgesAccessed = graph.outdegree(srcNodeId); + if (this.maxEdges > 0) { + if (nbEdgesAccessed >= this.maxEdges) { + return; + } + } LazyLongIterator it = graph.successors(srcNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { cb.accept(neighborNodeId); @@ -158,8 +176,12 @@ if (nodeCb != null) { nodeCb.accept(currentNodeId); } - nbEdgesAccessed += graph.outdegree(currentNodeId); + if (this.maxEdges > 0) { + if (nbEdgesAccessed >= this.maxEdges) { + break; + } + } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { @@ -216,7 +238,14 @@ currentPath.push(currentNodeId); long visitedNeighbors = 0; + nbEdgesAccessed += graph.outdegree(currentNodeId); + if (this.maxEdges > 0) { + if (nbEdgesAccessed >= this.maxEdges) { + currentPath.pop(); + return; + } + } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { visitPathsInternalVisitor(neighborNodeId, currentPath, cb); diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -68,10 +68,10 @@ method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) - async def simple_traversal(self, ttype, direction, edges_fmt, src): + async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) - async for node_id in method(direction, edges_fmt, src): + async for node_id in method(direction, edges_fmt, src, max_edges): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): @@ -92,8 +92,8 @@ async for node_id in it: # TODO return 404 if path is empty yield node_id - async def visit_edges(self, direction, edges_fmt, src): - it = self.stream_proxy.visit_edges(direction, edges_fmt, src) + async def visit_edges(self, direction, edges_fmt, src, max_edges): + it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges) # convert stream a, b, c, d -> (a, b), (c, d) prevNode = None async for node in it: @@ -103,9 +103,11 @@ else: prevNode = node - async def visit_paths(self, direction, edges_fmt, src): + async def visit_paths(self, direction, edges_fmt, src, max_edges): path = [] - async for node in self.stream_proxy.visit_paths(direction, edges_fmt, src): + async for node in self.stream_proxy.visit_paths( + direction, edges_fmt, src, max_edges + ): if node == PATH_SEPARATOR_ID: yield path path = [] diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -35,30 +35,32 @@ def stats(self): return self.get("stats") - def leaves(self, src, edges="*", direction="forward"): + def leaves(self, src, edges="*", direction="forward", max_edges=0): return self.get_lines( - "leaves/{}".format(src), params={"edges": edges, "direction": direction} + "leaves/{}".format(src), + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) - def neighbors(self, src, edges="*", direction="forward"): + def neighbors(self, src, edges="*", direction="forward", max_edges=0): return self.get_lines( - "neighbors/{}".format(src), params={"edges": edges, "direction": direction} + "neighbors/{}".format(src), + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) - def visit_nodes(self, src, edges="*", direction="forward"): + def visit_nodes(self, src, edges="*", direction="forward", max_edges=0): return self.get_lines( "visit/nodes/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) - def visit_edges(self, src, edges="*", direction="forward"): + def visit_edges(self, src, edges="*", direction="forward", max_edges=0): for edge in self.get_lines( "visit/edges/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ): yield tuple(edge.split()) - def visit_paths(self, src, edges="*", direction="forward"): + def visit_paths(self, src, edges="*", direction="forward", max_edges=0): def decode_path_wrapper(it): for e in it: yield json.loads(e) @@ -66,7 +68,7 @@ return decode_path_wrapper( self.get_lines( "visit/paths/{}".format(src), - params={"edges": edges, "direction": direction}, + params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) ) diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -74,9 +74,14 @@ lambda: self.graph.java_graph.indegree(self.id), ) - def simple_traversal(self, ttype, direction="forward", edges="*"): + def simple_traversal(self, ttype, direction="forward", edges="*", max_edges=0): for node in call_async_gen( - self.graph.backend.simple_traversal, ttype, direction, edges, self.id + self.graph.backend.simple_traversal, + ttype, + direction, + edges, + self.id, + max_edges, ): yield self.graph[node] @@ -86,15 +91,15 @@ def visit_nodes(self, *args, **kwargs): yield from self.simple_traversal("visit_nodes", *args, **kwargs) - def visit_edges(self, direction="forward", edges="*"): + def visit_edges(self, direction="forward", edges="*", max_edges=0): for src, dst in call_async_gen( - self.graph.backend.visit_edges, direction, edges, self.id + self.graph.backend.visit_edges, direction, edges, self.id, max_edges ): yield (self.graph[src], self.graph[dst]) - def visit_paths(self, direction="forward", edges="*"): + def visit_paths(self, direction="forward", edges="*", max_edges=0): for path in call_async_gen( - self.graph.backend.visit_paths, direction, edges, self.id + self.graph.backend.visit_paths, direction, edges, self.id, max_edges ): yield [self.graph[node] for node in path] diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -109,6 +109,15 @@ except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") + def get_max_edges(self): + """Validate HTTP query parameter 'max_edges', i.e., + the limit of the number of edges that can be visited""" + s = self.request.query.get("max_edges", "0") + try: + return int(s) + except ValueError: + raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") + class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" @@ -166,10 +175,15 @@ self.edges = self.get_edges() self.direction = self.get_direction() + self.max_edges = self.get_max_edges() async def stream_response(self): async for res_node in self.backend.simple_traversal( - self.simple_traversal_type, self.direction, self.edges, self.src_node + self.simple_traversal_type, + self.direction, + self.edges, + self.src_node, + self.max_edges, ): res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) @@ -236,7 +250,9 @@ class VisitEdgesView(SimpleTraversalView): async def stream_response(self): - it = self.backend.visit_edges(self.direction, self.edges, self.src_node) + it = self.backend.visit_edges( + self.direction, self.edges, self.src_node, self.max_edges + ) async for (res_src, res_dst) in it: res_src_swhid = self.swhid_of_node(res_src) res_dst_swhid = self.swhid_of_node(res_dst) @@ -247,7 +263,9 @@ content_type = "application/x-ndjson" async def stream_response(self): - it = self.backend.visit_paths(self.direction, self.edges, self.src_node) + it = self.backend.visit_paths( + self.direction, self.edges, self.src_node, self.max_edges + ) async for res_path in it: res_path_swhid = [self.swhid_of_node(n) for n in res_path] line = json.dumps(res_path_swhid) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -103,6 +103,29 @@ assert set(actual) == set(expected) +def test_visit_edges_limited(graph_client): + actual = list( + graph_client.visit_edges( + "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4 + ) + ) + expected = [ + ( + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:dir:0000000000000000000000000000000000000008", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ), + ] + assert set(actual) == set(expected) + + def test_visit_edges_diamond_pattern(graph_client): actual = list( graph_client.visit_edges(