diff --git a/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java b/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java index 737148a..c266dbc 100644 --- a/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java +++ b/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java @@ -1,74 +1,91 @@ package org.softwareheritage.graph; import java.util.ArrayList; /** * Edge restriction based on node types, used when visiting the graph. *

* Software Heritage * graph contains multiple node types (contents, directories, revisions, ...) and restricting * the traversal to specific node types is necessary for many querying operations: * use cases. * * @author The Software Heritage developers */ public class AllowedEdges { /** * 2D boolean matrix storing access rights for all combination of src/dst node types (first * dimension is source, second dimension is destination), when edge restriction is not enforced this * array is set to null for early bypass. */ public boolean[][] restrictedTo; /** * Constructor. * * @param edgesFmt a formatted string describing allowed * edges */ public AllowedEdges(String edgesFmt) { int nbNodeTypes = Node.Type.values().length; this.restrictedTo = new boolean[nbNodeTypes][nbNodeTypes]; // Special values (null, empty, "*") if (edgesFmt == null || edgesFmt.isEmpty()) { return; } if (edgesFmt.equals("*")) { // Allows for quick bypass (with simple null check) when no edge restriction restrictedTo = null; return; } // Format: "src1:dst1,src2:dst2,[...]" String[] edgeTypes = edgesFmt.split(","); for (String edgeType : edgeTypes) { String[] nodeTypes = edgeType.split(":"); if (nodeTypes.length != 2) { throw new IllegalArgumentException("Cannot parse edge type: " + edgeType); } ArrayList srcTypes = Node.Type.parse(nodeTypes[0]); ArrayList dstTypes = Node.Type.parse(nodeTypes[1]); for (Node.Type srcType : srcTypes) { for (Node.Type dstType : dstTypes) { restrictedTo[srcType.ordinal()][dstType.ordinal()] = true; } } } } /** * Checks if a given edge can be followed during graph traversal. * * @param srcType edge source type * @param dstType edge destination type * @return true if allowed and false otherwise */ public boolean isAllowed(Node.Type srcType, Node.Type dstType) { if (restrictedTo == null) return true; return restrictedTo[srcType.ordinal()][dstType.ordinal()]; } + + /** + * Return a new AllowedEdges instance with reversed edge restrictions. e.g. "src1:dst1,src2:dst2" + * becomes "dst1:src1,dst2:src2" + * + * @return a new AllowedEdges instance with reversed edge restrictions + */ + public AllowedEdges reverse() { + AllowedEdges reversed = new AllowedEdges(null); + reversed.restrictedTo = new boolean[restrictedTo.length][restrictedTo[0].length]; + for (int i = 0; i < restrictedTo.length; i++) { + for (int j = 0; j < restrictedTo[0].length; j++) { + reversed.restrictedTo[i][j] = restrictedTo[j][i]; + } + } + return reversed; + } } diff --git a/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java b/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java index a1a2c20..3137b2b 100644 --- a/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java +++ b/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java @@ -1,225 +1,252 @@ package org.softwareheritage.graph.rpc; import com.google.protobuf.FieldMask; import com.martiansoftware.jsap.*; import io.grpc.Server; import io.grpc.Status; import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; import io.grpc.netty.shaded.io.netty.channel.ChannelOption; import io.grpc.stub.StreamObserver; import io.grpc.protobuf.services.ProtoReflectionService; import it.unimi.dsi.logging.ProgressLogger; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.softwareheritage.graph.SWHID; import org.softwareheritage.graph.SwhBidirectionalGraph; import org.softwareheritage.graph.compress.LabelMapBuilder; import java.io.FileInputStream; import java.io.IOException; import java.util.Properties; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; /** * Server that manages startup/shutdown of a {@code Greeter} server. */ public class GraphServer { private final static Logger logger = LoggerFactory.getLogger(GraphServer.class); private final SwhBidirectionalGraph graph; private final int port; private final int threads; private Server server; public GraphServer(String graphBasename, int port, int threads) throws IOException { this.graph = loadGraph(graphBasename); this.port = port; this.threads = threads; } public static SwhBidirectionalGraph loadGraph(String basename) throws IOException { // TODO: use loadLabelledMapped() when https://github.com/vigna/webgraph-big/pull/5 is merged SwhBidirectionalGraph g = SwhBidirectionalGraph.loadLabelled(basename, new ProgressLogger(logger)); g.loadContentLength(); g.loadContentIsSkipped(); g.loadPersonIds(); g.loadAuthorTimestamps(); g.loadCommitterTimestamps(); g.loadMessages(); g.loadTagNames(); g.loadLabelNames(); return g; } private void start() throws IOException { server = NettyServerBuilder.forPort(port).withChildOption(ChannelOption.SO_REUSEADDR, true) .executor(Executors.newFixedThreadPool(threads)).addService(new TraversalService(graph)) .addService(ProtoReflectionService.newInstance()).build().start(); logger.info("Server started, listening on " + port); Runtime.getRuntime().addShutdownHook(new Thread(() -> { try { GraphServer.this.stop(); } catch (InterruptedException e) { e.printStackTrace(System.err); } })); } private void stop() throws InterruptedException { if (server != null) { server.shutdown().awaitTermination(30, TimeUnit.SECONDS); } } /** * Await termination on the main thread since the grpc library uses daemon threads. */ private void blockUntilShutdown() throws InterruptedException { if (server != null) { server.awaitTermination(); } } private static JSAPResult parseArgs(String[] args) { JSAPResult config = null; try { SimpleJSAP jsap = new SimpleJSAP(LabelMapBuilder.class.getName(), "", new Parameter[]{ new FlaggedOption("port", JSAP.INTEGER_PARSER, "50091", JSAP.NOT_REQUIRED, 'p', "port", "The port on which the server should listen."), new FlaggedOption("threads", JSAP.INTEGER_PARSER, "1", JSAP.NOT_REQUIRED, 't', "threads", "The number of concurrent threads. 0 = number of cores."), new UnflaggedOption("graphBasename", JSAP.STRING_PARSER, JSAP.REQUIRED, "Basename of the output graph")}); config = jsap.parse(args); if (jsap.messagePrinted()) { System.exit(1); } } catch (JSAPException e) { e.printStackTrace(); } return config; } /** * Main launches the server from the command line. */ public static void main(String[] args) throws IOException, InterruptedException { JSAPResult config = parseArgs(args); String graphBasename = config.getString("graphBasename"); int port = config.getInt("port"); int threads = config.getInt("threads"); if (threads == 0) { threads = Runtime.getRuntime().availableProcessors(); } final GraphServer server = new GraphServer(graphBasename, port, threads); server.start(); server.blockUntilShutdown(); } static class TraversalService extends TraversalServiceGrpc.TraversalServiceImplBase { SwhBidirectionalGraph graph; public TraversalService(SwhBidirectionalGraph graph) { this.graph = graph; } @Override public void checkSwhid(CheckSwhidRequest request, StreamObserver responseObserver) { CheckSwhidResponse.Builder builder = CheckSwhidResponse.newBuilder().setExists(true); try { graph.getNodeId(new SWHID(request.getSwhid())); } catch (IllegalArgumentException e) { builder.setExists(false); builder.setDetails(e.getMessage()); } responseObserver.onNext(builder.build()); responseObserver.onCompleted(); } @Override public void stats(StatsRequest request, StreamObserver responseObserver) { StatsResponse.Builder response = StatsResponse.newBuilder(); response.setNumNodes(graph.numNodes()); response.setNumEdges(graph.numArcs()); Properties properties = new Properties(); try { properties.load(new FileInputStream(graph.getPath() + ".properties")); properties.load(new FileInputStream(graph.getPath() + ".stats")); } catch (IOException e) { throw new RuntimeException(e); } response.setCompression(Double.parseDouble(properties.getProperty("compratio"))); response.setBitsPerNode(Double.parseDouble(properties.getProperty("bitspernode"))); response.setBitsPerEdge(Double.parseDouble(properties.getProperty("bitsperlink"))); response.setAvgLocality(Double.parseDouble(properties.getProperty("avglocality"))); response.setIndegreeMin(Long.parseLong(properties.getProperty("minindegree"))); response.setIndegreeMax(Long.parseLong(properties.getProperty("maxindegree"))); response.setIndegreeAvg(Double.parseDouble(properties.getProperty("avgindegree"))); response.setOutdegreeMin(Long.parseLong(properties.getProperty("minoutdegree"))); response.setOutdegreeMax(Long.parseLong(properties.getProperty("maxoutdegree"))); response.setOutdegreeAvg(Double.parseDouble(properties.getProperty("avgoutdegree"))); responseObserver.onNext(response.build()); responseObserver.onCompleted(); } @Override public void getNode(GetNodeRequest request, StreamObserver responseObserver) { long nodeId; try { nodeId = graph.getNodeId(new SWHID(request.getSwhid())); } catch (IllegalArgumentException e) { responseObserver.onError(Status.INVALID_ARGUMENT.withCause(e).asException()); return; } Node.Builder builder = Node.newBuilder(); NodePropertyBuilder.buildNodeProperties(graph.getForwardGraph(), request.hasMask() ? request.getMask() : null, builder, nodeId); responseObserver.onNext(builder.build()); responseObserver.onCompleted(); } @Override public void traverse(TraversalRequest request, StreamObserver responseObserver) { SwhBidirectionalGraph g = graph.copy(); - Traversal.simpleTraversal(g, request, responseObserver::onNext); + var t = new Traversal.SimpleTraversal(g, request, responseObserver::onNext); + t.visit(); responseObserver.onCompleted(); } + @Override + public void findPathTo(FindPathToRequest request, StreamObserver responseObserver) { + SwhBidirectionalGraph g = graph.copy(); + var t = new Traversal.FindPathTo(g, request); + t.visit(); + Path path = t.getPath(); + if (path == null) { + responseObserver.onError(Status.NOT_FOUND.asException()); + } else { + responseObserver.onNext(path); + responseObserver.onCompleted(); + } + } + + @Override + public void findPathBetween(FindPathBetweenRequest request, StreamObserver responseObserver) { + SwhBidirectionalGraph g = graph.copy(); + var t = new Traversal.FindPathBetween(g, request); + t.visit(); + Path path = t.getPath(); + if (path == null) { + responseObserver.onError(Status.NOT_FOUND.asException()); + } else { + responseObserver.onNext(path); + responseObserver.onCompleted(); + } + } + @Override public void countNodes(TraversalRequest request, StreamObserver responseObserver) { AtomicInteger count = new AtomicInteger(0); SwhBidirectionalGraph g = graph.copy(); TraversalRequest fixedReq = TraversalRequest.newBuilder(request) // Ignore return fields, just count nodes .setMask(FieldMask.getDefaultInstance()).build(); - Traversal.simpleTraversal(g, fixedReq, (Node node) -> { - count.incrementAndGet(); - }); + var t = new Traversal.SimpleTraversal(g, request, n -> count.incrementAndGet()); + t.visit(); CountResponse response = CountResponse.newBuilder().setCount(count.get()).build(); responseObserver.onNext(response); responseObserver.onCompleted(); } @Override public void countEdges(TraversalRequest request, StreamObserver responseObserver) { AtomicInteger count = new AtomicInteger(0); SwhBidirectionalGraph g = graph.copy(); TraversalRequest fixedReq = TraversalRequest.newBuilder(request) // Force return empty successors to count the edges .setMask(FieldMask.newBuilder().addPaths("successor").build()).build(); - Traversal.simpleTraversal(g, fixedReq, (Node node) -> { - count.addAndGet(node.getSuccessorCount()); - }); + var t = new Traversal.SimpleTraversal(g, request, n -> count.addAndGet(n.getSuccessorCount())); + t.visit(); CountResponse response = CountResponse.newBuilder().setCount(count.get()).build(); responseObserver.onNext(response); responseObserver.onCompleted(); } } } diff --git a/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java b/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java index c036052..5f50596 100644 --- a/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java +++ b/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java @@ -1,178 +1,426 @@ package org.softwareheritage.graph.rpc; -import it.unimi.dsi.big.webgraph.LazyLongIterator; import it.unimi.dsi.big.webgraph.labelling.ArcLabelledNodeIterator; import it.unimi.dsi.big.webgraph.labelling.Label; import org.softwareheritage.graph.*; import java.util.*; public class Traversal { - private static LazyLongIterator filterSuccessors(SwhUnidirectionalGraph g, long nodeId, AllowedEdges allowedEdges) { - if (allowedEdges.restrictedTo == null) { - // All edges are allowed, bypass edge check - return g.successors(nodeId); - } else { - LazyLongIterator allSuccessors = g.successors(nodeId); - return new LazyLongIterator() { - @Override - public long nextLong() { - long neighbor; - while ((neighbor = allSuccessors.nextLong()) != -1) { - if (allowedEdges.isAllowed(g.getNodeType(nodeId), g.getNodeType(neighbor))) { - return neighbor; - } - } - return -1; - } - - @Override - public long skip(final long n) { - long i = 0; - while (i < n && nextLong() != -1) - i++; - return i; - } - }; - } - } - private static ArcLabelledNodeIterator.LabelledArcIterator filterLabelledSuccessors(SwhUnidirectionalGraph g, long nodeId, AllowedEdges allowedEdges) { if (allowedEdges.restrictedTo == null) { // All edges are allowed, bypass edge check return g.labelledSuccessors(nodeId); } else { ArcLabelledNodeIterator.LabelledArcIterator allSuccessors = g.labelledSuccessors(nodeId); return new ArcLabelledNodeIterator.LabelledArcIterator() { @Override public Label label() { return allSuccessors.label(); } @Override public long nextLong() { long neighbor; while ((neighbor = allSuccessors.nextLong()) != -1) { if (allowedEdges.isAllowed(g.getNodeType(nodeId), g.getNodeType(neighbor))) { return neighbor; } } return -1; } @Override public long skip(final long n) { long i = 0; while (i < n && nextLong() != -1) i++; return i; } }; } } private static class NodeFilterChecker { private final SwhUnidirectionalGraph g; private final NodeFilter filter; private final AllowedNodes allowedNodes; private NodeFilterChecker(SwhUnidirectionalGraph graph, NodeFilter filter) { this.g = graph; this.filter = filter; this.allowedNodes = new AllowedNodes(filter.hasTypes() ? filter.getTypes() : "*"); } public boolean allowed(long nodeId) { if (filter == null) { return true; } if (!this.allowedNodes.isAllowed(g.getNodeType(nodeId))) { return false; } return true; } } - public static SwhUnidirectionalGraph getDirectedGraph(SwhBidirectionalGraph g, TraversalRequest request) { - switch (request.getDirection()) { + public static SwhUnidirectionalGraph getDirectedGraph(SwhBidirectionalGraph g, GraphDirection direction) { + switch (direction) { case FORWARD: return g.getForwardGraph(); case BACKWARD: return g.getBackwardGraph(); case BOTH: return new SwhUnidirectionalGraph(g.symmetrize(), g.getProperties()); + default : + throw new IllegalArgumentException("Unknown direction: " + direction); + } + } + + public static GraphDirection reverseDirection(GraphDirection direction) { + switch (direction) { + case FORWARD: + return GraphDirection.BACKWARD; + case BACKWARD: + return GraphDirection.FORWARD; + case BOTH: + return GraphDirection.BOTH; + default : + throw new IllegalArgumentException("Unknown direction: " + direction); } - throw new IllegalArgumentException("Unknown direction: " + request.getDirection()); } - public static void simpleTraversal(SwhBidirectionalGraph bidirectionalGraph, TraversalRequest request, - NodeObserver nodeObserver) { - SwhUnidirectionalGraph g = getDirectedGraph(bidirectionalGraph, request); - NodeFilterChecker nodeReturnChecker = new NodeFilterChecker(g, request.getReturnNodes()); - NodePropertyBuilder.NodeDataMask nodeDataMask = new NodePropertyBuilder.NodeDataMask( - request.hasMask() ? request.getMask() : null); - - AllowedEdges allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*"); - - Queue queue = new ArrayDeque<>(); - HashSet visited = new HashSet<>(); - request.getSrcList().forEach(srcSwhid -> { - long srcNodeId = g.getNodeId(new SWHID(srcSwhid)); - queue.add(srcNodeId); - visited.add(srcNodeId); - }); - queue.add(-1L); // Depth sentinel - - long edgesAccessed = 0; - long currentDepth = 0; - while (!queue.isEmpty()) { + static class StopTraversalException extends RuntimeException { + } + + static class BFSVisitor { + protected final SwhUnidirectionalGraph g; + protected long depth = 0; + protected long traversalSuccessors = 0; + protected long edgesAccessed = 0; + + protected HashMap parents = new HashMap<>(); + protected ArrayDeque queue = new ArrayDeque<>(); + private long maxDepth = -1; + private long maxEdges = -1; + + BFSVisitor(SwhUnidirectionalGraph g) { + this.g = g; + } + + public void addSource(long nodeId) { + queue.add(nodeId); + parents.put(nodeId, -1L); + } + + public void setMaxDepth(long depth) { + maxDepth = depth; + } + + public void setMaxEdges(long edges) { + maxEdges = edges; + } + + public void visitSetup() { + edgesAccessed = 0; + depth = 0; + queue.add(-1L); // depth sentinel + } + + public void visit() { + visitSetup(); + try { + while (!queue.isEmpty()) { + visitStep(); + } + } catch (StopTraversalException e) { + // Ignore + } + } + + public void visitStep() { + assert !queue.isEmpty(); long curr = queue.poll(); if (curr == -1L) { - ++currentDepth; + ++depth; if (!queue.isEmpty()) { queue.add(-1L); + visitStep(); } - continue; + return; } - if (request.hasMaxDepth() && currentDepth > request.getMaxDepth()) { - break; + if (maxDepth >= 0 && depth > maxDepth) { + throw new StopTraversalException(); } edgesAccessed += g.outdegree(curr); - if (request.hasMaxEdges() && edgesAccessed >= request.getMaxEdges()) { - break; + if (maxEdges >= 0 && edgesAccessed >= maxEdges) { + throw new StopTraversalException(); } + visitNode(curr); + } - Node.Builder nodeBuilder = null; - if (nodeReturnChecker.allowed(curr) && (!request.hasMinDepth() || currentDepth >= request.getMinDepth())) { - nodeBuilder = Node.newBuilder(); - NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, curr); - } + protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) { + return g.labelledSuccessors(nodeId); + } - ArcLabelledNodeIterator.LabelledArcIterator it = filterLabelledSuccessors(g, curr, allowedEdges); - long traversalSuccessors = 0; + protected void visitNode(long node) { + ArcLabelledNodeIterator.LabelledArcIterator it = getSuccessors(node); + traversalSuccessors = 0; for (long succ; (succ = it.nextLong()) != -1;) { traversalSuccessors++; - if (!visited.contains(succ)) { - queue.add(succ); - visited.add(succ); - } - NodePropertyBuilder.buildSuccessorProperties(g, nodeDataMask, nodeBuilder, curr, succ, it.label()); + visitEdge(node, succ, it.label()); + } + } + + protected void visitEdge(long src, long dst, Label label) { + if (!parents.containsKey(dst)) { + queue.add(dst); + parents.put(dst, src); + } + } + } + + static class SimpleTraversal extends BFSVisitor { + private final NodeFilterChecker nodeReturnChecker; + private final AllowedEdges allowedEdges; + private final TraversalRequest request; + private final NodePropertyBuilder.NodeDataMask nodeDataMask; + private final NodeObserver nodeObserver; + + private Node.Builder nodeBuilder; + + SimpleTraversal(SwhBidirectionalGraph bidirectionalGraph, TraversalRequest request, NodeObserver nodeObserver) { + super(getDirectedGraph(bidirectionalGraph, request.getDirection())); + this.request = request; + this.nodeObserver = nodeObserver; + this.nodeReturnChecker = new NodeFilterChecker(g, request.getReturnNodes()); + this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null); + this.allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*"); + request.getSrcList().forEach(srcSwhid -> { + long srcNodeId = g.getNodeId(new SWHID(srcSwhid)); + addSource(srcNodeId); + }); + if (request.hasMaxDepth()) { + setMaxDepth(request.getMaxDepth()); + } + if (request.hasMaxEdges()) { + setMaxEdges(request.getMaxEdges()); + } + } + + @Override + protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) { + return filterLabelledSuccessors(g, nodeId, allowedEdges); + } + + @Override + public void visitNode(long node) { + nodeBuilder = null; + if (nodeReturnChecker.allowed(node) && (!request.hasMinDepth() || depth >= request.getMinDepth())) { + nodeBuilder = Node.newBuilder(); + NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, node); } + super.visitNode(node); if (request.getReturnNodes().hasMinTraversalSuccessors() && traversalSuccessors < request.getReturnNodes().getMinTraversalSuccessors() || request.getReturnNodes().hasMaxTraversalSuccessors() && traversalSuccessors > request.getReturnNodes().getMaxTraversalSuccessors()) { nodeBuilder = null; } if (nodeBuilder != null) { nodeObserver.onNext(nodeBuilder.build()); } } + + @Override + protected void visitEdge(long src, long dst, Label label) { + super.visitEdge(src, dst, label); + NodePropertyBuilder.buildSuccessorProperties(g, nodeDataMask, nodeBuilder, src, dst, label); + } + } + + static class FindPathTo extends BFSVisitor { + private final AllowedEdges allowedEdges; + private final FindPathToRequest request; + private final NodePropertyBuilder.NodeDataMask nodeDataMask; + private final NodeFilterChecker targetChecker; + private Long targetNode = null; + + FindPathTo(SwhBidirectionalGraph bidirectionalGraph, FindPathToRequest request) { + super(getDirectedGraph(bidirectionalGraph, request.getDirection())); + this.request = request; + this.targetChecker = new NodeFilterChecker(g, request.getTarget()); + this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null); + this.allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*"); + if (request.hasMaxDepth()) { + setMaxDepth(request.getMaxDepth()); + } + if (request.hasMaxEdges()) { + setMaxEdges(request.getMaxEdges()); + } + request.getSrcList().forEach(srcSwhid -> { + long srcNodeId = g.getNodeId(new SWHID(srcSwhid)); + addSource(srcNodeId); + }); + } + + @Override + protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) { + return filterLabelledSuccessors(g, nodeId, allowedEdges); + } + + @Override + public void visitNode(long node) { + if (targetChecker.allowed(node)) { + targetNode = node; + throw new StopTraversalException(); + } + super.visitNode(node); + } + + public Path getPath() { + if (targetNode == null) { + return null; + } + Path.Builder pathBuilder = Path.newBuilder(); + long curNode = targetNode; + ArrayList path = new ArrayList<>(); + while (curNode != -1) { + path.add(curNode); + curNode = parents.get(curNode); + } + Collections.reverse(path); + for (long nodeId : path) { + Node.Builder nodeBuilder = Node.newBuilder(); + NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, nodeId); + pathBuilder.addNode(nodeBuilder.build()); + } + return pathBuilder.build(); + } + } + + static class FindPathBetween extends BFSVisitor { + private final FindPathBetweenRequest request; + private final NodePropertyBuilder.NodeDataMask nodeDataMask; + private final AllowedEdges allowedEdgesSrc; + private final AllowedEdges allowedEdgesDst; + + private final BFSVisitor srcVisitor; + private final BFSVisitor dstVisitor; + private Long middleNode = null; + + FindPathBetween(SwhBidirectionalGraph bidirectionalGraph, FindPathBetweenRequest request) { + super(getDirectedGraph(bidirectionalGraph, request.getDirection())); + this.request = request; + this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null); + this.allowedEdgesSrc = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*"); + this.allowedEdgesDst = request.hasEdgesReverse() + ? new AllowedEdges(request.getEdgesReverse()) + : (request.hasEdges() ? new AllowedEdges(request.getEdges()).reverse() : new AllowedEdges("*")); + SwhUnidirectionalGraph srcGraph = getDirectedGraph(bidirectionalGraph, request.getDirection()); + SwhUnidirectionalGraph dstGraph = getDirectedGraph(bidirectionalGraph, + request.hasDirectionReverse() + ? request.getDirectionReverse() + : reverseDirection(request.getDirection())); + + this.srcVisitor = new BFSVisitor(srcGraph) { + @Override + protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) { + return filterLabelledSuccessors(g, nodeId, allowedEdgesSrc); + } + + @Override + public void visitNode(long node) { + if (dstVisitor.parents.containsKey(node)) { + middleNode = node; + throw new StopTraversalException(); + } + super.visitNode(node); + } + }; + this.dstVisitor = new BFSVisitor(dstGraph) { + @Override + protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) { + return filterLabelledSuccessors(g, nodeId, allowedEdgesDst); + } + + @Override + public void visitNode(long node) { + if (dstVisitor.parents.containsKey(node)) { + middleNode = node; + throw new StopTraversalException(); + } + super.visitNode(node); + } + }; + if (request.hasMaxDepth()) { + this.srcVisitor.setMaxDepth(request.getMaxDepth()); + this.dstVisitor.setMaxDepth(request.getMaxDepth()); + } + if (request.hasMaxEdges()) { + this.srcVisitor.setMaxEdges(request.getMaxEdges()); + this.dstVisitor.setMaxEdges(request.getMaxEdges()); + } + request.getSrcList().forEach(srcSwhid -> { + long srcNodeId = g.getNodeId(new SWHID(srcSwhid)); + srcVisitor.addSource(srcNodeId); + }); + request.getDstList().forEach(srcSwhid -> { + long srcNodeId = g.getNodeId(new SWHID(srcSwhid)); + dstVisitor.addSource(srcNodeId); + }); + } + + @Override + public void visit() { + srcVisitor.visitSetup(); + dstVisitor.visitSetup(); + while (!srcVisitor.queue.isEmpty() || !dstVisitor.queue.isEmpty()) { + try { + if (!srcVisitor.queue.isEmpty()) { + srcVisitor.visitStep(); + } + } catch (StopTraversalException e) { + srcVisitor.queue.clear(); + } + try { + if (!dstVisitor.queue.isEmpty()) { + dstVisitor.visitStep(); + } + } catch (StopTraversalException e) { + dstVisitor.queue.clear(); + } + } + } + + public Path getPath() { + if (middleNode == null) { + return null; + } + Path.Builder pathBuilder = Path.newBuilder(); + ArrayList path = new ArrayList<>(); + long curNode = middleNode; + while (curNode != -1) { + path.add(curNode); + curNode = srcVisitor.parents.get(curNode); + } + Collections.reverse(path); + curNode = dstVisitor.parents.get(middleNode); + while (curNode != -1) { + path.add(curNode); + curNode = dstVisitor.parents.get(curNode); + } + for (long nodeId : path) { + Node.Builder nodeBuilder = Node.newBuilder(); + NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, nodeId); + pathBuilder.addNode(nodeBuilder.build()); + } + return pathBuilder.build(); + } } public interface NodeObserver { void onNext(Node nodeId); } } diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java new file mode 100644 index 0000000..0e288bc --- /dev/null +++ b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java @@ -0,0 +1,42 @@ +package org.softwareheritage.graph.rpc; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.softwareheritage.graph.SWHID; + +import java.util.ArrayList; +import java.util.List; + +public class FindPathBetweenTest extends TraversalServiceTest { + private FindPathBetweenRequest.Builder getRequestBuilder(SWHID src, SWHID dst) { + return FindPathBetweenRequest.newBuilder().addSrc(src.toString()).addDst(dst.toString()); + } + + @Test + public void forwardRootToLeaf() { + ArrayList actual = getSWHIDs( + client.findPathBetween(getRequestBuilder(new SWHID(TEST_ORIGIN_ID), fakeSWHID("cnt", 4)).build())); + List expected = List.of(new SWHID(TEST_ORIGIN_ID), fakeSWHID("snp", 20), fakeSWHID("rev", 9), + fakeSWHID("dir", 8), fakeSWHID("dir", 6), fakeSWHID("cnt", 4)); + Assertions.assertEquals(expected, actual); + } + + @Test + public void forwardRevToRev() { + ArrayList actual = getSWHIDs( + client.findPathBetween(getRequestBuilder(fakeSWHID("rev", 18), fakeSWHID("rev", 3)).build())); + List expected = List.of(fakeSWHID("rev", 18), fakeSWHID("rev", 13), fakeSWHID("rev", 9), + fakeSWHID("rev", 3)); + Assertions.assertEquals(expected, actual); + } + + @Test + public void backwardRevToRev() { + ArrayList actual = getSWHIDs( + client.findPathBetween(getRequestBuilder(fakeSWHID("rev", 3), fakeSWHID("rev", 18)) + .setDirection(GraphDirection.BACKWARD).build())); + List expected = List.of(fakeSWHID("rev", 3), fakeSWHID("rev", 9), fakeSWHID("rev", 13), + fakeSWHID("rev", 18)); + Assertions.assertEquals(expected, actual); + } +} diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java new file mode 100644 index 0000000..182f005 --- /dev/null +++ b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java @@ -0,0 +1,48 @@ +package org.softwareheritage.graph.rpc; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.softwareheritage.graph.SWHID; + +import java.util.ArrayList; +import java.util.List; + +public class FindPathToTest extends TraversalServiceTest { + private FindPathToRequest.Builder getRequestBuilder(SWHID src, String allowedNodes) { + return FindPathToRequest.newBuilder().addSrc(src.toString()) + .setTarget(NodeFilter.newBuilder().setTypes(allowedNodes).build()); + } + + @Test + public void forwardOriToFirstDir() { + ArrayList actual = getSWHIDs( + client.findPathTo(getRequestBuilder(new SWHID(TEST_ORIGIN_ID), "dir").build())); + List expected = List.of(new SWHID(TEST_ORIGIN_ID), fakeSWHID("snp", 20), fakeSWHID("rev", 9), + fakeSWHID("dir", 8)); + Assertions.assertEquals(expected, actual); + } + + @Test + public void forwardRelToFirstCnt() { + ArrayList actual = getSWHIDs(client.findPathTo(getRequestBuilder(fakeSWHID("rel", 19), "cnt").build())); + List expected = List.of(fakeSWHID("rel", 19), fakeSWHID("rev", 18), fakeSWHID("dir", 17), + fakeSWHID("cnt", 14)); + Assertions.assertEquals(expected, actual); + } + + @Test + public void backwardDirToFirstRel() { + ArrayList actual = getSWHIDs(client.findPathTo( + getRequestBuilder(fakeSWHID("dir", 16), "rel").setDirection(GraphDirection.BACKWARD).build())); + List expected = List.of(fakeSWHID("dir", 16), fakeSWHID("dir", 17), fakeSWHID("rev", 18), + fakeSWHID("rel", 19)); + Assertions.assertEquals(expected, actual); + } + + @Test + public void forwardCntToItself() { + ArrayList actual = getSWHIDs(client.findPathTo(getRequestBuilder(fakeSWHID("cnt", 4), "cnt").build())); + List expected = List.of(fakeSWHID("cnt", 4)); + Assertions.assertEquals(expected, actual); + } +} diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java index f038327..b11c1fc 100644 --- a/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java +++ b/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java @@ -1,50 +1,58 @@ package org.softwareheritage.graph.rpc; import io.grpc.ManagedChannel; import io.grpc.Server; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.testing.GrpcCleanupRule; import org.junit.Rule; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.softwareheritage.graph.GraphTest; import org.softwareheritage.graph.SWHID; import org.softwareheritage.graph.SwhBidirectionalGraph; import java.util.ArrayList; import java.util.Iterator; public class TraversalServiceTest extends GraphTest { @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private static Server server; private static ManagedChannel channel; protected static SwhBidirectionalGraph g; protected static TraversalServiceGrpc.TraversalServiceBlockingStub client; @BeforeAll static void setup() throws Exception { String serverName = InProcessServerBuilder.generateName(); g = GraphServer.loadGraph(getGraphPath().toString()); server = InProcessServerBuilder.forName(serverName).directExecutor() .addService(new GraphServer.TraversalService(g.copy())).build().start(); channel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); client = TraversalServiceGrpc.newBlockingStub(channel); } @AfterAll static void teardown() { channel.shutdownNow(); server.shutdownNow(); } public ArrayList getSWHIDs(Iterator it) { ArrayList res = new ArrayList<>(); it.forEachRemaining((Node n) -> { res.add(new SWHID(n.getSwhid())); }); return res; } + + public ArrayList getSWHIDs(Path p) { + ArrayList res = new ArrayList<>(); + p.getNodeList().forEach((Node n) -> { + res.add(new SWHID(n.getSwhid())); + }); + return res; + } } diff --git a/proto/swhgraph.proto b/proto/swhgraph.proto index 7637941..3687c8e 100644 --- a/proto/swhgraph.proto +++ b/proto/swhgraph.proto @@ -1,129 +1,155 @@ syntax = "proto3"; import "google/protobuf/field_mask.proto"; option java_multiple_files = true; option java_package = "org.softwareheritage.graph.rpc"; option java_outer_classname = "GraphService"; package swh.graph; service TraversalService { rpc Traverse (TraversalRequest) returns (stream Node); + rpc FindPathTo (FindPathToRequest) returns (Path); + rpc FindPathBetween (FindPathBetweenRequest) returns (Path); rpc CountNodes (TraversalRequest) returns (CountResponse); rpc CountEdges (TraversalRequest) returns (CountResponse); rpc Stats (StatsRequest) returns (StatsResponse); rpc CheckSwhid (CheckSwhidRequest) returns (CheckSwhidResponse); rpc GetNode (GetNodeRequest) returns (Node); } enum GraphDirection { FORWARD = 0; BACKWARD = 1; BOTH = 2; } message TraversalRequest { repeated string src = 1; - - // Traversal options GraphDirection direction = 2; optional string edges = 3; optional int64 max_edges = 4; optional int64 min_depth = 5; optional int64 max_depth = 6; optional NodeFilter return_nodes = 7; optional google.protobuf.FieldMask mask = 8; } +message FindPathToRequest { + repeated string src = 1; + optional NodeFilter target = 2; + GraphDirection direction = 3; + optional string edges = 4; + optional int64 max_edges = 5; + optional int64 max_depth = 6; + optional google.protobuf.FieldMask mask = 7; +} + +message FindPathBetweenRequest { + repeated string src = 1; + repeated string dst = 2; + GraphDirection direction = 3; + optional GraphDirection direction_reverse = 4; + optional string edges = 5; + optional string edges_reverse = 6; + optional int64 max_edges = 7; + optional int64 max_depth = 8; + optional google.protobuf.FieldMask mask = 9; +} + message NodeFilter { optional string types = 1; optional int64 min_traversal_successors = 2; optional int64 max_traversal_successors = 3; } message Node { string swhid = 1; repeated Successor successor = 2; oneof data { ContentData cnt = 3; RevisionData rev = 5; ReleaseData rel = 6; OriginData ori = 8; }; } +message Path { + repeated Node node = 1; +} + message Successor { optional string swhid = 1; repeated EdgeLabel label = 2; } message ContentData { optional int64 length = 1; optional bool is_skipped = 2; } message RevisionData { optional int64 author = 1; optional int64 author_date = 2; optional int32 author_date_offset = 3; optional int64 committer = 4; optional int64 committer_date = 5; optional int32 committer_date_offset = 6; optional bytes message = 7; } message ReleaseData { optional int64 author = 1; optional int64 author_date = 2; optional int32 author_date_offset = 3; optional bytes name = 4; optional bytes message = 5; } message OriginData { optional string url = 1; } message EdgeLabel { bytes name = 1; int32 permission = 2; } message CountResponse { int64 count = 1; } message StatsRequest { } message StatsResponse { int64 num_nodes = 1; int64 num_edges = 2; double compression = 3; double bits_per_node = 4; double bits_per_edge = 5; double avg_locality = 6; int64 indegree_min = 7; int64 indegree_max = 8; double indegree_avg = 9; int64 outdegree_min = 10; int64 outdegree_max = 11; double outdegree_avg = 12; } message CheckSwhidRequest { string swhid = 1; } message GetNodeRequest { string swhid = 1; optional google.protobuf.FieldMask mask = 8; } message CheckSwhidResponse { bool exists = 1; string details = 2; }