diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java index 63050ee..dfe2d9b 100644 --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,184 +1,196 @@ package org.softwareheritage.graph; import java.util.ArrayList; import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.PropertyNamingStrategy; import org.softwareheritage.graph.algo.NodeIdConsumer; import org.softwareheritage.graph.algo.Stats; import org.softwareheritage.graph.algo.Traversal; public class Entry { private Graph graph; private final long PATH_SEPARATOR_ID = -1; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = new Graph(graphBasename); System.err.println("Graph loaded."); } public Graph get_graph() { return graph.copy(); } public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } private interface NodeCountVisitor { void accept(long nodeId, NodeIdConsumer consumer); } private int count_visitor(NodeCountVisitor f, long srcNodeId) { int count[] = { 0 }; f.accept(srcNodeId, (node) -> { count[0]++; }); return count[0]; } public int count_leaves(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::leavesVisitor, srcNodeId); } public int count_neighbors(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::neighborsVisitor, srcNodeId); } public int count_visit_nodes(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::visitNodesVisitor, srcNodeId); } public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(this.graph.copy(), clientFIFO); } public class QueryHandler { Graph graph; DataOutputStream out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } public void writeNode(long nodeId) { try { out.writeLong(nodeId); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } + public void writeEdge(long srcId, long dstId) { + writeNode(srcId); + writeNode(dstId); + } + public void writePath(ArrayList path) { for (Long nodeId : path) { writeNode(nodeId); } writeNode(PATH_SEPARATOR_ID); } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); this.out = new DataOutputStream(file); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void leaves(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.leavesVisitor(srcNodeId, this::writeNode); close(); } public void neighbors(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.neighborsVisitor(srcNodeId, this::writeNode); close(); } public void visit_nodes(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitNodesVisitor(srcNodeId, this::writeNode); close(); } + public void visit_edges(String direction, String edgesFmt, long srcNodeId) { + open(); + Traversal t = new Traversal(this.graph, direction, edgesFmt); + t.visitNodesVisitor(srcNodeId, null, this::writeEdge); + close(); + } + public void visit_paths(String direction, String edgesFmt, long srcNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } public void walk(String direction, String edgesFmt, String algorithm, long srcNodeId, long dstNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { writeNode(nodeId); } close(); } public void walk_type(String direction, String edgesFmt, String algorithm, long srcNodeId, String dst) { open(); Node.Type dstType = Node.Type.fromStr(dst); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { writeNode(nodeId); } close(); } public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { writeNode(nodeId); } close(); } public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) { open(); Node.Type dstType = Node.Type.fromStr(dst); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { writeNode(nodeId); } close(); } } } diff --git a/java/src/main/java/org/softwareheritage/graph/algo/EdgeIdConsumer.java b/java/src/main/java/org/softwareheritage/graph/algo/EdgeIdConsumer.java new file mode 100644 index 0000000..7c82f0b --- /dev/null +++ b/java/src/main/java/org/softwareheritage/graph/algo/EdgeIdConsumer.java @@ -0,0 +1,9 @@ +package org.softwareheritage.graph.algo; + +public interface EdgeIdConsumer { + + /** Callback for incrementally receiving edge identifiers during a graph + * visit. + */ + void accept(long srcId, long dstId); +} diff --git a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java index 91adf4e..3fb08f0 100644 --- a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java @@ -1,484 +1,494 @@ package org.softwareheritage.graph.algo; import java.util.*; import it.unimi.dsi.bits.LongArrayBitVector; import org.softwareheritage.graph.AllowedEdges; import org.softwareheritage.graph.Endpoint; import org.softwareheritage.graph.Graph; import org.softwareheritage.graph.Neighbors; import org.softwareheritage.graph.Node; /** * Traversal algorithms on the compressed graph. *

* Internal implementation of the traversal API endpoints. These methods only input/output internal * long ids, which are converted in the {@link Endpoint} higher-level class to Software Heritage * PID. * * @author The Software Heritage developers * @see org.softwareheritage.graph.Endpoint */ public class Traversal { /** Graph used in the traversal */ Graph graph; /** Boolean to specify the use of the transposed graph */ boolean useTransposed; /** Graph edge restriction */ AllowedEdges edges; /** Hash set storing if we have visited a node */ HashSet visited; /** Hash map storing parent node id for each nodes during a traversal */ Map parentNode; /** Number of edges accessed during traversal */ long nbEdgesAccessed; /** random number generator, for random walks */ Random rng; /** * Constructor. * * @param graph graph used in the traversal * @param direction a string (either "forward" or "backward") specifying edge orientation * @param edgesFmt a formatted string describing allowed edges */ public Traversal(Graph graph, String direction, String edgesFmt) { if (!direction.matches("forward|backward")) { throw new IllegalArgumentException("Unknown traversal direction: " + direction); } this.graph = graph; this.useTransposed = (direction.equals("backward")); this.edges = new AllowedEdges(graph, edgesFmt); long nbNodes = graph.getNbNodes(); this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; this.rng = new Random(); } /** * Returns number of accessed edges during traversal. * * @return number of edges accessed in last traversal */ public long getNbEdgesAccessed() { return nbEdgesAccessed; } /** * Returns number of accessed nodes during traversal. * * @return number of nodes accessed in last traversal */ public long getNbNodesAccessed() { return this.visited.size(); } /** * Push version of {@link leaves}: will fire passed callback for each leaf. */ public void leavesVisitor(long srcNodeId, NodeIdConsumer cb) { Stack stack = new Stack(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); long neighborsCnt = 0; nbEdgesAccessed += graph.degree(currentNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) { neighborsCnt++; if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); } } if (neighborsCnt == 0) { cb.accept(currentNodeId); } } } /** * Returns the leaves of a subgraph rooted at the specified source node. * * @param srcNodeId source node * @return list of node ids corresponding to the leaves */ public ArrayList leaves(long srcNodeId) { ArrayList nodeIds = new ArrayList(); leavesVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId)); return nodeIds; } /** * Push version of {@link neighbors}: will fire passed callback on each * neighbor. */ public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) { this.nbEdgesAccessed = graph.degree(srcNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, srcNodeId)) { cb.accept(neighborNodeId); } } /** * Returns node direct neighbors (linked with exactly one edge). * * @param srcNodeId source node * @return list of node ids corresponding to the neighbors */ public ArrayList neighbors(long srcNodeId) { ArrayList nodeIds = new ArrayList(); neighborsVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId)); return nodeIds; } /** * Push version of {@link visitNodes}: will fire passed callback on each * visited node. */ - public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) { + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) { Stack stack = new Stack(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); - cb.accept(currentNodeId); + if (nodeCb != null) { + nodeCb.accept(currentNodeId); + } nbEdgesAccessed += graph.degree(currentNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) { if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); + if (edgeCb != null) { + edgeCb.accept(currentNodeId, neighborNodeId); + } } } } } + /** One-argument version to handle callbacks properly */ + public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) { + visitNodesVisitor(srcNodeId, cb, null); + } + /** * Performs a graph traversal and returns explored nodes. * * @param srcNodeId source node * @return list of explored node ids */ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList(); visitNodesVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId)); return nodeIds; } /** * Push version of {@link visitPaths}: will fire passed callback on each * discovered (complete) path. */ public void visitPathsVisitor(long srcNodeId, PathConsumer cb) { Stack currentPath = new Stack(); this.nbEdgesAccessed = 0; visitPathsInternalVisitor(srcNodeId, currentPath, cb); } /** * Performs a graph traversal and returns explored paths. * * @param srcNodeId source node * @return list of explored paths (represented as a list of node ids) */ public ArrayList> visitPaths(long srcNodeId) { ArrayList> paths = new ArrayList<>(); visitPathsVisitor(srcNodeId, (path) -> paths.add(path)); return paths; } private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb) { currentPath.push(currentNodeId); long visitedNeighbors = 0; nbEdgesAccessed += graph.degree(currentNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) { visitPathsInternalVisitor(neighborNodeId, currentPath, cb); visitedNeighbors++; } if (visitedNeighbors == 0) { ArrayList path = new ArrayList(); for (long nodeId : currentPath) { path.add(nodeId); } cb.accept(path); } currentPath.pop(); } /** * Performs a graph traversal with backtracking, and returns the first * found path from source to destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids */ public ArrayList walk(long srcNodeId, T dst, String visitOrder) { long dstNodeId = -1; if (visitOrder.equals("dfs")) { dstNodeId = walkInternalDFS(srcNodeId, dst); } else if (visitOrder.equals("bfs")) { dstNodeId = walkInternalBFS(srcNodeId, dst); } else { throw new IllegalArgumentException("Unknown visit order: " + visitOrder); } if (dstNodeId == -1) { throw new IllegalArgumentException("Cannot find destination: " + dst); } ArrayList nodeIds = backtracking(srcNodeId, dstNodeId); return nodeIds; } /** * Performs a random walk (picking a random successor at each step) from * source to destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids or an empty path to indicate * that no suitable path have been found */ public ArrayList randomWalk(long srcNodeId, T dst) { return randomWalk(srcNodeId, dst, 0); } /** * Performs a stubborn random walk (picking a random successor at each * step) from source to destination. The walk is "stubborn" in the sense * that it will not give up the first time if a satisfying target node is * found, but it will retry up to a limited amount of times. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @param retries number of times to retry; 0 means no retries (single walk) * @return found path as a list of node ids or an empty path to indicate * that no suitable path have been found */ public ArrayList randomWalk(long srcNodeId, T dst, int retries) { long curNodeId = srcNodeId; ArrayList path = new ArrayList(); this.nbEdgesAccessed = 0; boolean found; if (retries < 0) { throw new IllegalArgumentException("Negative number of retries given: " + retries); } while (true) { path.add(curNodeId); Neighbors neighbors = new Neighbors(graph, useTransposed, edges, curNodeId); curNodeId = randomPick(neighbors.iterator()); if (curNodeId < 0) { found = false; break; } if (isDstNode(curNodeId, dst)) { path.add(curNodeId); found = true; break; } } if (found) { return path; } else if (retries > 0) { // try again return randomWalk(srcNodeId, dst, retries - 1); } else { // not found and no retries left path.clear(); return path; } } /** * Randomly choose an element from an iterator over Longs using reservoir * sampling * * @param elements iterator over selection domain * @return randomly chosen element or -1 if no suitable element was found */ private long randomPick(Iterator elements) { long curPick = -1; long seenCandidates = 0; while (elements.hasNext()) { seenCandidates++; if (Math.round(rng.nextFloat() * (seenCandidates - 1)) == 0) { curPick = elements.next(); } } return curPick; } /** * Internal DFS function of {@link #walk}. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ private long walkInternalDFS(long srcNodeId, T dst) { Stack stack = new Stack(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (isDstNode(currentNodeId, dst)) { return currentNodeId; } nbEdgesAccessed += graph.degree(currentNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) { if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); parentNode.put(neighborNodeId, currentNodeId); } } } return -1; } /** * Internal BFS function of {@link #walk}. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ private long walkInternalBFS(long srcNodeId, T dst) { Queue queue = new LinkedList(); this.nbEdgesAccessed = 0; queue.add(srcNodeId); visited.add(srcNodeId); while (!queue.isEmpty()) { long currentNodeId = queue.poll(); if (isDstNode(currentNodeId, dst)) { return currentNodeId; } nbEdgesAccessed += graph.degree(currentNodeId, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) { if (!visited.contains(neighborNodeId)) { queue.add(neighborNodeId); visited.add(neighborNodeId); parentNode.put(neighborNodeId, currentNodeId); } } } return -1; } /** * Internal function of {@link #walk} to check if a node corresponds to the destination. * * @param nodeId current node * @param dst destination (either a node or a node type) * @return true if the node is a destination, or false otherwise */ private boolean isDstNode(long nodeId, T dst) { if (dst instanceof Long) { long dstNodeId = (Long) dst; return nodeId == dstNodeId; } else if (dst instanceof Node.Type) { Node.Type dstType = (Node.Type) dst; return graph.getNodeType(nodeId) == dstType; } else { return false; } } /** * Internal backtracking function of {@link #walk}. * * @param srcNodeId source node * @param dstNodeId destination node * @return the found path, as a list of node ids */ private ArrayList backtracking(long srcNodeId, long dstNodeId) { ArrayList path = new ArrayList(); long currentNodeId = dstNodeId; while (currentNodeId != srcNodeId) { path.add(currentNodeId); currentNodeId = parentNode.get(currentNodeId); } path.add(srcNodeId); Collections.reverse(path); return path; } public Long findCommonDescendant(long lhsNode, long rhsNode) { Queue lhsStack = new ArrayDeque<>(); Queue rhsStack = new ArrayDeque<>(); HashSet lhsVisited = new HashSet<>(); HashSet rhsVisited = new HashSet<>(); lhsStack.add(lhsNode); rhsStack.add(rhsNode); lhsVisited.add(lhsNode); rhsVisited.add(rhsNode); this.nbEdgesAccessed = 0; Long curNode; while (!lhsStack.isEmpty() || !rhsStack.isEmpty()) { if (!lhsStack.isEmpty()) { curNode = lhsStack.poll(); nbEdgesAccessed += graph.degree(curNode, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, curNode)) { if (!lhsVisited.contains(neighborNodeId)) { if (rhsVisited.contains(neighborNodeId)) return neighborNodeId; lhsStack.add(neighborNodeId); lhsVisited.add(neighborNodeId); } } } if (!rhsStack.isEmpty()) { curNode = rhsStack.poll(); nbEdgesAccessed += graph.degree(curNode, useTransposed); for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, curNode)) { if (!rhsVisited.contains(neighborNodeId)) { if (lhsVisited.contains(neighborNodeId)) return neighborNodeId; rhsStack.add(neighborNodeId); rhsVisited.add(neighborNodeId); } } } } return null; } } diff --git a/swh/graph/backend.py b/swh/graph/backend.py index 370dd33..d838023 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,183 +1,194 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os import struct import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway from swh.graph.config import check_config from swh.graph.pid import NodeToPidMap, PidToNodeMap from swh.model.identifiers import PID_TYPES BUF_SIZE = 64 * 1024 BIN_FMT = ">q" # 64 bit integer, big endian PATH_SEPARATOR_ID = -1 NODE2PID_EXT = "node2pid.bin" PID2NODE_EXT = "pid2node.bin" def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path, config=None): self.gateway = None self.entry = None self.graph_path = graph_path self.config = check_config(config or {}) def __enter__(self): self.gateway = JavaGateway.launch_gateway( java_path=None, javaopts=self.config["java_tool_options"].split(), classpath=self.config["classpath"], die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.node2pid = NodeToPidMap(self.graph_path + "." + NODE2PID_EXT) self.pid2node = PidToNodeMap(self.graph_path + "." + PID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) return self def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() def stats(self): return self.entry.stats() def count(self, ttype, direction, edges_fmt, src): method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) async def simple_traversal(self, ttype, direction, edges_fmt, src): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) async for node_id in method(direction, edges_fmt, src): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): if dst in PID_TYPES: it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src, dst) else: it = self.stream_proxy.walk(direction, edges_fmt, algo, src, dst) async for node_id in it: yield node_id async def random_walk(self, direction, edges_fmt, retries, src, dst): if dst in PID_TYPES: it = self.stream_proxy.random_walk_type( direction, edges_fmt, retries, src, dst ) else: it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst) async for node_id in it: # TODO return 404 if path is empty yield node_id + async def visit_edges(self, direction, edges_fmt, src): + it = self.stream_proxy.visit_edges(direction, edges_fmt, src) + # convert stream a, b, c, d -> (a, b), (c, d) + prevNode = None + async for node in it: + if prevNode is not None: + yield (prevNode, node) + prevNode = None + else: + prevNode = node + async def visit_paths(self, direction, edges_fmt, src): path = [] async for node in self.stream_proxy.visit_paths(direction, edges_fmt, src): if node == PATH_SEPARATOR_ID: yield path path = [] else: path.append(node) class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, "rb") # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: while True: data = await loop.run_in_executor(None, f.read, BUF_SIZE) if not data: break for data in struct.iter_unpack(BIN_FMT, data): yield data[0] class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix="swh-graph-") as tmpdirname: cli_fifo = os.path.join(tmpdirname, "swh-graph.fifo") os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/client.py b/swh/graph/client.py index 5821dfe..26f05ed 100644 --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -1,103 +1,111 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from swh.core.api import RPCClient class GraphAPIError(Exception): """Graph API Error""" def __str__(self): return "An unexpected error occurred in the Graph backend: {}".format(self.args) class RemoteGraphClient(RPCClient): """Client to the Software Heritage Graph.""" def __init__(self, url, timeout=None): super().__init__(api_exception=GraphAPIError, url=url, timeout=timeout) def raw_verb_lines(self, verb, endpoint, **kwargs): response = self.raw_verb(verb, endpoint, stream=True, **kwargs) self.raise_for_status(response) for line in response.iter_lines(): yield line.decode().lstrip("\n") def get_lines(self, endpoint, **kwargs): yield from self.raw_verb_lines("get", endpoint, **kwargs) # Web API endpoints def stats(self): return self.get("stats") def leaves(self, src, edges="*", direction="forward"): return self.get_lines( "leaves/{}".format(src), params={"edges": edges, "direction": direction} ) def neighbors(self, src, edges="*", direction="forward"): return self.get_lines( "neighbors/{}".format(src), params={"edges": edges, "direction": direction} ) def visit_nodes(self, src, edges="*", direction="forward"): return self.get_lines( "visit/nodes/{}".format(src), params={"edges": edges, "direction": direction}, ) + def visit_edges(self, src, edges="*", direction="forward"): + for edge in self.get_lines( + "visit/edges/{}".format(src), + params={"edges": edges, "direction": direction}, + ): + print(edge) + yield tuple(edge.split()) + def visit_paths(self, src, edges="*", direction="forward"): def decode_path_wrapper(it): for e in it: yield json.loads(e) return decode_path_wrapper( self.get_lines( "visit/paths/{}".format(src), params={"edges": edges, "direction": direction}, ) ) def walk( self, src, dst, edges="*", traversal="dfs", direction="forward", limit=None ): endpoint = "walk/{}/{}" return self.get_lines( endpoint.format(src, dst), params={ "edges": edges, "traversal": traversal, "direction": direction, "limit": limit, }, ) def random_walk(self, src, dst, edges="*", direction="forward", limit=None): endpoint = "randomwalk/{}/{}" return self.get_lines( endpoint.format(src, dst), params={"edges": edges, "direction": direction, "limit": limit}, ) def count_leaves(self, src, edges="*", direction="forward"): return self.get( "leaves/count/{}".format(src), params={"edges": edges, "direction": direction}, ) def count_neighbors(self, src, edges="*", direction="forward"): return self.get( "neighbors/count/{}".format(src), params={"edges": edges, "direction": direction}, ) def count_visit_nodes(self, src, edges="*", direction="forward"): return self.get( "visit/nodes/count/{}".format(src), params={"edges": edges, "direction": direction}, ) diff --git a/swh/graph/graph.py b/swh/graph/graph.py index f26bb31..fda4c68 100644 --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -1,179 +1,185 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import functools from swh.graph.backend import Backend from swh.graph.dot import dot_to_svg, graph_dot, KIND_TO_SHAPE BASE_URL = "https://archive.softwareheritage.org/browse" KIND_TO_URL_FRAGMENT = { "ori": "/origin/{}", "snp": "/snapshot/{}", "rel": "/release/{}", "rev": "/revision/{}", "dir": "/directory/{}", "cnt": "/content/sha1_git:{}/", } def call_async_gen(generator, *args, **kwargs): loop = asyncio.get_event_loop() it = generator(*args, **kwargs).__aiter__() while True: try: res = loop.run_until_complete(it.__anext__()) yield res except StopAsyncIteration: break class Neighbors: """Neighbor iterator with custom O(1) length method""" def __init__(self, graph, iterator, length_func): self.graph = graph self.iterator = iterator self.length_func = length_func def __iter__(self): return self def __next__(self): succ = self.iterator.nextLong() if succ == -1: raise StopIteration return GraphNode(self.graph, succ) def __len__(self): return self.length_func() class GraphNode: """Node in the SWH graph""" def __init__(self, graph, node_id): self.graph = graph self.id = node_id def children(self): return Neighbors( self.graph, self.graph.java_graph.successors(self.id), lambda: self.graph.java_graph.outdegree(self.id), ) def parents(self): return Neighbors( self.graph, self.graph.java_graph.predecessors(self.id), lambda: self.graph.java_graph.indegree(self.id), ) def simple_traversal(self, ttype, direction="forward", edges="*"): for node in call_async_gen( self.graph.backend.simple_traversal, ttype, direction, edges, self.id ): yield self.graph[node] def leaves(self, *args, **kwargs): yield from self.simple_traversal("leaves", *args, **kwargs) def visit_nodes(self, *args, **kwargs): yield from self.simple_traversal("visit_nodes", *args, **kwargs) + def visit_edges(self, direction="forward", edges="*"): + for src, dst in call_async_gen( + self.graph.backend.visit_edges, direction, edges, self.id + ): + yield (self.graph[src], self.graph[dst]) + def visit_paths(self, direction="forward", edges="*"): for path in call_async_gen( self.graph.backend.visit_paths, direction, edges, self.id ): yield [self.graph[node] for node in path] def walk(self, dst, direction="forward", edges="*", traversal="dfs"): for node in call_async_gen( self.graph.backend.walk, direction, edges, traversal, self.id, dst ): yield self.graph[node] def _count(self, ttype, direction="forward", edges="*"): return self.graph.backend.count(ttype, direction, edges, self.id) count_leaves = functools.partialmethod(_count, ttype="leaves") count_neighbors = functools.partialmethod(_count, ttype="neighbors") count_visit_nodes = functools.partialmethod(_count, ttype="visit_nodes") @property def pid(self): return self.graph.node2pid[self.id] @property def kind(self): return self.pid.split(":")[2] def __str__(self): return self.pid def __repr__(self): return "<{}>".format(self.pid) def dot_fragment(self): swh, version, kind, hash = self.pid.split(":") label = "{}:{}..{}".format(kind, hash[0:2], hash[-2:]) url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) shape = KIND_TO_SHAPE[kind] return '{} [label="{}", href="{}", target="_blank", shape="{}"];'.format( self.id, label, url, shape ) def _repr_svg_(self): nodes = [self, *list(self.children()), *list(self.parents())] dot = graph_dot(nodes) svg = dot_to_svg(dot) return svg class Graph: def __init__(self, backend, node2pid, pid2node): self.backend = backend self.java_graph = backend.entry.get_graph() self.node2pid = node2pid self.pid2node = pid2node def stats(self): return self.backend.stats() @property def path(self): return self.java_graph.getPath() def __len__(self): return self.java_graph.getNbNodes() def __getitem__(self, node_id): if isinstance(node_id, int): self.node2pid[node_id] # check existence return GraphNode(self, node_id) elif isinstance(node_id, str): node_id = self.pid2node[node_id] return GraphNode(self, node_id) def __iter__(self): for pid, pos in self.backend.pid2node: yield self[pid] def iter_prefix(self, prefix): for pid, pos in self.backend.pid2node.iter_prefix(prefix): yield self[pid] def iter_type(self, pid_type): for pid, pos in self.backend.pid2node.iter_type(pid_type): yield self[pid] @contextlib.contextmanager def load(graph_path): with Backend(graph_path) as backend: yield Graph(backend, backend.node2pid, backend.pid2node) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index 5e99547..b0d0435 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,253 +1,272 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio import json import aiohttp.web from collections import deque from swh.core.api.asynchronous import RPCServerApp from swh.model.identifiers import PID_TYPES from swh.model.exceptions import ValidationError try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration @asynccontextmanager async def stream_response(request, content_type="text/plain", *args, **kwargs): response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = content_type await response.prepare(request) yield response await response.write_eof() async def index(request): return aiohttp.web.Response( content_type="text/html", body=""" Software Heritage storage server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""", ) async def stats(request): stats = request.app["backend"].stats() return aiohttp.web.Response(body=stats, content_type="application/json") def get_direction(request): """validate HTTP query parameter `direction`""" s = request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(body=f"invalid direction: {s}") return s def get_edges(request): """validate HTTP query parameter `edges`, i.e., edge restrictions""" s = request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in PID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(body=f"invalid edge restriction: {s}") return s def get_traversal(request): """validate HTTP query parameter `traversal`, i.e., visit order""" s = request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(body=f"invalid traversal order: {s}") return s def get_limit(request): """validate HTTP query parameter `limit`, i.e., number of results""" s = request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(body=f"invalid limit value: {s}") def node_of_pid(pid, backend): """lookup a PID in a pid2node map, failing in an HTTP-nice way if needed""" try: return backend.pid2node[pid] except KeyError: raise aiohttp.web.HTTPNotFound(body=f"PID not found: {pid}") except ValidationError: raise aiohttp.web.HTTPBadRequest(body=f"malformed PID: {pid}") def pid_of_node(node, backend): """lookup a node in a node2pid map, failing in an HTTP-nice way if needed """ try: return backend.node2pid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( body=f"reverse lookup failed for node id: {node}" ) def get_simple_traversal_handler(ttype): async def simple_traversal(request): backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) async with stream_response(request) as response: async for res_node in backend.simple_traversal( ttype, direction, edges, src_node ): res_pid = pid_of_node(res_node, backend) await response.write("{}\n".format(res_pid).encode()) return response return simple_traversal def get_walk_handler(random=False): async def walk(request): backend = request.app["backend"] src = request.match_info["src"] dst = request.match_info["dst"] edges = get_edges(request) direction = get_direction(request) algo = get_traversal(request) limit = get_limit(request) src_node = node_of_pid(src, backend) if dst not in PID_TYPES: dst = node_of_pid(dst, backend) async with stream_response(request) as response: if random: it = backend.random_walk( direction, edges, RANDOM_RETRIES, src_node, dst ) else: it = backend.walk(direction, edges, algo, src_node, dst) if limit < 0: queue = deque(maxlen=-limit) async for res_node in it: res_pid = pid_of_node(res_node, backend) queue.append("{}\n".format(res_pid).encode()) while queue: await response.write(queue.popleft()) else: count = 0 async for res_node in it: if limit == 0 or count < limit: res_pid = pid_of_node(res_node, backend) await response.write("{}\n".format(res_pid).encode()) count += 1 else: break return response return walk async def visit_paths(request): backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) it = backend.visit_paths(direction, edges, src_node) async with stream_response( request, content_type="application/x-ndjson" ) as response: async for res_path in it: res_path_pid = [pid_of_node(n, backend) for n in res_path] line = json.dumps(res_path_pid) await response.write("{}\n".format(line).encode()) return response +async def visit_edges(request): + backend = request.app["backend"] + + src = request.match_info["src"] + edges = get_edges(request) + direction = get_direction(request) + + src_node = node_of_pid(src, backend) + it = backend.visit_edges(direction, edges, src_node) + print(it) + async with stream_response(request) as response: + async for (res_src, res_dst) in it: + res_src_pid = pid_of_node(res_src, backend) + res_dst_pid = pid_of_node(res_dst, backend) + await response.write("{} {}\n".format(res_src_pid, res_dst_pid).encode()) + return response + + def get_count_handler(ttype): async def count(request): loop = asyncio.get_event_loop() backend = request.app["backend"] src = request.match_info["src"] edges = get_edges(request) direction = get_direction(request) src_node = node_of_pid(src, backend) cnt = await loop.run_in_executor( None, backend.count, ttype, direction, edges, src_node ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") return count def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.router.add_get("/", index) app.router.add_get("/graph", index) app.router.add_get("/graph/stats", stats) app.router.add_get("/graph/leaves/{src}", get_simple_traversal_handler("leaves")) app.router.add_get( "/graph/neighbors/{src}", get_simple_traversal_handler("neighbors") ) app.router.add_get( "/graph/visit/nodes/{src}", get_simple_traversal_handler("visit_nodes") ) + app.router.add_get("/graph/visit/edges/{src}", visit_edges) app.router.add_get("/graph/visit/paths/{src}", visit_paths) # temporarily disabled in wait of a proper fix for T1969 # app.router.add_get('/graph/walk/{src}/{dst}', # get_walk_handler(random=False)) # app.router.add_get('/graph/walk/last/{src}/{dst}', # get_walk_handler(random=False, last=True)) app.router.add_get("/graph/randomwalk/{src}/{dst}", get_walk_handler(random=True)) app.router.add_get("/graph/neighbors/count/{src}", get_count_handler("neighbors")) app.router.add_get("/graph/leaves/count/{src}", get_count_handler("leaves")) app.router.add_get( "/graph/visit/nodes/count/{src}", get_count_handler("visit_nodes") ) app["backend"] = backend return app diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index 7651ba1..642cf6e 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,230 +1,258 @@ import pytest from pytest import raises from swh.core.api import RemoteException def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {"counts", "ratios", "indegree", "outdegree"} assert set(stats["counts"].keys()) == {"nodes", "edges"} assert set(stats["ratios"].keys()) == { "compression", "bits_per_node", "bits_per_edge", "avg_locality", } assert set(stats["indegree"].keys()) == {"min", "max", "avg"} assert set(stats["outdegree"].keys()) == {"min", "max", "avg"} assert stats["counts"]["nodes"] == 21 assert stats["counts"]["edges"] == 23 assert isinstance(stats["ratios"]["compression"], float) assert isinstance(stats["ratios"]["bits_per_node"], float) assert isinstance(stats["ratios"]["bits_per_edge"], float) assert isinstance(stats["ratios"]["avg_locality"], float) assert stats["indegree"]["min"] == 0 assert stats["indegree"]["max"] == 3 assert isinstance(stats["indegree"]["avg"], float) assert stats["outdegree"]["min"] == 0 assert stats["outdegree"]["max"] == 3 assert isinstance(stats["outdegree"]["avg"], float) def test_leaves(graph_client): actual = list( graph_client.leaves("swh:1:ori:0000000000000000000000000000000000000021") ) expected = [ "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", "swh:1:cnt:0000000000000000000000000000000000000007", ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list( graph_client.neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) ) expected = [ "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000013", ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev", ) ) expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ] assert set(actual) == set(expected) +def test_visit_edges(graph_client): + actual = list( + graph_client.visit_edges( + "swh:1:rel:0000000000000000000000000000000000000010", + edges="rel:rev,rev:rev,rev:dir", + ) + ) + expected = [ + ( + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:dir:0000000000000000000000000000000000000008", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + ), + ] + assert set(actual) == set(expected) + + def test_visit_paths(graph_client): actual = list( graph_client.visit_paths( "swh:1:snp:0000000000000000000000000000000000000020", edges="snp:*,rev:*" ) ) actual = [tuple(path) for path in actual] expected = [ ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", ), ] assert set(actual) == set(expected) @pytest.mark.skip(reason="currently disabled due to T1969") def test_walk(graph_client): args = ("swh:1:dir:0000000000000000000000000000000000000016", "rel") kwargs = { "edges": "dir:dir,dir:rev,rev:*", "direction": "backward", "traversal": "bfs", } actual = list(graph_client.walk(*args, **kwargs)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", "swh:1:rev:0000000000000000000000000000000000000018", "swh:1:rel:0000000000000000000000000000000000000019", ] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = ["swh:1:rel:0000000000000000000000000000000000000019"] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = 2 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", ] assert set(actual) == set(expected) def test_random_walk(graph_client): """as the walk is random, we test a visit from a cnt node to the only origin in the dataset, and only check the final node of the path (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"} expected_root = "swh:1:ori:0000000000000000000000000000000000000021" actual = list(graph_client.random_walk(*args, **kwargs)) assert len(actual) > 1 # no origin directly links to a content assert actual[0] == args[0] assert actual[-1] == expected_root kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] kwargs2["limit"] = -2 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 2 assert actual[-1] == expected_root kwargs2["limit"] = 3 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 3 def test_count(graph_client): actual = graph_client.count_leaves( "swh:1:ori:0000000000000000000000000000000000000021" ) assert actual == 4 actual = graph_client.count_visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev" ) assert actual == 3 actual = graph_client.count_neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) assert actual == 3 def test_param_validation(graph_client): with raises(RemoteException) as exc_info: # PID not found list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021")) assert exc_info.value.response.status_code == 404 with raises(RemoteException) as exc_info: # malformed PID list( graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021") ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed edge specificaiton list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:notanodetype,dir:rev,rev:*", direction="backward", ) ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed direction list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:dir,dir:rev,rev:*", direction="notadirection", ) ) assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason="currently disabled due to T1969") def test_param_validation_walk(graph_client): """test validation of walk-specific parameters only""" with raises(RemoteException) as exc_info: # malformed traversal order list( graph_client.walk( "swh:1:dir:0000000000000000000000000000000000000016", "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="notatraversalorder", ) ) assert exc_info.value.response.status_code == 400 diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py index 264a0ea..e54c576 100644 --- a/swh/graph/tests/test_graph.py +++ b/swh/graph/tests/test_graph.py @@ -1,138 +1,166 @@ import pytest def test_graph(graph): assert len(graph) == 21 obj = "swh:1:dir:0000000000000000000000000000000000000008" node = graph[obj] assert str(node) == obj assert len(node.children()) == 3 assert len(node.parents()) == 2 actual = {p.pid for p in node.children()} expected = { "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000007", } assert expected == actual actual = {p.pid for p in node.parents()} expected = { "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000012", } assert expected == actual def test_invalid_pid(graph): with pytest.raises(IndexError): graph[1337] with pytest.raises(IndexError): graph[len(graph) + 1] with pytest.raises(KeyError): graph["swh:1:dir:0000000000000000000000000000000420000012"] def test_leaves(graph): actual = list(graph["swh:1:ori:0000000000000000000000000000000000000021"].leaves()) actual = [p.pid for p in actual] expected = [ "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", "swh:1:cnt:0000000000000000000000000000000000000007", ] assert set(actual) == set(expected) def test_visit_nodes(graph): actual = list( graph["swh:1:rel:0000000000000000000000000000000000000010"].visit_nodes( edges="rel:rev,rev:rev" ) ) actual = [p.pid for p in actual] expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ] assert set(actual) == set(expected) +def test_visit_edges(graph): + actual = list( + graph["swh:1:rel:0000000000000000000000000000000000000010"].visit_edges( + edges="rel:rev,rev:rev,rev:dir" + ) + ) + actual = [(src.pid, dst.pid) for src, dst in actual] + expected = [ + ( + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:dir:0000000000000000000000000000000000000008", + ), + ( + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + ), + ] + assert set(actual) == set(expected) + + def test_visit_paths(graph): actual = list( graph["swh:1:snp:0000000000000000000000000000000000000020"].visit_paths( edges="snp:*,rev:*" ) ) actual = [tuple(n.pid for n in path) for path in actual] expected = [ ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", ), ] assert set(actual) == set(expected) def test_walk(graph): actual = list( graph["swh:1:dir:0000000000000000000000000000000000000016"].walk( "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="bfs" ) ) actual = [p.pid for p in actual] expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", "swh:1:rev:0000000000000000000000000000000000000018", "swh:1:rel:0000000000000000000000000000000000000019", ] assert set(actual) == set(expected) def test_count(graph): assert ( graph["swh:1:ori:0000000000000000000000000000000000000021"].count_leaves() == 4 ) assert ( graph["swh:1:rel:0000000000000000000000000000000000000010"].count_visit_nodes( edges="rel:rev,rev:rev" ) == 3 ) assert ( graph["swh:1:rev:0000000000000000000000000000000000000009"].count_neighbors( direction="backward" ) == 3 ) def test_iter_type(graph): rev_list = list(graph.iter_type("rev")) actual = [n.pid for n in rev_list] expected = [ "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000013", "swh:1:rev:0000000000000000000000000000000000000018", ] assert expected == actual