diff --git a/docs/api.rst b/docs/api.rst --- a/docs/api.rst +++ b/docs/api.rst @@ -129,6 +129,40 @@ ... + +.. http:get:: /graph/randomwalk/:src/:dst + + Performs a graph *random* traversal, i.e., picking one random successor + node at each hop, from source to destination (final destination node + included). + + :param string src: starting node specified as a SWH PID + :param string dst: destination node, either as a node PID or a node type. + The traversal will stop at the first node encountered matching the + desired destination. + + :query string edges: edges types the traversal can follow; default to + ``"*"`` + :query string direction: direction in which graph edges will be followed; + can be either ``forward`` or ``backward``, default to ``forward`` + + :statuscode 200: success + :statuscode 400: invalid query string provided + :statuscode 404: starting node cannot be found + + .. sourcecode:: http + + HTTP/1.1 200 OK + Content-Type: text/plain + Transfer-Encoding: chunked + + swh:1:rev:f39d7d78b70e0f39facb1e4fab77ad3df5c52a35 + swh:1:rev:52c90f2d32bfa7d6eccd66a56c44ace1f78fbadd + swh:1:rev:cea92e843e40452c08ba313abc39f59efbb4c29c + swh:1:rev:8d517bdfb57154b8a11d7f1682ecc0f79abf8e02 + ... + + Visit ----- diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -159,5 +159,26 @@ } close(); } + + public void random_walk(String direction, String edgesFmt, int retries, + long srcNodeId, long dstNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { + writeNode(nodeId); + } + close(); + } + + public void random_walk_type(String direction, String edgesFmt, int retries, + long srcNodeId, String dst) { + open(); + Node.Type dstType = Node.Type.fromStr(dst); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { + writeNode(nodeId); + } + close(); + } } } diff --git a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java --- a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java @@ -4,9 +4,11 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedList; import java.util.Map; import java.util.Queue; +import java.util.Random; import java.util.Stack; import it.unimi.dsi.bits.LongArrayBitVector; @@ -43,6 +45,9 @@ /** Number of edges accessed during traversal */ long nbEdgesAccessed; + /** random number generator, for random walks */ + Random rng; + /** * Constructor. * @@ -64,6 +69,7 @@ this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; + this.rng = new Random(); } /** @@ -231,30 +237,112 @@ } /** - * Performs a graph traversal and returns the first found path from source to destination. + * Performs a graph traversal with backtracking, and returns the first + * found path from source to destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids */ - public ArrayList walk(long srcNodeId, T dst, String algorithm) { + public ArrayList walk(long srcNodeId, T dst, String visitOrder) { long dstNodeId = -1; - if (algorithm.equals("dfs")) { - dstNodeId = walkInternalDfs(srcNodeId, dst); - } else if (algorithm.equals("bfs")) { - dstNodeId = walkInternalBfs(srcNodeId, dst); + if (visitOrder.equals("dfs")) { + dstNodeId = walkInternalDFS(srcNodeId, dst); + } else if (visitOrder.equals("bfs")) { + dstNodeId = walkInternalBFS(srcNodeId, dst); } else { - throw new IllegalArgumentException("Unknown traversal algorithm: " + algorithm); + throw new IllegalArgumentException("Unknown visit order: " + visitOrder); } if (dstNodeId == -1) { - throw new IllegalArgumentException("Unable to find destination point: " + dst); + throw new IllegalArgumentException("Cannot find destination: " + dst); } ArrayList nodeIds = backtracking(srcNodeId, dstNodeId); return nodeIds; } + /** + * Performs a random walk (picking a random successor at each step) from + * source to destination. + * + * @param srcNodeId source node + * @param dst destination (either a node or a node type) + * @return found path as a list of node ids or an empty path to indicate + * that no suitable path have been found + */ + public ArrayList randomWalk(long srcNodeId, T dst) { + return randomWalk(srcNodeId, dst, 0); + } + + /** + * Performs a stubborn random walk (picking a random successor at each + * step) from source to destination. The walk is "stubborn" in the sense + * that it will not give up the first time if a satisfying target node is + * found, but it will retry up to a limited amount of times. + * + * @param srcNodeId source node + * @param dst destination (either a node or a node type) + * @param retries number of times to retry; 0 means no retries (single walk) + * @return found path as a list of node ids or an empty path to indicate + * that no suitable path have been found + */ + public ArrayList randomWalk(long srcNodeId, T dst, int retries) { + long curNodeId = srcNodeId; + ArrayList path = new ArrayList(); + this.nbEdgesAccessed = 0; + boolean found; + + if (retries < 0) { + throw new IllegalArgumentException("Negative number of retries given: " + retries); + } + + while (true) { + path.add(curNodeId); + Neighbors neighbors = new Neighbors(graph, useTransposed, edges, curNodeId); + curNodeId = randomPick(neighbors.iterator()); + if (curNodeId < 0) { + found = false; + break; + } + if (isDstNode(curNodeId, dst)) { + path.add(curNodeId); + found = true; + break; + } + } + + if (found) { + return path; + } else if (retries > 0) { // try again + return randomWalk(srcNodeId, dst, retries - 1); + } else { // not found and no retries left + path.clear(); + return path; + } + } + + /** + * Randomly choose an element from an iterator over Longs using reservoir + * sampling + * + * @param elements iterator over selection domain + * @return randomly chosen element or -1 if no suitable element was found + */ + private long randomPick(Iterator elements) { + long curPick = -1; + long seenCandidates = 0; + + while (elements.hasNext()) { + seenCandidates++; + if (Math.round(rng.nextFloat() * (seenCandidates - 1)) == 0) { + curPick = elements.next(); + } + } + + return curPick; + } + /** * Internal DFS function of {@link #walk}. * @@ -262,7 +350,7 @@ * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ - private long walkInternalDfs(long srcNodeId, T dst) { + private long walkInternalDFS(long srcNodeId, T dst) { Stack stack = new Stack(); this.nbEdgesAccessed = 0; @@ -295,7 +383,7 @@ * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ - private long walkInternalBfs(long srcNodeId, T dst) { + private long walkInternalBFS(long srcNodeId, T dst) { Queue queue = new LinkedList(); this.nbEdgesAccessed = 0; diff --git a/swh/graph/backend.py b/swh/graph/backend.py --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -120,6 +120,16 @@ async for node_id in it: yield node_id + async def random_walk(self, direction, edges_fmt, retries, src, dst): + if dst in PID_TYPES: + it = self.stream_proxy.random_walk_type(direction, edges_fmt, + retries, src, dst) + else: + it = self.stream_proxy.random_walk(direction, edges_fmt, retries, + src, dst) + async for node_id in it: # TODO return 404 if path is empty + yield node_id + async def visit_paths(self, direction, edges_fmt, src): path = [] async for node in self.stream_proxy.visit_paths( diff --git a/swh/graph/client.py b/swh/graph/client.py --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -81,6 +81,14 @@ 'direction': direction }) + def random_walk(self, src, dst, edges="*", direction="forward"): + return self.get_lines( + 'randomwalk/{}/{}'.format(src, dst), + params={ + 'edges': edges, + 'direction': direction + }) + def count_leaves(self, src, edges="*", direction="forward"): return self.get( 'leaves/count/{}'.format(src), diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -23,6 +23,10 @@ from async_generator import asynccontextmanager # type: ignore +# maximum number of retries for random walks +RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration + + @asynccontextmanager async def stream_response(request, content_type='text/plain', *args, **kwargs): @@ -121,25 +125,31 @@ return simple_traversal -async def walk(request): - backend = request.app['backend'] +def get_walk_handler(random=False): + async def walk(request): + backend = request.app['backend'] - src = request.match_info['src'] - dst = request.match_info['dst'] - edges = get_edges(request) - direction = get_direction(request) - algo = get_traversal(request) + src = request.match_info['src'] + dst = request.match_info['dst'] + edges = get_edges(request) + direction = get_direction(request) + algo = get_traversal(request) - src_node = node_of_pid(src, backend) - if dst not in PID_TYPES: - dst = node_of_pid(dst, backend) - async with stream_response(request) as response: - async for res_node in backend.walk( - direction, edges, algo, src_node, dst - ): - res_pid = pid_of_node(res_node, backend) - await response.write('{}\n'.format(res_pid).encode()) - return response + src_node = node_of_pid(src, backend) + if dst not in PID_TYPES: + dst = node_of_pid(dst, backend) + async with stream_response(request) as response: + if random: + it = backend.random_walk(direction, edges, RANDOM_RETRIES, + src_node, dst) + else: + it = backend.walk(direction, edges, algo, src_node, dst) + async for res_node in it: + res_pid = pid_of_node(res_node, backend) + await response.write('{}\n'.format(res_pid).encode()) + return response + + return walk async def visit_paths(request): @@ -190,7 +200,10 @@ app.router.add_get('/graph/visit/nodes/{src}', get_simple_traversal_handler('visit_nodes')) app.router.add_get('/graph/visit/paths/{src}', visit_paths) - app.router.add_get('/graph/walk/{src}/{dst}', walk) + app.router.add_get('/graph/walk/{src}/{dst}', + get_walk_handler(random=False)) + app.router.add_get('/graph/randomwalk/{src}/{dst}', + get_walk_handler(random=True)) app.router.add_get('/graph/neighbors/count/{src}', get_count_handler('neighbors')) diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -105,6 +105,18 @@ assert set(actual) == set(expected) +def test_random_walk(graph_client): + """as the walk is random, we test a visit from a cnt node to the only origin in + the dataset, and only check the final node of the path (i.e., the origin) + + """ + src = 'swh:1:cnt:0000000000000000000000000000000000000001' + actual = list(graph_client.random_walk(src, 'ori', direction='backward')) + expected_root = 'swh:1:ori:0000000000000000000000000000000000000021' + assert actual[0] == src + assert actual[-1] == expected_root + + def test_count(graph_client): actual = graph_client.count_leaves( 'swh:1:ori:0000000000000000000000000000000000000021'