diff --git a/docs/api.rst b/docs/api.rst --- a/docs/api.rst +++ b/docs/api.rst @@ -170,6 +170,43 @@ } } + +.. http:get:: /graph/randomwalk/:src/:dst + + Performs a graph *random* traversal, i.e., picking one random successor + node at each hop, from source to destination (final destination node + included). + + :param string src: starting node specified as a SWH PID + :param string dst: destination node, either as a node PID or a node type. + The traversal will stop at the first node encountered matching the + desired destination. + + :query string edges: edges types the traversal can follow; default to + ``"*"`` + :query string direction: direction in which graph edges will be followed; + can be either ``forward`` or ``backward``, default to ``forward`` + + :statuscode 200: success + :statuscode 400: invalid query string provided + :statuscode 404: starting node cannot be found + + .. sourcecode:: http + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "result": [ + "swh:1:rev:f39d7d78b70e0f39facb1e4fab77ad3df5c52a35", + "swh:1:rev:52c90f2d32bfa7d6eccd66a56c44ace1f78fbadd", + "swh:1:rev:cea92e843e40452c08ba313abc39f59efbb4c29c", + "swh:1:rev:8d517bdfb57154b8a11d7f1682ecc0f79abf8e02", + ... + ], + } + + Visit ----- diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -159,5 +159,26 @@ } close(); } + + public void random_walk(String direction, String edgesFmt, int retries, + long srcNodeId, long dstNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { + writeNode(nodeId); + } + close(); + } + + public void random_walk_type(String direction, String edgesFmt, int retries, + long srcNodeId, String dst) { + open(); + Node.Type dstType = Node.Type.fromStr(dst); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { + writeNode(nodeId); + } + close(); + } } } diff --git a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java @@ -4,9 +4,11 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedList; import java.util.Map; import java.util.Queue; +import java.util.Random; import java.util.Stack; import it.unimi.dsi.bits.LongArrayBitVector; @@ -43,6 +45,9 @@ /** Number of edges accessed during traversal */ long nbEdgesAccessed; + /** random number generator, for random walks */ + Random rng; + /** * Constructor. * @@ -64,6 +69,7 @@ this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; + this.rng = new Random(); } /** @@ -231,30 +237,119 @@ } /** - * Performs a graph traversal and returns the first found path from source to destination. + * Performs a graph traversal with backtracking, and returns the first + * found path from source to destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids */ - public ArrayList walk(long srcNodeId, T dst, String algorithm) { + public ArrayList walk(long srcNodeId, T dst, String visitOrder) { long dstNodeId = -1; - if (algorithm.equals("dfs")) { - dstNodeId = walkInternalDfs(srcNodeId, dst); - } else if (algorithm.equals("bfs")) { - dstNodeId = walkInternalBfs(srcNodeId, dst); + if (visitOrder.equals("dfs")) { + dstNodeId = walkInternalDFS(srcNodeId, dst); + } else if (visitOrder.equals("bfs")) { + dstNodeId = walkInternalBFS(srcNodeId, dst); } else { - throw new IllegalArgumentException("Unknown traversal algorithm: " + algorithm); + throw new IllegalArgumentException("Unknown visit order: " + visitOrder); } if (dstNodeId == -1) { - throw new IllegalArgumentException("Unable to find destination point: " + dst); + throw new IllegalArgumentException("Cannot find destination: " + dst); } ArrayList nodeIds = backtracking(srcNodeId, dstNodeId); return nodeIds; } + /** + * Performs a random walk (picking a random successor at each step) from + * source to destination. + * + * @param srcNodeId source node + * @param dst destination (either a node or a node type) + * @return found path as a list of node ids or an empty path to indicate + * that no suitable path have been found + */ + public ArrayList randomWalk(long srcNodeId, T dst) { + return randomWalk(srcNodeId, dst, 0); + } + + /** + * Performs a stubborn random walk (picking a random successor at each + * step) from source to destination. The walk is "stubborn" in the sense + * that it will not give up the first time if a satisfying target node is + * found, but it will retry up to a limited amount of times. + * + * @param srcNodeId source node + * @param dst destination (either a node or a node type) + * @param retries number of times to retry; 0 means no retries (single walk) + * @return found path as a list of node ids or an empty path to indicate + * that no suitable path have been found + */ + public ArrayList randomWalk(long srcNodeId, T dst, int retries) { + ArrayList path = new ArrayList(); + this.nbEdgesAccessed = 0; + long curNodeId = srcNodeId; + boolean found; + + assert retries >= 0; + + while (true) { + long nbNeighbors = graph.degree(curNodeId, useTransposed); + if (nbNeighbors == 0) { + found = false; + break; + } + Neighbors neighbors = new Neighbors(graph, useTransposed, edges, curNodeId); + Iterator successors = neighbors.iterator(); + + curNodeId = randomPick(successors, nbNeighbors); + path.add(curNodeId); + this.nbEdgesAccessed++; + + if (isDstNode(curNodeId, dst)) { + found = true; + break; + } + } + + if (found) { + return path; + } else if (retries > 0) { // try again + return randomWalk(srcNodeId, dst, retries - 1); + } else { // not found and no retries left + path.clear(); + return path; + } + } + + /** + * Randomly choose an element from an iterator + * + * @param elements iterator over selection domain + * @param lenght total length of elements iterated upon + * @return randomly chosen element + */ + private T randomPick(Iterator elements, long length) { + long elementsToSkip = Math.round(rng.nextFloat() * (length - 1)); + long skippedElements = -1; + T e; + + while (elements.hasNext()) { + e = elements.next(); + skippedElements++; + if (skippedElements < elementsToSkip) { + continue; + } else { + return e; + } + } + + assert false; + return null; + } + /** * Internal DFS function of {@link #walk}. * @@ -262,7 +357,7 @@ * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ - private long walkInternalDfs(long srcNodeId, T dst) { + private long walkInternalDFS(long srcNodeId, T dst) { Stack stack = new Stack(); this.nbEdgesAccessed = 0; @@ -295,7 +390,7 @@ * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ - private long walkInternalBfs(long srcNodeId, T dst) { + private long walkInternalBFS(long srcNodeId, T dst) { Queue queue = new LinkedList(); this.nbEdgesAccessed = 0; diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -119,6 +119,16 @@ async for node_id in it: yield node_id + async def random_walk(self, direction, edges_fmt, retries, src, dst): + if dst in PID_TYPES: + it = self.stream_proxy.random_walk_type(direction, edges_fmt, + retries, src, dst) + else: + it = self.stream_proxy.random_walk(direction, edges_fmt, retries, + src, dst) + async for node_id in it: # TODO return 404 if path is empty + yield node_id + async def visit_paths(self, direction, edges_fmt, src): path = [] async for node in self.stream_proxy.visit_paths( diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -81,6 +81,14 @@ 'direction': direction }) + def random_walk(self, src, dst, edges="*", direction="forward"): + return self.get_lines( + 'randomwalk/{}/{}'.format(src, dst), + params={ + 'edges': edges, + 'direction': direction + }) + def count_leaves(self, src, edges="*", direction="forward"): return self.get( 'leaves/count/{}'.format(src), diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -22,6 +22,10 @@ from async_generator import asynccontextmanager # type: ignore +# maximum number of retries for random walks +RANDOM_RETRIES = 5 + + @asynccontextmanager async def stream_response(request, *args, **kwargs): response = aiohttp.web.StreamResponse(*args, **kwargs) @@ -71,25 +75,35 @@ return simple_traversal -async def walk(request): - backend = request.app['backend'] +def get_walk_handler(random=False): + async def walk(request): + backend = request.app['backend'] - src = request.match_info['src'] - dst = request.match_info['dst'] - edges = request.query.get('edges', '*') - direction = request.query.get('direction', 'forward') - algo = request.query.get('traversal', 'dfs') + src = request.match_info['src'] + dst = request.match_info['dst'] + edges = request.query.get('edges', '*') + direction = request.query.get('direction', 'forward') + algo = request.query.get('traversal', 'dfs') - src_node = backend.pid2node[src] - if dst not in PID_TYPES: - dst = backend.pid2node[dst] - async with stream_response(request) as response: - async for res_node in backend.walk( - direction, edges, algo, src_node, dst - ): - res_pid = backend.node2pid[res_node] - await response.write('{}\n'.format(res_pid).encode()) - return response + src_node = backend.pid2node[src] + if dst not in PID_TYPES: + dst = backend.pid2node[dst] + async with stream_response(request) as response: + if random: + async for res_node in backend.random_walk( + direction, edges, RANDOM_RETRIES, src_node, dst + ): + res_pid = backend.node2pid[res_node] + await response.write('{}\n'.format(res_pid).encode()) + else: + async for res_node in backend.walk( + direction, edges, algo, src_node, dst + ): + res_pid = backend.node2pid[res_node] + await response.write('{}\n'.format(res_pid).encode()) + return response + + return walk async def visit_paths(request): @@ -139,7 +153,10 @@ app.router.add_route('GET', '/graph/visit/nodes/{src}', get_simple_traversal_handler('visit_nodes')) app.router.add_route('GET', '/graph/visit/paths/{src}', visit_paths) - app.router.add_route('GET', '/graph/walk/{src}/{dst}', walk) + app.router.add_route('GET', '/graph/walk/{src}/{dst}', + get_walk_handler(random=False)) + app.router.add_route('GET', '/graph/randomwalk/{src}/{dst}', + get_walk_handler(random=True)) app.router.add_route('GET', '/graph/neighbors/count/{src}', get_count_handler('neighbors')) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -105,6 +105,19 @@ assert set(actual) == set(expected) +def test_random_walk(graph_client): + """as the walk is random, we test a visit from a cnt node to the only origin in + the dataset, and only check the final node of the path (i.e., the origin) + + """ + actual = list(graph_client.random_walk( + 'swh:1:cnt:0000000000000000000000000000000000000001', 'ori', + direction='backward', + )) + expected_root = 'swh:1:ori:0000000000000000000000000000000000000021' + assert actual[-1] == expected_root + + def test_count(graph_client): actual = graph_client.count_leaves( 'swh:1:ori:0000000000000000000000000000000000000021'