diff --git a/java/src/main/java/org/softwareheritage/graph/Entry.java b/java/src/main/java/org/softwareheritage/graph/Entry.java index 78b230e..c911581 100644 --- a/java/src/main/java/org/softwareheritage/graph/Entry.java +++ b/java/src/main/java/org/softwareheritage/graph/Entry.java @@ -1,188 +1,196 @@ package org.softwareheritage.graph; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategy; - import java.io.DataOutputStream; import java.io.FileOutputStream; import java.io.IOException; import java.util.ArrayList; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; + public class Entry { private final long PATH_SEPARATOR_ID = -1; private Graph graph; public void load_graph(String graphBasename) throws IOException { System.err.println("Loading graph " + graphBasename + " ..."); this.graph = new Graph(graphBasename); System.err.println("Graph loaded."); } public Graph get_graph() { return graph.copy(); } public String stats() { try { Stats stats = new Stats(graph.getPath()); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); return objectMapper.writeValueAsString(stats); } catch (IOException e) { throw new RuntimeException("Cannot read stats: " + e); } } private int count_visitor(NodeCountVisitor f, long srcNodeId) { int[] count = {0}; f.accept(srcNodeId, (node) -> { count[0]++; }); return count[0]; } public int count_leaves(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::leavesVisitor, srcNodeId); } public int count_neighbors(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::neighborsVisitor, srcNodeId); } public int count_visit_nodes(String direction, String edgesFmt, long srcNodeId) { Traversal t = new Traversal(this.graph.copy(), direction, edgesFmt); return count_visitor(t::visitNodesVisitor, srcNodeId); } public QueryHandler get_handler(String clientFIFO) { return new QueryHandler(this.graph.copy(), clientFIFO); } private interface NodeCountVisitor { void accept(long nodeId, Traversal.NodeIdConsumer consumer); } public class QueryHandler { Graph graph; DataOutputStream out; String clientFIFO; public QueryHandler(Graph graph, String clientFIFO) { this.graph = graph; this.clientFIFO = clientFIFO; this.out = null; } public void writeNode(long nodeId) { try { out.writeLong(nodeId); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } public void writeEdge(long srcId, long dstId) { writeNode(srcId); writeNode(dstId); } public void writePath(ArrayList path) { for (Long nodeId : path) { writeNode(nodeId); } writeNode(PATH_SEPARATOR_ID); } public void open() { try { FileOutputStream file = new FileOutputStream(this.clientFIFO); this.out = new DataOutputStream(file); } catch (IOException e) { throw new RuntimeException("Cannot open client FIFO: " + e); } } public void close() { try { out.close(); } catch (IOException e) { throw new RuntimeException("Cannot write response to client: " + e); } } - public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void leaves(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.leavesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.leaves(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void neighbors(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.neighborsVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.neighbors(srcNodeId)) { + writeNode(nodeId); + } close(); } - public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges) { + public void visit_nodes(String direction, String edgesFmt, long srcNodeId, long maxEdges, String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); - t.visitNodesVisitor(srcNodeId, this::writeNode); + Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges, returnTypes); + for (Long nodeId : t.visitNodes(srcNodeId)) { + writeNode(nodeId); + } close(); } public void visit_edges(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.visitNodesVisitor(srcNodeId, null, this::writeEdge); close(); } public void visit_paths(String direction, String edgesFmt, long srcNodeId, long maxEdges) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt, maxEdges); t.visitPathsVisitor(srcNodeId, this::writePath); close(); } public void walk(String direction, String edgesFmt, String algorithm, long srcNodeId, long dstNodeId) { open(); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstNodeId, algorithm)) { writeNode(nodeId); } close(); } public void walk_type(String direction, String edgesFmt, String algorithm, long srcNodeId, String dst) { open(); Node.Type dstType = Node.Type.fromStr(dst); Traversal t = new Traversal(this.graph, direction, edgesFmt); for (Long nodeId : t.walk(srcNodeId, dstType, algorithm)) { writeNode(nodeId); } close(); } - public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId) { + public void random_walk(String direction, String edgesFmt, int retries, long srcNodeId, long dstNodeId, + String returnTypes) { open(); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstNodeId, retries)) { writeNode(nodeId); } close(); } - public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst) { + public void random_walk_type(String direction, String edgesFmt, int retries, long srcNodeId, String dst, + String returnTypes) { open(); Node.Type dstType = Node.Type.fromStr(dst); - Traversal t = new Traversal(this.graph, direction, edgesFmt); + Traversal t = new Traversal(this.graph, direction, edgesFmt, 0, returnTypes); for (Long nodeId : t.randomWalk(srcNodeId, dstType, retries)) { writeNode(nodeId); } close(); } } } diff --git a/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java new file mode 100644 index 0000000..3f3e7a3 --- /dev/null +++ b/java/src/main/java/org/softwareheritage/graph/NodesFiltering.java @@ -0,0 +1,107 @@ +package org.softwareheritage.graph; + +import java.util.ArrayList; + +/** + *

NodesFiltering

+ *

+ * class that manages the filtering of nodes that have been returned after a visit of the graph. + * parameterized by a string that represents either no filtering (*) or a set of node types. + *

+ * + * + * + * How to use NodesFiltering : + * + *
+ * {@code
+ *  Long id1 = .... // graph.getNodeType(id1) == CNT
+ *  Long id2 = .... // graph.getNodeType(id2) == SNP
+ *  Long id3 = .... // graph.getNodeType(id3) == ORI
+ *  ArrayList nodeIds = nez ArrayList();
+ *  nodeIds.add(id1); nodeIds.add(id2); nodeIds.add(id3);
+ *
+ *  NodeFiltering nds = new NodesFiltering("snp,ori"); // we allow only snp node types to be shown
+ *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id2, id3
+ *
+ *  nds = NodesFiltering("*");
+ *  System.out.println(nds.filterByNodeTypes(nodeIds,graph)); // will print id1, id2 id3
+ *
+ * }
+ * 
+ */ + +public class NodesFiltering { + + boolean restricted; + ArrayList allowedNodesTypes; + + /** + * Default constructor, in order to handle the * case (all types of nodes are allowed to be + * returned). allowedNodesTypes will contains [SNP,CNT....] all types of nodes. + * + */ + public NodesFiltering() { + restricted = false; + allowedNodesTypes = Node.Type.parse("*"); + } + + /** + * Constructor + * + * @param strTypes a formatted string describing the types of nodes we want to allow to be shown. + * + * NodesFilterind("cnt,snp") will set allowedNodesTypes to [CNT,SNP] + * + */ + public NodesFiltering(String strTypes) { + restricted = true; + allowedNodesTypes = new ArrayList(); + String[] types = strTypes.split(","); + for (String type : types) { + allowedNodesTypes.add(Node.Type.fromStr(type)); + } + } + + /** + * Check if the type given in parameter is in the list of allowed types. + * + * @param typ the type of the node. + */ + public boolean typeIsAllowed(Node.Type typ) { + return this.allowedNodesTypes.contains(typ); + } + + /** + *

+ * the function that filters the nodes returned, we browse the list of nodes found after a visit and + * we create a new list with only the nodes that have a type that is contained in the list of + * allowed types (allowedNodesTypes) + *

+ * + * @param nodeIds the nodes founded during the visit + * @param g the graph in order to find the types of nodes from their id in nodeIds + * @return a new list with the id of node which have a type in allowedTypes + * + * + */ + public ArrayList filterByNodeTypes(ArrayList nodeIds, Graph g) { + ArrayList filteredNodes = new ArrayList(); + for (Long node : nodeIds) { + if (this.typeIsAllowed(g.getNodeType(node))) { + filteredNodes.add(node); + } + } + return filteredNodes; + } +} diff --git a/java/src/main/java/org/softwareheritage/graph/Traversal.java b/java/src/main/java/org/softwareheritage/graph/Traversal.java index e1a240c..681e5a1 100644 --- a/java/src/main/java/org/softwareheritage/graph/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/Traversal.java @@ -1,545 +1,580 @@ package org.softwareheritage.graph; -import it.unimi.dsi.big.webgraph.LazyLongIterator; -import org.softwareheritage.graph.server.Endpoint; - -import java.util.*; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.Stack; import java.util.function.Consumer; import java.util.function.LongConsumer; +import org.softwareheritage.graph.server.Endpoint; + +import it.unimi.dsi.big.webgraph.LazyLongIterator; + /** * Traversal algorithms on the compressed graph. *

* Internal implementation of the traversal API endpoints. These methods only input/output internal * long ids, which are converted in the {@link Endpoint} higher-level class to {@link SWHID}. * * @author The Software Heritage developers * @see Endpoint */ public class Traversal { /** Graph used in the traversal */ Graph graph; /** Graph edge restriction */ AllowedEdges edges; /** Hash set storing if we have visited a node */ HashSet visited; /** Hash map storing parent node id for each nodes during a traversal */ Map parentNode; /** Number of edges accessed during traversal */ long nbEdgesAccessed; /** The anti Dos limit of edges traversed while a visit */ long maxEdges; + /** The string represent the set of type restriction */ + NodesFiltering ndsfilter; /** random number generator, for random walks */ Random rng; /** * Constructor. * * @param graph graph used in the traversal * @param direction a string (either "forward" or "backward") specifying edge orientation * @param edgesFmt a formatted string describing allowed * edges */ + public Traversal(Graph graph, String direction, String edgesFmt) { this(graph, direction, edgesFmt, 0); } public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges) { + this(graph, direction, edgesFmt, 0, "*"); + } + + public Traversal(Graph graph, String direction, String edgesFmt, long maxEdges, String returnTypes) { if (!direction.matches("forward|backward")) { throw new IllegalArgumentException("Unknown traversal direction: " + direction); } if (direction.equals("backward")) { this.graph = graph.transpose(); } else { this.graph = graph; } this.edges = new AllowedEdges(edgesFmt); this.visited = new HashSet<>(); this.parentNode = new HashMap<>(); this.nbEdgesAccessed = 0; this.maxEdges = maxEdges; this.rng = new Random(); + + if (returnTypes.equals("*")) { + this.ndsfilter = new NodesFiltering(); + } else { + this.ndsfilter = new NodesFiltering(returnTypes); + } } /** * Returns number of accessed edges during traversal. * * @return number of edges accessed in last traversal */ public long getNbEdgesAccessed() { return nbEdgesAccessed; } /** * Returns number of accessed nodes during traversal. * * @return number of nodes accessed in last traversal */ public long getNbNodesAccessed() { return this.visited.size(); } /** * Push version of {@link #leaves} will fire passed callback for each leaf. */ public void leavesVisitor(long srcNodeId, NodeIdConsumer cb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); long neighborsCnt = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); if (this.maxEdges > 0) { if (nbEdgesAccessed >= this.maxEdges) { break; } } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { neighborsCnt++; if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); } } if (neighborsCnt == 0) { cb.accept(currentNodeId); } } } /** * Returns the leaves of a subgraph rooted at the specified source node. * * @param srcNodeId source node * @return list of node ids corresponding to the leaves */ public ArrayList leaves(long srcNodeId) { - ArrayList nodeIds = new ArrayList<>(); + ArrayList nodeIds = new ArrayList(); leavesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } /** * Push version of {@link #neighbors}: will fire passed callback on each neighbor. */ public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) { this.nbEdgesAccessed = graph.outdegree(srcNodeId); if (this.maxEdges > 0) { if (nbEdgesAccessed >= this.maxEdges) { return; } } LazyLongIterator it = graph.successors(srcNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { cb.accept(neighborNodeId); } } /** * Returns node direct neighbors (linked with exactly one edge). * * @param srcNodeId source node * @return list of node ids corresponding to the neighbors */ public ArrayList neighbors(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); neighborsVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } /** * Push version of {@link #visitNodes}: will fire passed callback on each visited node. */ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (nodeCb != null) { nodeCb.accept(currentNodeId); } nbEdgesAccessed += graph.outdegree(currentNodeId); if (this.maxEdges > 0) { if (nbEdgesAccessed >= this.maxEdges) { break; } } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (edgeCb != null) { edgeCb.accept(currentNodeId, neighborNodeId); } if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); } } } } /** One-argument version to handle callbacks properly */ public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) { visitNodesVisitor(srcNodeId, cb, null); } /** * Performs a graph traversal and returns explored nodes. * * @param srcNodeId source node * @return list of explored node ids */ public ArrayList visitNodes(long srcNodeId) { ArrayList nodeIds = new ArrayList<>(); visitNodesVisitor(srcNodeId, nodeIds::add); + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(nodeIds, graph); + } return nodeIds; } /** * Push version of {@link #visitPaths}: will fire passed callback on each discovered (complete) * path. */ public void visitPathsVisitor(long srcNodeId, PathConsumer cb) { Stack currentPath = new Stack<>(); this.nbEdgesAccessed = 0; visitPathsInternalVisitor(srcNodeId, currentPath, cb); } /** * Performs a graph traversal and returns explored paths. * * @param srcNodeId source node * @return list of explored paths (represented as a list of node ids) */ public ArrayList> visitPaths(long srcNodeId) { ArrayList> paths = new ArrayList<>(); visitPathsVisitor(srcNodeId, paths::add); return paths; } private void visitPathsInternalVisitor(long currentNodeId, Stack currentPath, PathConsumer cb) { currentPath.push(currentNodeId); long visitedNeighbors = 0; nbEdgesAccessed += graph.outdegree(currentNodeId); if (this.maxEdges > 0) { if (nbEdgesAccessed >= this.maxEdges) { currentPath.pop(); return; } } LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { visitPathsInternalVisitor(neighborNodeId, currentPath, cb); visitedNeighbors++; } if (visitedNeighbors == 0) { ArrayList path = new ArrayList<>(currentPath); cb.accept(path); } currentPath.pop(); } /** * Performs a graph traversal with backtracking, and returns the first found path from source to * destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids */ public ArrayList walk(long srcNodeId, T dst, String visitOrder) { long dstNodeId; if (visitOrder.equals("dfs")) { dstNodeId = walkInternalDFS(srcNodeId, dst); } else if (visitOrder.equals("bfs")) { dstNodeId = walkInternalBFS(srcNodeId, dst); } else { throw new IllegalArgumentException("Unknown visit order: " + visitOrder); } if (dstNodeId == -1) { throw new IllegalArgumentException("Cannot find destination: " + dst); } return backtracking(srcNodeId, dstNodeId); } /** * Performs a random walk (picking a random successor at each step) from source to destination. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return found path as a list of node ids or an empty path to indicate that no suitable path have * been found */ public ArrayList randomWalk(long srcNodeId, T dst) { return randomWalk(srcNodeId, dst, 0); } /** * Performs a stubborn random walk (picking a random successor at each step) from source to * destination. The walk is "stubborn" in the sense that it will not give up the first time if a * satisfying target node is found, but it will retry up to a limited amount of times. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @param retries number of times to retry; 0 means no retries (single walk) * @return found path as a list of node ids or an empty path to indicate that no suitable path have * been found */ public ArrayList randomWalk(long srcNodeId, T dst, int retries) { long curNodeId = srcNodeId; ArrayList path = new ArrayList<>(); this.nbEdgesAccessed = 0; boolean found; if (retries < 0) { throw new IllegalArgumentException("Negative number of retries given: " + retries); } while (true) { path.add(curNodeId); LazyLongIterator successors = graph.successors(curNodeId, edges); curNodeId = randomPick(successors); if (curNodeId < 0) { found = false; break; } if (isDstNode(curNodeId, dst)) { path.add(curNodeId); found = true; break; } } if (found) { + if (ndsfilter.restricted) { + return ndsfilter.filterByNodeTypes(path, graph); + } return path; } else if (retries > 0) { // try again return randomWalk(srcNodeId, dst, retries - 1); } else { // not found and no retries left path.clear(); return path; } } /** * Randomly choose an element from an iterator over Longs using reservoir sampling * * @param elements iterator over selection domain * @return randomly chosen element or -1 if no suitable element was found */ private long randomPick(LazyLongIterator elements) { long curPick = -1; long seenCandidates = 0; for (long element; (element = elements.nextLong()) != -1;) { seenCandidates++; if (Math.round(rng.nextFloat() * (seenCandidates - 1)) == 0) { curPick = element; } } return curPick; } /** * Internal DFS function of {@link #walk}. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ private long walkInternalDFS(long srcNodeId, T dst) { Stack stack = new Stack<>(); this.nbEdgesAccessed = 0; stack.push(srcNodeId); visited.add(srcNodeId); while (!stack.isEmpty()) { long currentNodeId = stack.pop(); if (isDstNode(currentNodeId, dst)) { return currentNodeId; } nbEdgesAccessed += graph.outdegree(currentNodeId); LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!visited.contains(neighborNodeId)) { stack.push(neighborNodeId); visited.add(neighborNodeId); parentNode.put(neighborNodeId, currentNodeId); } } } return -1; } /** * Internal BFS function of {@link #walk}. * * @param srcNodeId source node * @param dst destination (either a node or a node type) * @return final destination node or -1 if no path found */ private long walkInternalBFS(long srcNodeId, T dst) { Queue queue = new LinkedList<>(); this.nbEdgesAccessed = 0; queue.add(srcNodeId); visited.add(srcNodeId); while (!queue.isEmpty()) { long currentNodeId = queue.poll(); if (isDstNode(currentNodeId, dst)) { return currentNodeId; } nbEdgesAccessed += graph.outdegree(currentNodeId); LazyLongIterator it = graph.successors(currentNodeId, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!visited.contains(neighborNodeId)) { queue.add(neighborNodeId); visited.add(neighborNodeId); parentNode.put(neighborNodeId, currentNodeId); } } } return -1; } /** * Internal function of {@link #walk} to check if a node corresponds to the destination. * * @param nodeId current node * @param dst destination (either a node or a node type) * @return true if the node is a destination, or false otherwise */ private boolean isDstNode(long nodeId, T dst) { if (dst instanceof Long) { long dstNodeId = (Long) dst; return nodeId == dstNodeId; } else if (dst instanceof Node.Type) { Node.Type dstType = (Node.Type) dst; return graph.getNodeType(nodeId) == dstType; } else { return false; } } /** * Internal backtracking function of {@link #walk}. * * @param srcNodeId source node * @param dstNodeId destination node * @return the found path, as a list of node ids */ private ArrayList backtracking(long srcNodeId, long dstNodeId) { ArrayList path = new ArrayList<>(); long currentNodeId = dstNodeId; while (currentNodeId != srcNodeId) { path.add(currentNodeId); currentNodeId = parentNode.get(currentNodeId); } path.add(srcNodeId); Collections.reverse(path); return path; } /** * Find a common descendant between two given nodes using two parallel BFS * * @param lhsNode the first node * @param rhsNode the second node * @return the found path, as a list of node ids */ public Long findCommonDescendant(long lhsNode, long rhsNode) { Queue lhsStack = new ArrayDeque<>(); Queue rhsStack = new ArrayDeque<>(); HashSet lhsVisited = new HashSet<>(); HashSet rhsVisited = new HashSet<>(); lhsStack.add(lhsNode); rhsStack.add(rhsNode); lhsVisited.add(lhsNode); rhsVisited.add(rhsNode); this.nbEdgesAccessed = 0; Long curNode; while (!lhsStack.isEmpty() || !rhsStack.isEmpty()) { if (!lhsStack.isEmpty()) { curNode = lhsStack.poll(); nbEdgesAccessed += graph.outdegree(curNode); LazyLongIterator it = graph.successors(curNode, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!lhsVisited.contains(neighborNodeId)) { if (rhsVisited.contains(neighborNodeId)) return neighborNodeId; lhsStack.add(neighborNodeId); lhsVisited.add(neighborNodeId); } } } if (!rhsStack.isEmpty()) { curNode = rhsStack.poll(); nbEdgesAccessed += graph.outdegree(curNode); LazyLongIterator it = graph.successors(curNode, edges); for (long neighborNodeId; (neighborNodeId = it.nextLong()) != -1;) { if (!rhsVisited.contains(neighborNodeId)) { if (lhsVisited.contains(neighborNodeId)) return neighborNodeId; rhsStack.add(neighborNodeId); rhsVisited.add(neighborNodeId); } } } } return null; } public interface NodeIdConsumer extends LongConsumer { /** * Callback for incrementally receiving node identifiers during a graph visit. */ void accept(long nodeId); } public interface EdgeIdConsumer { /** * Callback for incrementally receiving edge identifiers during a graph visit. */ void accept(long srcId, long dstId); } public interface PathConsumer extends Consumer> { /** * Callback for incrementally receiving node paths (made of node identifiers) during a graph visit. */ void accept(ArrayList path); } } diff --git a/swh/graph/backend.py b/swh/graph/backend.py index 750849a..f8cbc02 100644 --- a/swh/graph/backend.py +++ b/swh/graph/backend.py @@ -1,196 +1,200 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import io import os import struct import subprocess import sys import tempfile from py4j.java_gateway import JavaGateway from swh.graph.config import check_config from swh.graph.swhid import NodeToSwhidMap, SwhidToNodeMap from swh.model.identifiers import EXTENDED_SWHID_TYPES BUF_SIZE = 64 * 1024 BIN_FMT = ">q" # 64 bit integer, big endian PATH_SEPARATOR_ID = -1 NODE2SWHID_EXT = "node2swhid.bin" SWHID2NODE_EXT = "swhid2node.bin" def _get_pipe_stderr(): # Get stderr if possible, or pipe to stdout if running with Jupyter. try: sys.stderr.fileno() except io.UnsupportedOperation: return subprocess.STDOUT else: return sys.stderr class Backend: def __init__(self, graph_path, config=None): self.gateway = None self.entry = None self.graph_path = graph_path self.config = check_config(config or {}) def __enter__(self): self.gateway = JavaGateway.launch_gateway( java_path=None, javaopts=self.config["java_tool_options"].split(), classpath=self.config["classpath"], die_on_exit=True, redirect_stdout=sys.stdout, redirect_stderr=_get_pipe_stderr(), ) self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry() self.entry.load_graph(self.graph_path) self.node2swhid = NodeToSwhidMap(self.graph_path + "." + NODE2SWHID_EXT) self.swhid2node = SwhidToNodeMap(self.graph_path + "." + SWHID2NODE_EXT) self.stream_proxy = JavaStreamProxy(self.entry) return self def __exit__(self, exc_type, exc_value, tb): self.gateway.shutdown() def stats(self): return self.entry.stats() def count(self, ttype, direction, edges_fmt, src): method = getattr(self.entry, "count_" + ttype) return method(direction, edges_fmt, src) - async def simple_traversal(self, ttype, direction, edges_fmt, src, max_edges): + async def simple_traversal( + self, ttype, direction, edges_fmt, src, max_edges, return_types + ): assert ttype in ("leaves", "neighbors", "visit_nodes") method = getattr(self.stream_proxy, ttype) - async for node_id in method(direction, edges_fmt, src, max_edges): + async for node_id in method(direction, edges_fmt, src, max_edges, return_types): yield node_id async def walk(self, direction, edges_fmt, algo, src, dst): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.walk_type(direction, edges_fmt, algo, src, dst) else: it = self.stream_proxy.walk(direction, edges_fmt, algo, src, dst) async for node_id in it: yield node_id - async def random_walk(self, direction, edges_fmt, retries, src, dst): + async def random_walk(self, direction, edges_fmt, retries, src, dst, return_types): if dst in EXTENDED_SWHID_TYPES: it = self.stream_proxy.random_walk_type( - direction, edges_fmt, retries, src, dst + direction, edges_fmt, retries, src, dst, return_types ) else: - it = self.stream_proxy.random_walk(direction, edges_fmt, retries, src, dst) + it = self.stream_proxy.random_walk( + direction, edges_fmt, retries, src, dst, return_types + ) async for node_id in it: # TODO return 404 if path is empty yield node_id async def visit_edges(self, direction, edges_fmt, src, max_edges): it = self.stream_proxy.visit_edges(direction, edges_fmt, src, max_edges) # convert stream a, b, c, d -> (a, b), (c, d) prevNode = None async for node in it: if prevNode is not None: yield (prevNode, node) prevNode = None else: prevNode = node async def visit_paths(self, direction, edges_fmt, src, max_edges): path = [] async for node in self.stream_proxy.visit_paths( direction, edges_fmt, src, max_edges ): if node == PATH_SEPARATOR_ID: yield path path = [] else: path.append(node) class JavaStreamProxy: """A proxy class for the org.softwareheritage.graph.Entry Java class that takes care of the setup and teardown of the named-pipe FIFO communication between Python and Java. Initialize JavaStreamProxy using: proxy = JavaStreamProxy(swh_entry_class_instance) Then you can call an Entry method and iterate on the FIFO results like this: async for value in proxy.java_method(arg1, arg2): print(value) """ def __init__(self, entry): self.entry = entry async def read_node_ids(self, fname): loop = asyncio.get_event_loop() open_thread = loop.run_in_executor(None, open, fname, "rb") # Since the open() call on the FIFO is blocking until it is also opened # on the Java side, we await it with a timeout in case there is an # exception that prevents the write-side open(). with (await asyncio.wait_for(open_thread, timeout=2)) as f: while True: data = await loop.run_in_executor(None, f.read, BUF_SIZE) if not data: break for data in struct.iter_unpack(BIN_FMT, data): yield data[0] class _HandlerWrapper: def __init__(self, handler): self._handler = handler def __getattr__(self, name): func = getattr(self._handler, name) async def java_call(*args, **kwargs): loop = asyncio.get_event_loop() await loop.run_in_executor(None, lambda: func(*args, **kwargs)) def java_task(*args, **kwargs): return asyncio.create_task(java_call(*args, **kwargs)) return java_task @contextlib.contextmanager def get_handler(self): with tempfile.TemporaryDirectory(prefix="swh-graph-") as tmpdirname: cli_fifo = os.path.join(tmpdirname, "swh-graph.fifo") os.mkfifo(cli_fifo) reader = self.read_node_ids(cli_fifo) query_handler = self.entry.get_handler(cli_fifo) handler = self._HandlerWrapper(query_handler) yield (handler, reader) def __getattr__(self, name): async def java_call_iterator(*args, **kwargs): with self.get_handler() as (handler, reader): java_task = getattr(handler, name)(*args, **kwargs) try: async for value in reader: yield value except asyncio.TimeoutError: # If the read-side open() timeouts, an exception on the # Java side probably happened that prevented the # write-side open(). We propagate this exception here if # that is the case. task_exc = java_task.exception() if task_exc: raise task_exc raise await java_task return java_call_iterator diff --git a/swh/graph/client.py b/swh/graph/client.py index 2517a48..94730d3 100644 --- a/swh/graph/client.py +++ b/swh/graph/client.py @@ -1,112 +1,143 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import json from swh.core.api import RPCClient class GraphAPIError(Exception): """Graph API Error""" def __str__(self): - return "An unexpected error occurred in the Graph backend: {}".format(self.args) + return """An unexpected error occurred + in the Graph backend: {}""".format( + self.args + ) class RemoteGraphClient(RPCClient): """Client to the Software Heritage Graph.""" def __init__(self, url, timeout=None): super().__init__(api_exception=GraphAPIError, url=url, timeout=timeout) def raw_verb_lines(self, verb, endpoint, **kwargs): response = self.raw_verb(verb, endpoint, stream=True, **kwargs) self.raise_for_status(response) for line in response.iter_lines(): yield line.decode().lstrip("\n") def get_lines(self, endpoint, **kwargs): yield from self.raw_verb_lines("get", endpoint, **kwargs) # Web API endpoints def stats(self): return self.get("stats") - def leaves(self, src, edges="*", direction="forward", max_edges=0): + def leaves( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "leaves/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) - def neighbors(self, src, edges="*", direction="forward", max_edges=0): + def neighbors( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "neighbors/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) - def visit_nodes(self, src, edges="*", direction="forward", max_edges=0): + def visit_nodes( + self, src, edges="*", direction="forward", max_edges=0, return_types="*" + ): return self.get_lines( "visit/nodes/{}".format(src), - params={"edges": edges, "direction": direction, "max_edges": max_edges}, + params={ + "edges": edges, + "direction": direction, + "max_edges": max_edges, + "return_types": return_types, + }, ) def visit_edges(self, src, edges="*", direction="forward", max_edges=0): for edge in self.get_lines( "visit/edges/{}".format(src), params={"edges": edges, "direction": direction, "max_edges": max_edges}, ): yield tuple(edge.split()) def visit_paths(self, src, edges="*", direction="forward", max_edges=0): def decode_path_wrapper(it): for e in it: yield json.loads(e) return decode_path_wrapper( self.get_lines( "visit/paths/{}".format(src), params={"edges": edges, "direction": direction, "max_edges": max_edges}, ) ) def walk( self, src, dst, edges="*", traversal="dfs", direction="forward", limit=None ): endpoint = "walk/{}/{}" return self.get_lines( endpoint.format(src, dst), params={ "edges": edges, "traversal": traversal, "direction": direction, "limit": limit, }, ) - def random_walk(self, src, dst, edges="*", direction="forward", limit=None): + def random_walk( + self, src, dst, edges="*", direction="forward", limit=None, return_types="*" + ): endpoint = "randomwalk/{}/{}" return self.get_lines( endpoint.format(src, dst), - params={"edges": edges, "direction": direction, "limit": limit}, + params={ + "edges": edges, + "direction": direction, + "limit": limit, + "return_types": return_types, + }, ) def count_leaves(self, src, edges="*", direction="forward"): return self.get( "leaves/count/{}".format(src), params={"edges": edges, "direction": direction}, ) def count_neighbors(self, src, edges="*", direction="forward"): return self.get( "neighbors/count/{}".format(src), params={"edges": edges, "direction": direction}, ) def count_visit_nodes(self, src, edges="*", direction="forward"): return self.get( "visit/nodes/count/{}".format(src), params={"edges": edges, "direction": direction}, ) diff --git a/swh/graph/graph.py b/swh/graph/graph.py index c68decb..3fd853b 100644 --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -1,190 +1,193 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import functools from swh.graph.backend import Backend from swh.graph.dot import KIND_TO_SHAPE, dot_to_svg, graph_dot BASE_URL = "https://archive.softwareheritage.org/browse" KIND_TO_URL_FRAGMENT = { "ori": "/origin/{}", "snp": "/snapshot/{}", "rel": "/release/{}", "rev": "/revision/{}", "dir": "/directory/{}", "cnt": "/content/sha1_git:{}/", } def call_async_gen(generator, *args, **kwargs): loop = asyncio.get_event_loop() it = generator(*args, **kwargs).__aiter__() while True: try: res = loop.run_until_complete(it.__anext__()) yield res except StopAsyncIteration: break class Neighbors: """Neighbor iterator with custom O(1) length method""" def __init__(self, graph, iterator, length_func): self.graph = graph self.iterator = iterator self.length_func = length_func def __iter__(self): return self def __next__(self): succ = self.iterator.nextLong() if succ == -1: raise StopIteration return GraphNode(self.graph, succ) def __len__(self): return self.length_func() class GraphNode: """Node in the SWH graph""" def __init__(self, graph, node_id): self.graph = graph self.id = node_id def children(self): return Neighbors( self.graph, self.graph.java_graph.successors(self.id), lambda: self.graph.java_graph.outdegree(self.id), ) def parents(self): return Neighbors( self.graph, self.graph.java_graph.predecessors(self.id), lambda: self.graph.java_graph.indegree(self.id), ) - def simple_traversal(self, ttype, direction="forward", edges="*", max_edges=0): + def simple_traversal( + self, ttype, direction="forward", edges="*", max_edges=0, return_types="*" + ): for node in call_async_gen( self.graph.backend.simple_traversal, ttype, direction, edges, self.id, max_edges, + return_types, ): yield self.graph[node] def leaves(self, *args, **kwargs): yield from self.simple_traversal("leaves", *args, **kwargs) def visit_nodes(self, *args, **kwargs): yield from self.simple_traversal("visit_nodes", *args, **kwargs) def visit_edges(self, direction="forward", edges="*", max_edges=0): for src, dst in call_async_gen( self.graph.backend.visit_edges, direction, edges, self.id, max_edges ): yield (self.graph[src], self.graph[dst]) def visit_paths(self, direction="forward", edges="*", max_edges=0): for path in call_async_gen( self.graph.backend.visit_paths, direction, edges, self.id, max_edges ): yield [self.graph[node] for node in path] def walk(self, dst, direction="forward", edges="*", traversal="dfs"): for node in call_async_gen( self.graph.backend.walk, direction, edges, traversal, self.id, dst ): yield self.graph[node] def _count(self, ttype, direction="forward", edges="*"): return self.graph.backend.count(ttype, direction, edges, self.id) count_leaves = functools.partialmethod(_count, ttype="leaves") count_neighbors = functools.partialmethod(_count, ttype="neighbors") count_visit_nodes = functools.partialmethod(_count, ttype="visit_nodes") @property def swhid(self): return self.graph.node2swhid[self.id] @property def kind(self): return self.swhid.split(":")[2] def __str__(self): return self.swhid def __repr__(self): return "<{}>".format(self.swhid) def dot_fragment(self): swh, version, kind, hash = self.swhid.split(":") label = "{}:{}..{}".format(kind, hash[0:2], hash[-2:]) url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) shape = KIND_TO_SHAPE[kind] return '{} [label="{}", href="{}", target="_blank", shape="{}"];'.format( self.id, label, url, shape ) def _repr_svg_(self): nodes = [self, *list(self.children()), *list(self.parents())] dot = graph_dot(nodes) svg = dot_to_svg(dot) return svg class Graph: def __init__(self, backend, node2swhid, swhid2node): self.backend = backend self.java_graph = backend.entry.get_graph() self.node2swhid = node2swhid self.swhid2node = swhid2node def stats(self): return self.backend.stats() @property def path(self): return self.java_graph.getPath() def __len__(self): return self.java_graph.numNodes() def __getitem__(self, node_id): if isinstance(node_id, int): self.node2swhid[node_id] # check existence return GraphNode(self, node_id) elif isinstance(node_id, str): node_id = self.swhid2node[node_id] return GraphNode(self, node_id) def __iter__(self): for swhid, pos in self.backend.swhid2node: yield self[swhid] def iter_prefix(self, prefix): for swhid, pos in self.backend.swhid2node.iter_prefix(prefix): yield self[swhid] def iter_type(self, swhid_type): for swhid, pos in self.backend.swhid2node.iter_type(swhid_type): yield self[swhid] @contextlib.contextmanager def load(graph_path): with Backend(graph_path) as backend: yield Graph(backend, backend.node2swhid, backend.swhid2node) diff --git a/swh/graph/server/app.py b/swh/graph/server/app.py index 3d3258b..227524a 100644 --- a/swh/graph/server/app.py +++ b/swh/graph/server/app.py @@ -1,333 +1,359 @@ # Copyright (C) 2019-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information """ A proxy HTTP server for swh-graph, talking to the Java code via py4j, and using FIFO as a transport to stream integers between the two languages. """ import asyncio from collections import deque import json from typing import Optional import aiohttp.web from swh.core.api.asynchronous import RPCServerApp from swh.model.exceptions import ValidationError from swh.model.identifiers import EXTENDED_SWHID_TYPES try: from contextlib import asynccontextmanager except ImportError: # Compatibility with 3.6 backport from async_generator import asynccontextmanager # type: ignore # maximum number of retries for random walks RANDOM_RETRIES = 5 # TODO make this configurable via rpc-serve configuration async def index(request): return aiohttp.web.Response( content_type="text/html", body=""" Software Heritage storage server

You have reached the Software Heritage graph API server.

See its API documentation for more information.

""", ) class GraphView(aiohttp.web.View): """Base class for views working on the graph, with utility functions""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.backend = self.request.app["backend"] def node_of_swhid(self, swhid): """Lookup a SWHID in a swhid2node map, failing in an HTTP-nice way if needed.""" try: return self.backend.swhid2node[swhid] except KeyError: raise aiohttp.web.HTTPNotFound(text=f"SWHID not found: {swhid}") except ValidationError: raise aiohttp.web.HTTPBadRequest(text=f"malformed SWHID: {swhid}") def swhid_of_node(self, node): """Lookup a node in a node2swhid map, failing in an HTTP-nice way if needed.""" try: return self.backend.node2swhid[node] except KeyError: raise aiohttp.web.HTTPInternalServerError( text=f"reverse lookup failed for node id: {node}" ) def get_direction(self): """Validate HTTP query parameter `direction`""" s = self.request.query.get("direction", "forward") if s not in ("forward", "backward"): raise aiohttp.web.HTTPBadRequest(text=f"invalid direction: {s}") return s def get_edges(self): """Validate HTTP query parameter `edges`, i.e., edge restrictions""" s = self.request.query.get("edges", "*") if any( [ node_type != "*" and node_type not in EXTENDED_SWHID_TYPES for edge in s.split(":") for node_type in edge.split(",", maxsplit=1) ] ): raise aiohttp.web.HTTPBadRequest(text=f"invalid edge restriction: {s}") return s + def get_return_types(self): + """Validate HTTP query parameter 'return types', i.e, + a set of types which we will filter the query results with""" + s = self.request.query.get("return_types", "*") + if any( + node_type != "*" and node_type not in EXTENDED_SWHID_TYPES + for node_type in s.split(",") + ): + raise aiohttp.web.HTTPBadRequest( + text=f"invalid type for filtering res: {s}" + ) + # if the user puts a star, + # then we filter nothing, we don't need the other information + if "*" in s: + return "*" + else: + return s + def get_traversal(self): """Validate HTTP query parameter `traversal`, i.e., visit order""" s = self.request.query.get("traversal", "dfs") if s not in ("bfs", "dfs"): raise aiohttp.web.HTTPBadRequest(text=f"invalid traversal order: {s}") return s def get_limit(self): """Validate HTTP query parameter `limit`, i.e., number of results""" s = self.request.query.get("limit", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") def get_max_edges(self): """Validate HTTP query parameter 'max_edges', i.e., the limit of the number of edges that can be visited""" s = self.request.query.get("max_edges", "0") try: return int(s) except ValueError: raise aiohttp.web.HTTPBadRequest(text=f"invalid max_edges value: {s}") class StreamingGraphView(GraphView): """Base class for views streaming their response line by line.""" content_type = "text/plain" @asynccontextmanager async def response_streamer(self, *args, **kwargs): """Context manager to prepare then close a StreamResponse""" response = aiohttp.web.StreamResponse(*args, **kwargs) response.content_type = self.content_type await response.prepare(self.request) yield response await response.write_eof() async def get(self): await self.prepare_response() async with self.response_streamer() as self.response_stream: await self.stream_response() return self.response_stream async def prepare_response(self): """This can be overridden with some setup to be run before the response actually starts streaming. """ pass async def stream_response(self): """Override this to perform the response streaming. Implementations of this should await self.stream_line(line) to write each line. """ raise NotImplementedError async def stream_line(self, line): """Write a line in the response stream.""" await self.response_stream.write((line + "\n").encode()) class StatsView(GraphView): """View showing some statistics on the graph""" async def get(self): stats = self.backend.stats() return aiohttp.web.Response(body=stats, content_type="application/json") class SimpleTraversalView(StreamingGraphView): """Base class for views of simple traversals""" simple_traversal_type: Optional[str] = None async def prepare_response(self): src = self.request.match_info["src"] self.src_node = self.node_of_swhid(src) self.edges = self.get_edges() self.direction = self.get_direction() self.max_edges = self.get_max_edges() + self.return_types = self.get_return_types() async def stream_response(self): async for res_node in self.backend.simple_traversal( self.simple_traversal_type, self.direction, self.edges, self.src_node, self.max_edges, + self.return_types, ): res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) class LeavesView(SimpleTraversalView): simple_traversal_type = "leaves" class NeighborsView(SimpleTraversalView): simple_traversal_type = "neighbors" class VisitNodesView(SimpleTraversalView): simple_traversal_type = "visit_nodes" class WalkView(StreamingGraphView): async def prepare_response(self): src = self.request.match_info["src"] dst = self.request.match_info["dst"] self.src_node = self.node_of_swhid(src) if dst not in EXTENDED_SWHID_TYPES: self.dst_thing = self.node_of_swhid(dst) else: self.dst_thing = dst self.edges = self.get_edges() self.direction = self.get_direction() self.algo = self.get_traversal() self.limit = self.get_limit() + self.return_types = self.get_return_types() async def get_walk_iterator(self): return self.backend.walk( self.direction, self.edges, self.algo, self.src_node, self.dst_thing ) async def stream_response(self): it = self.get_walk_iterator() if self.limit < 0: queue = deque(maxlen=-self.limit) async for res_node in it: res_swhid = self.swhid_of_node(res_node) queue.append(res_swhid) while queue: await self.stream_line(queue.popleft()) else: count = 0 async for res_node in it: if self.limit == 0 or count < self.limit: res_swhid = self.swhid_of_node(res_node) await self.stream_line(res_swhid) count += 1 else: break class RandomWalkView(WalkView): def get_walk_iterator(self): return self.backend.random_walk( - self.direction, self.edges, RANDOM_RETRIES, self.src_node, self.dst_thing + self.direction, + self.edges, + RANDOM_RETRIES, + self.src_node, + self.dst_thing, + self.return_types, ) class VisitEdgesView(SimpleTraversalView): async def stream_response(self): it = self.backend.visit_edges( self.direction, self.edges, self.src_node, self.max_edges ) async for (res_src, res_dst) in it: res_src_swhid = self.swhid_of_node(res_src) res_dst_swhid = self.swhid_of_node(res_dst) await self.stream_line("{} {}".format(res_src_swhid, res_dst_swhid)) class VisitPathsView(SimpleTraversalView): content_type = "application/x-ndjson" async def stream_response(self): it = self.backend.visit_paths( self.direction, self.edges, self.src_node, self.max_edges ) async for res_path in it: res_path_swhid = [self.swhid_of_node(n) for n in res_path] line = json.dumps(res_path_swhid) await self.stream_line(line) class CountView(GraphView): """Base class for counting views.""" count_type: Optional[str] = None async def get(self): src = self.request.match_info["src"] self.src_node = self.node_of_swhid(src) self.edges = self.get_edges() self.direction = self.get_direction() loop = asyncio.get_event_loop() cnt = await loop.run_in_executor( None, self.backend.count, self.count_type, self.direction, self.edges, self.src_node, ) return aiohttp.web.Response(body=str(cnt), content_type="application/json") class CountNeighborsView(CountView): count_type = "neighbors" class CountLeavesView(CountView): count_type = "leaves" class CountVisitNodesView(CountView): count_type = "visit_nodes" def make_app(backend, **kwargs): app = RPCServerApp(**kwargs) app.add_routes( [ aiohttp.web.get("/", index), aiohttp.web.get("/graph", index), aiohttp.web.view("/graph/stats", StatsView), aiohttp.web.view("/graph/leaves/{src}", LeavesView), aiohttp.web.view("/graph/neighbors/{src}", NeighborsView), aiohttp.web.view("/graph/visit/nodes/{src}", VisitNodesView), aiohttp.web.view("/graph/visit/edges/{src}", VisitEdgesView), aiohttp.web.view("/graph/visit/paths/{src}", VisitPathsView), # temporarily disabled in wait of a proper fix for T1969 # aiohttp.web.view("/graph/walk/{src}/{dst}", WalkView) aiohttp.web.view("/graph/randomwalk/{src}/{dst}", RandomWalkView), aiohttp.web.view("/graph/neighbors/count/{src}", CountNeighborsView), aiohttp.web.view("/graph/leaves/count/{src}", CountLeavesView), aiohttp.web.view("/graph/visit/nodes/count/{src}", CountVisitNodesView), ] ) app["backend"] = backend return app diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py index 9c2fc50..8d23295 100644 --- a/swh/graph/tests/test_api_client.py +++ b/swh/graph/tests/test_api_client.py @@ -1,335 +1,370 @@ import pytest from pytest import raises from swh.core.api import RemoteException def test_stats(graph_client): stats = graph_client.stats() assert set(stats.keys()) == {"counts", "ratios", "indegree", "outdegree"} assert set(stats["counts"].keys()) == {"nodes", "edges"} assert set(stats["ratios"].keys()) == { "compression", "bits_per_node", "bits_per_edge", "avg_locality", } assert set(stats["indegree"].keys()) == {"min", "max", "avg"} assert set(stats["outdegree"].keys()) == {"min", "max", "avg"} assert stats["counts"]["nodes"] == 21 assert stats["counts"]["edges"] == 23 assert isinstance(stats["ratios"]["compression"], float) assert isinstance(stats["ratios"]["bits_per_node"], float) assert isinstance(stats["ratios"]["bits_per_edge"], float) assert isinstance(stats["ratios"]["avg_locality"], float) assert stats["indegree"]["min"] == 0 assert stats["indegree"]["max"] == 3 assert isinstance(stats["indegree"]["avg"], float) assert stats["outdegree"]["min"] == 0 assert stats["outdegree"]["max"] == 3 assert isinstance(stats["outdegree"]["avg"], float) def test_leaves(graph_client): actual = list( graph_client.leaves("swh:1:ori:0000000000000000000000000000000000000021") ) expected = [ "swh:1:cnt:0000000000000000000000000000000000000001", "swh:1:cnt:0000000000000000000000000000000000000004", "swh:1:cnt:0000000000000000000000000000000000000005", "swh:1:cnt:0000000000000000000000000000000000000007", ] assert set(actual) == set(expected) def test_neighbors(graph_client): actual = list( graph_client.neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) ) expected = [ "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000013", ] assert set(actual) == set(expected) def test_visit_nodes(graph_client): actual = list( graph_client.visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev", ) ) expected = [ "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ] assert set(actual) == set(expected) +def test_visit_nodes_filtered(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="dir", + ) + ) + expected = [ + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:dir:0000000000000000000000000000000000000006", + ] + assert set(actual) == set(expected) + + +def test_visit_nodes_filtered_star(graph_client): + actual = list( + graph_client.visit_nodes( + "swh:1:rel:0000000000000000000000000000000000000010", return_types="*", + ) + ) + expected = [ + "swh:1:rel:0000000000000000000000000000000000000010", + "swh:1:rev:0000000000000000000000000000000000000009", + "swh:1:rev:0000000000000000000000000000000000000003", + "swh:1:dir:0000000000000000000000000000000000000002", + "swh:1:cnt:0000000000000000000000000000000000000001", + "swh:1:dir:0000000000000000000000000000000000000008", + "swh:1:cnt:0000000000000000000000000000000000000007", + "swh:1:dir:0000000000000000000000000000000000000006", + "swh:1:cnt:0000000000000000000000000000000000000004", + "swh:1:cnt:0000000000000000000000000000000000000005", + ] + assert set(actual) == set(expected) + + def test_visit_edges(graph_client): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev,rev:dir", ) ) expected = [ ( "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ] assert set(actual) == set(expected) def test_visit_edges_limited(graph_client): actual = list( graph_client.visit_edges( "swh:1:rel:0000000000000000000000000000000000000010", max_edges=4, edges="rel:rev,rev:rev,rev:dir", ) ) expected = [ ( "swh:1:rel:0000000000000000000000000000000000000010", "swh:1:rev:0000000000000000000000000000000000000009", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ] # As there are four valid answers (up to reordering), we cannot check for # equality. Instead, we check the client returned all edges but one. assert set(actual).issubset(set(expected)) - assert len(actual) == 3 + assert len(actual) == 4 def test_visit_edges_diamond_pattern(graph_client): actual = list( graph_client.visit_edges( "swh:1:rev:0000000000000000000000000000000000000009", edges="*", ) ) expected = [ ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", ), ( "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:dir:0000000000000000000000000000000000000002", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000001", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:cnt:0000000000000000000000000000000000000007", ), ( "swh:1:dir:0000000000000000000000000000000000000008", "swh:1:dir:0000000000000000000000000000000000000006", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000004", ), ( "swh:1:dir:0000000000000000000000000000000000000006", "swh:1:cnt:0000000000000000000000000000000000000005", ), ] assert set(actual) == set(expected) def test_visit_paths(graph_client): actual = list( graph_client.visit_paths( "swh:1:snp:0000000000000000000000000000000000000020", edges="snp:*,rev:*" ) ) actual = [tuple(path) for path in actual] expected = [ ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:rev:0000000000000000000000000000000000000003", "swh:1:dir:0000000000000000000000000000000000000002", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rev:0000000000000000000000000000000000000009", "swh:1:dir:0000000000000000000000000000000000000008", ), ( "swh:1:snp:0000000000000000000000000000000000000020", "swh:1:rel:0000000000000000000000000000000000000010", ), ] assert set(actual) == set(expected) @pytest.mark.skip(reason="currently disabled due to T1969") def test_walk(graph_client): args = ("swh:1:dir:0000000000000000000000000000000000000016", "rel") kwargs = { "edges": "dir:dir,dir:rev,rev:*", "direction": "backward", "traversal": "bfs", } actual = list(graph_client.walk(*args, **kwargs)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", "swh:1:rev:0000000000000000000000000000000000000018", "swh:1:rel:0000000000000000000000000000000000000019", ] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.walk(*args, **kwargs2)) expected = ["swh:1:rel:0000000000000000000000000000000000000019"] assert set(actual) == set(expected) kwargs2 = kwargs.copy() kwargs2["limit"] = 2 actual = list(graph_client.walk(*args, **kwargs2)) expected = [ "swh:1:dir:0000000000000000000000000000000000000016", "swh:1:dir:0000000000000000000000000000000000000017", ] assert set(actual) == set(expected) def test_random_walk(graph_client): - """as the walk is random, we test a visit from a cnt node to the only origin in - the dataset, and only check the final node of the path (i.e., the origin) - + """as the walk is random, we test a visit from a cnt node to the only + origin in the dataset, and only check the final node of the path + (i.e., the origin) """ args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori") kwargs = {"direction": "backward"} expected_root = "swh:1:ori:0000000000000000000000000000000000000021" actual = list(graph_client.random_walk(*args, **kwargs)) assert len(actual) > 1 # no origin directly links to a content assert actual[0] == args[0] assert actual[-1] == expected_root kwargs2 = kwargs.copy() kwargs2["limit"] = -1 actual = list(graph_client.random_walk(*args, **kwargs2)) assert actual == [expected_root] kwargs2["limit"] = -2 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 2 assert actual[-1] == expected_root kwargs2["limit"] = 3 actual = list(graph_client.random_walk(*args, **kwargs2)) assert len(actual) == 3 def test_count(graph_client): actual = graph_client.count_leaves( "swh:1:ori:0000000000000000000000000000000000000021" ) assert actual == 4 actual = graph_client.count_visit_nodes( "swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev" ) assert actual == 3 actual = graph_client.count_neighbors( "swh:1:rev:0000000000000000000000000000000000000009", direction="backward" ) assert actual == 3 def test_param_validation(graph_client): with raises(RemoteException) as exc_info: # SWHID not found list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021")) assert exc_info.value.response.status_code == 404 with raises(RemoteException) as exc_info: # malformed SWHID list( graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021") ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed edge specificaiton list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:notanodetype,dir:rev,rev:*", direction="backward", ) ) assert exc_info.value.response.status_code == 400 with raises(RemoteException) as exc_info: # malformed direction list( graph_client.visit_nodes( "swh:1:dir:0000000000000000000000000000000000000016", edges="dir:dir,dir:rev,rev:*", direction="notadirection", ) ) assert exc_info.value.response.status_code == 400 @pytest.mark.skip(reason="currently disabled due to T1969") def test_param_validation_walk(graph_client): """test validation of walk-specific parameters only""" with raises(RemoteException) as exc_info: # malformed traversal order list( graph_client.walk( "swh:1:dir:0000000000000000000000000000000000000016", "rel", edges="dir:dir,dir:rev,rev:*", direction="backward", traversal="notatraversalorder", ) ) assert exc_info.value.response.status_code == 400