diff --git a/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java b/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java
index 737148a..c266dbc 100644
--- a/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java
+++ b/java/src/main/java/org/softwareheritage/graph/AllowedEdges.java
@@ -1,74 +1,91 @@
package org.softwareheritage.graph;
import java.util.ArrayList;
/**
* Edge restriction based on node types, used when visiting the graph.
*
* Software Heritage
* graph contains multiple node types (contents, directories, revisions, ...) and restricting
* the traversal to specific node types is necessary for many querying operations:
* use cases.
*
* @author The Software Heritage developers
*/
public class AllowedEdges {
/**
* 2D boolean matrix storing access rights for all combination of src/dst node types (first
* dimension is source, second dimension is destination), when edge restriction is not enforced this
* array is set to null for early bypass.
*/
public boolean[][] restrictedTo;
/**
* Constructor.
*
* @param edgesFmt a formatted string describing allowed
* edges
*/
public AllowedEdges(String edgesFmt) {
int nbNodeTypes = Node.Type.values().length;
this.restrictedTo = new boolean[nbNodeTypes][nbNodeTypes];
// Special values (null, empty, "*")
if (edgesFmt == null || edgesFmt.isEmpty()) {
return;
}
if (edgesFmt.equals("*")) {
// Allows for quick bypass (with simple null check) when no edge restriction
restrictedTo = null;
return;
}
// Format: "src1:dst1,src2:dst2,[...]"
String[] edgeTypes = edgesFmt.split(",");
for (String edgeType : edgeTypes) {
String[] nodeTypes = edgeType.split(":");
if (nodeTypes.length != 2) {
throw new IllegalArgumentException("Cannot parse edge type: " + edgeType);
}
ArrayList srcTypes = Node.Type.parse(nodeTypes[0]);
ArrayList dstTypes = Node.Type.parse(nodeTypes[1]);
for (Node.Type srcType : srcTypes) {
for (Node.Type dstType : dstTypes) {
restrictedTo[srcType.ordinal()][dstType.ordinal()] = true;
}
}
}
}
/**
* Checks if a given edge can be followed during graph traversal.
*
* @param srcType edge source type
* @param dstType edge destination type
* @return true if allowed and false otherwise
*/
public boolean isAllowed(Node.Type srcType, Node.Type dstType) {
if (restrictedTo == null)
return true;
return restrictedTo[srcType.ordinal()][dstType.ordinal()];
}
+
+ /**
+ * Return a new AllowedEdges instance with reversed edge restrictions. e.g. "src1:dst1,src2:dst2"
+ * becomes "dst1:src1,dst2:src2"
+ *
+ * @return a new AllowedEdges instance with reversed edge restrictions
+ */
+ public AllowedEdges reverse() {
+ AllowedEdges reversed = new AllowedEdges(null);
+ reversed.restrictedTo = new boolean[restrictedTo.length][restrictedTo[0].length];
+ for (int i = 0; i < restrictedTo.length; i++) {
+ for (int j = 0; j < restrictedTo[0].length; j++) {
+ reversed.restrictedTo[i][j] = restrictedTo[j][i];
+ }
+ }
+ return reversed;
+ }
}
diff --git a/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java b/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java
index a1a2c20..3137b2b 100644
--- a/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java
+++ b/java/src/main/java/org/softwareheritage/graph/rpc/GraphServer.java
@@ -1,225 +1,252 @@
package org.softwareheritage.graph.rpc;
import com.google.protobuf.FieldMask;
import com.martiansoftware.jsap.*;
import io.grpc.Server;
import io.grpc.Status;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.channel.ChannelOption;
import io.grpc.stub.StreamObserver;
import io.grpc.protobuf.services.ProtoReflectionService;
import it.unimi.dsi.logging.ProgressLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.softwareheritage.graph.SWHID;
import org.softwareheritage.graph.SwhBidirectionalGraph;
import org.softwareheritage.graph.compress.LabelMapBuilder;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Properties;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Server that manages startup/shutdown of a {@code Greeter} server.
*/
public class GraphServer {
private final static Logger logger = LoggerFactory.getLogger(GraphServer.class);
private final SwhBidirectionalGraph graph;
private final int port;
private final int threads;
private Server server;
public GraphServer(String graphBasename, int port, int threads) throws IOException {
this.graph = loadGraph(graphBasename);
this.port = port;
this.threads = threads;
}
public static SwhBidirectionalGraph loadGraph(String basename) throws IOException {
// TODO: use loadLabelledMapped() when https://github.com/vigna/webgraph-big/pull/5 is merged
SwhBidirectionalGraph g = SwhBidirectionalGraph.loadLabelled(basename, new ProgressLogger(logger));
g.loadContentLength();
g.loadContentIsSkipped();
g.loadPersonIds();
g.loadAuthorTimestamps();
g.loadCommitterTimestamps();
g.loadMessages();
g.loadTagNames();
g.loadLabelNames();
return g;
}
private void start() throws IOException {
server = NettyServerBuilder.forPort(port).withChildOption(ChannelOption.SO_REUSEADDR, true)
.executor(Executors.newFixedThreadPool(threads)).addService(new TraversalService(graph))
.addService(ProtoReflectionService.newInstance()).build().start();
logger.info("Server started, listening on " + port);
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
GraphServer.this.stop();
} catch (InterruptedException e) {
e.printStackTrace(System.err);
}
}));
}
private void stop() throws InterruptedException {
if (server != null) {
server.shutdown().awaitTermination(30, TimeUnit.SECONDS);
}
}
/**
* Await termination on the main thread since the grpc library uses daemon threads.
*/
private void blockUntilShutdown() throws InterruptedException {
if (server != null) {
server.awaitTermination();
}
}
private static JSAPResult parseArgs(String[] args) {
JSAPResult config = null;
try {
SimpleJSAP jsap = new SimpleJSAP(LabelMapBuilder.class.getName(), "",
new Parameter[]{
new FlaggedOption("port", JSAP.INTEGER_PARSER, "50091", JSAP.NOT_REQUIRED, 'p', "port",
"The port on which the server should listen."),
new FlaggedOption("threads", JSAP.INTEGER_PARSER, "1", JSAP.NOT_REQUIRED, 't', "threads",
"The number of concurrent threads. 0 = number of cores."),
new UnflaggedOption("graphBasename", JSAP.STRING_PARSER, JSAP.REQUIRED,
"Basename of the output graph")});
config = jsap.parse(args);
if (jsap.messagePrinted()) {
System.exit(1);
}
} catch (JSAPException e) {
e.printStackTrace();
}
return config;
}
/**
* Main launches the server from the command line.
*/
public static void main(String[] args) throws IOException, InterruptedException {
JSAPResult config = parseArgs(args);
String graphBasename = config.getString("graphBasename");
int port = config.getInt("port");
int threads = config.getInt("threads");
if (threads == 0) {
threads = Runtime.getRuntime().availableProcessors();
}
final GraphServer server = new GraphServer(graphBasename, port, threads);
server.start();
server.blockUntilShutdown();
}
static class TraversalService extends TraversalServiceGrpc.TraversalServiceImplBase {
SwhBidirectionalGraph graph;
public TraversalService(SwhBidirectionalGraph graph) {
this.graph = graph;
}
@Override
public void checkSwhid(CheckSwhidRequest request, StreamObserver responseObserver) {
CheckSwhidResponse.Builder builder = CheckSwhidResponse.newBuilder().setExists(true);
try {
graph.getNodeId(new SWHID(request.getSwhid()));
} catch (IllegalArgumentException e) {
builder.setExists(false);
builder.setDetails(e.getMessage());
}
responseObserver.onNext(builder.build());
responseObserver.onCompleted();
}
@Override
public void stats(StatsRequest request, StreamObserver responseObserver) {
StatsResponse.Builder response = StatsResponse.newBuilder();
response.setNumNodes(graph.numNodes());
response.setNumEdges(graph.numArcs());
Properties properties = new Properties();
try {
properties.load(new FileInputStream(graph.getPath() + ".properties"));
properties.load(new FileInputStream(graph.getPath() + ".stats"));
} catch (IOException e) {
throw new RuntimeException(e);
}
response.setCompression(Double.parseDouble(properties.getProperty("compratio")));
response.setBitsPerNode(Double.parseDouble(properties.getProperty("bitspernode")));
response.setBitsPerEdge(Double.parseDouble(properties.getProperty("bitsperlink")));
response.setAvgLocality(Double.parseDouble(properties.getProperty("avglocality")));
response.setIndegreeMin(Long.parseLong(properties.getProperty("minindegree")));
response.setIndegreeMax(Long.parseLong(properties.getProperty("maxindegree")));
response.setIndegreeAvg(Double.parseDouble(properties.getProperty("avgindegree")));
response.setOutdegreeMin(Long.parseLong(properties.getProperty("minoutdegree")));
response.setOutdegreeMax(Long.parseLong(properties.getProperty("maxoutdegree")));
response.setOutdegreeAvg(Double.parseDouble(properties.getProperty("avgoutdegree")));
responseObserver.onNext(response.build());
responseObserver.onCompleted();
}
@Override
public void getNode(GetNodeRequest request, StreamObserver responseObserver) {
long nodeId;
try {
nodeId = graph.getNodeId(new SWHID(request.getSwhid()));
} catch (IllegalArgumentException e) {
responseObserver.onError(Status.INVALID_ARGUMENT.withCause(e).asException());
return;
}
Node.Builder builder = Node.newBuilder();
NodePropertyBuilder.buildNodeProperties(graph.getForwardGraph(),
request.hasMask() ? request.getMask() : null, builder, nodeId);
responseObserver.onNext(builder.build());
responseObserver.onCompleted();
}
@Override
public void traverse(TraversalRequest request, StreamObserver responseObserver) {
SwhBidirectionalGraph g = graph.copy();
- Traversal.simpleTraversal(g, request, responseObserver::onNext);
+ var t = new Traversal.SimpleTraversal(g, request, responseObserver::onNext);
+ t.visit();
responseObserver.onCompleted();
}
+ @Override
+ public void findPathTo(FindPathToRequest request, StreamObserver responseObserver) {
+ SwhBidirectionalGraph g = graph.copy();
+ var t = new Traversal.FindPathTo(g, request);
+ t.visit();
+ Path path = t.getPath();
+ if (path == null) {
+ responseObserver.onError(Status.NOT_FOUND.asException());
+ } else {
+ responseObserver.onNext(path);
+ responseObserver.onCompleted();
+ }
+ }
+
+ @Override
+ public void findPathBetween(FindPathBetweenRequest request, StreamObserver responseObserver) {
+ SwhBidirectionalGraph g = graph.copy();
+ var t = new Traversal.FindPathBetween(g, request);
+ t.visit();
+ Path path = t.getPath();
+ if (path == null) {
+ responseObserver.onError(Status.NOT_FOUND.asException());
+ } else {
+ responseObserver.onNext(path);
+ responseObserver.onCompleted();
+ }
+ }
+
@Override
public void countNodes(TraversalRequest request, StreamObserver responseObserver) {
AtomicInteger count = new AtomicInteger(0);
SwhBidirectionalGraph g = graph.copy();
TraversalRequest fixedReq = TraversalRequest.newBuilder(request)
// Ignore return fields, just count nodes
.setMask(FieldMask.getDefaultInstance()).build();
- Traversal.simpleTraversal(g, fixedReq, (Node node) -> {
- count.incrementAndGet();
- });
+ var t = new Traversal.SimpleTraversal(g, request, n -> count.incrementAndGet());
+ t.visit();
CountResponse response = CountResponse.newBuilder().setCount(count.get()).build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}
@Override
public void countEdges(TraversalRequest request, StreamObserver responseObserver) {
AtomicInteger count = new AtomicInteger(0);
SwhBidirectionalGraph g = graph.copy();
TraversalRequest fixedReq = TraversalRequest.newBuilder(request)
// Force return empty successors to count the edges
.setMask(FieldMask.newBuilder().addPaths("successor").build()).build();
- Traversal.simpleTraversal(g, fixedReq, (Node node) -> {
- count.addAndGet(node.getSuccessorCount());
- });
+ var t = new Traversal.SimpleTraversal(g, request, n -> count.addAndGet(n.getSuccessorCount()));
+ t.visit();
CountResponse response = CountResponse.newBuilder().setCount(count.get()).build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}
}
}
diff --git a/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java b/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java
index c036052..5f50596 100644
--- a/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/rpc/Traversal.java
@@ -1,178 +1,426 @@
package org.softwareheritage.graph.rpc;
-import it.unimi.dsi.big.webgraph.LazyLongIterator;
import it.unimi.dsi.big.webgraph.labelling.ArcLabelledNodeIterator;
import it.unimi.dsi.big.webgraph.labelling.Label;
import org.softwareheritage.graph.*;
import java.util.*;
public class Traversal {
- private static LazyLongIterator filterSuccessors(SwhUnidirectionalGraph g, long nodeId, AllowedEdges allowedEdges) {
- if (allowedEdges.restrictedTo == null) {
- // All edges are allowed, bypass edge check
- return g.successors(nodeId);
- } else {
- LazyLongIterator allSuccessors = g.successors(nodeId);
- return new LazyLongIterator() {
- @Override
- public long nextLong() {
- long neighbor;
- while ((neighbor = allSuccessors.nextLong()) != -1) {
- if (allowedEdges.isAllowed(g.getNodeType(nodeId), g.getNodeType(neighbor))) {
- return neighbor;
- }
- }
- return -1;
- }
-
- @Override
- public long skip(final long n) {
- long i = 0;
- while (i < n && nextLong() != -1)
- i++;
- return i;
- }
- };
- }
- }
-
private static ArcLabelledNodeIterator.LabelledArcIterator filterLabelledSuccessors(SwhUnidirectionalGraph g,
long nodeId, AllowedEdges allowedEdges) {
if (allowedEdges.restrictedTo == null) {
// All edges are allowed, bypass edge check
return g.labelledSuccessors(nodeId);
} else {
ArcLabelledNodeIterator.LabelledArcIterator allSuccessors = g.labelledSuccessors(nodeId);
return new ArcLabelledNodeIterator.LabelledArcIterator() {
@Override
public Label label() {
return allSuccessors.label();
}
@Override
public long nextLong() {
long neighbor;
while ((neighbor = allSuccessors.nextLong()) != -1) {
if (allowedEdges.isAllowed(g.getNodeType(nodeId), g.getNodeType(neighbor))) {
return neighbor;
}
}
return -1;
}
@Override
public long skip(final long n) {
long i = 0;
while (i < n && nextLong() != -1)
i++;
return i;
}
};
}
}
private static class NodeFilterChecker {
private final SwhUnidirectionalGraph g;
private final NodeFilter filter;
private final AllowedNodes allowedNodes;
private NodeFilterChecker(SwhUnidirectionalGraph graph, NodeFilter filter) {
this.g = graph;
this.filter = filter;
this.allowedNodes = new AllowedNodes(filter.hasTypes() ? filter.getTypes() : "*");
}
public boolean allowed(long nodeId) {
if (filter == null) {
return true;
}
if (!this.allowedNodes.isAllowed(g.getNodeType(nodeId))) {
return false;
}
return true;
}
}
- public static SwhUnidirectionalGraph getDirectedGraph(SwhBidirectionalGraph g, TraversalRequest request) {
- switch (request.getDirection()) {
+ public static SwhUnidirectionalGraph getDirectedGraph(SwhBidirectionalGraph g, GraphDirection direction) {
+ switch (direction) {
case FORWARD:
return g.getForwardGraph();
case BACKWARD:
return g.getBackwardGraph();
case BOTH:
return new SwhUnidirectionalGraph(g.symmetrize(), g.getProperties());
+ default :
+ throw new IllegalArgumentException("Unknown direction: " + direction);
+ }
+ }
+
+ public static GraphDirection reverseDirection(GraphDirection direction) {
+ switch (direction) {
+ case FORWARD:
+ return GraphDirection.BACKWARD;
+ case BACKWARD:
+ return GraphDirection.FORWARD;
+ case BOTH:
+ return GraphDirection.BOTH;
+ default :
+ throw new IllegalArgumentException("Unknown direction: " + direction);
}
- throw new IllegalArgumentException("Unknown direction: " + request.getDirection());
}
- public static void simpleTraversal(SwhBidirectionalGraph bidirectionalGraph, TraversalRequest request,
- NodeObserver nodeObserver) {
- SwhUnidirectionalGraph g = getDirectedGraph(bidirectionalGraph, request);
- NodeFilterChecker nodeReturnChecker = new NodeFilterChecker(g, request.getReturnNodes());
- NodePropertyBuilder.NodeDataMask nodeDataMask = new NodePropertyBuilder.NodeDataMask(
- request.hasMask() ? request.getMask() : null);
-
- AllowedEdges allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*");
-
- Queue queue = new ArrayDeque<>();
- HashSet visited = new HashSet<>();
- request.getSrcList().forEach(srcSwhid -> {
- long srcNodeId = g.getNodeId(new SWHID(srcSwhid));
- queue.add(srcNodeId);
- visited.add(srcNodeId);
- });
- queue.add(-1L); // Depth sentinel
-
- long edgesAccessed = 0;
- long currentDepth = 0;
- while (!queue.isEmpty()) {
+ static class StopTraversalException extends RuntimeException {
+ }
+
+ static class BFSVisitor {
+ protected final SwhUnidirectionalGraph g;
+ protected long depth = 0;
+ protected long traversalSuccessors = 0;
+ protected long edgesAccessed = 0;
+
+ protected HashMap parents = new HashMap<>();
+ protected ArrayDeque queue = new ArrayDeque<>();
+ private long maxDepth = -1;
+ private long maxEdges = -1;
+
+ BFSVisitor(SwhUnidirectionalGraph g) {
+ this.g = g;
+ }
+
+ public void addSource(long nodeId) {
+ queue.add(nodeId);
+ parents.put(nodeId, -1L);
+ }
+
+ public void setMaxDepth(long depth) {
+ maxDepth = depth;
+ }
+
+ public void setMaxEdges(long edges) {
+ maxEdges = edges;
+ }
+
+ public void visitSetup() {
+ edgesAccessed = 0;
+ depth = 0;
+ queue.add(-1L); // depth sentinel
+ }
+
+ public void visit() {
+ visitSetup();
+ try {
+ while (!queue.isEmpty()) {
+ visitStep();
+ }
+ } catch (StopTraversalException e) {
+ // Ignore
+ }
+ }
+
+ public void visitStep() {
+ assert !queue.isEmpty();
long curr = queue.poll();
if (curr == -1L) {
- ++currentDepth;
+ ++depth;
if (!queue.isEmpty()) {
queue.add(-1L);
+ visitStep();
}
- continue;
+ return;
}
- if (request.hasMaxDepth() && currentDepth > request.getMaxDepth()) {
- break;
+ if (maxDepth >= 0 && depth > maxDepth) {
+ throw new StopTraversalException();
}
edgesAccessed += g.outdegree(curr);
- if (request.hasMaxEdges() && edgesAccessed >= request.getMaxEdges()) {
- break;
+ if (maxEdges >= 0 && edgesAccessed >= maxEdges) {
+ throw new StopTraversalException();
}
+ visitNode(curr);
+ }
- Node.Builder nodeBuilder = null;
- if (nodeReturnChecker.allowed(curr) && (!request.hasMinDepth() || currentDepth >= request.getMinDepth())) {
- nodeBuilder = Node.newBuilder();
- NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, curr);
- }
+ protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) {
+ return g.labelledSuccessors(nodeId);
+ }
- ArcLabelledNodeIterator.LabelledArcIterator it = filterLabelledSuccessors(g, curr, allowedEdges);
- long traversalSuccessors = 0;
+ protected void visitNode(long node) {
+ ArcLabelledNodeIterator.LabelledArcIterator it = getSuccessors(node);
+ traversalSuccessors = 0;
for (long succ; (succ = it.nextLong()) != -1;) {
traversalSuccessors++;
- if (!visited.contains(succ)) {
- queue.add(succ);
- visited.add(succ);
- }
- NodePropertyBuilder.buildSuccessorProperties(g, nodeDataMask, nodeBuilder, curr, succ, it.label());
+ visitEdge(node, succ, it.label());
+ }
+ }
+
+ protected void visitEdge(long src, long dst, Label label) {
+ if (!parents.containsKey(dst)) {
+ queue.add(dst);
+ parents.put(dst, src);
+ }
+ }
+ }
+
+ static class SimpleTraversal extends BFSVisitor {
+ private final NodeFilterChecker nodeReturnChecker;
+ private final AllowedEdges allowedEdges;
+ private final TraversalRequest request;
+ private final NodePropertyBuilder.NodeDataMask nodeDataMask;
+ private final NodeObserver nodeObserver;
+
+ private Node.Builder nodeBuilder;
+
+ SimpleTraversal(SwhBidirectionalGraph bidirectionalGraph, TraversalRequest request, NodeObserver nodeObserver) {
+ super(getDirectedGraph(bidirectionalGraph, request.getDirection()));
+ this.request = request;
+ this.nodeObserver = nodeObserver;
+ this.nodeReturnChecker = new NodeFilterChecker(g, request.getReturnNodes());
+ this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null);
+ this.allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*");
+ request.getSrcList().forEach(srcSwhid -> {
+ long srcNodeId = g.getNodeId(new SWHID(srcSwhid));
+ addSource(srcNodeId);
+ });
+ if (request.hasMaxDepth()) {
+ setMaxDepth(request.getMaxDepth());
+ }
+ if (request.hasMaxEdges()) {
+ setMaxEdges(request.getMaxEdges());
+ }
+ }
+
+ @Override
+ protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) {
+ return filterLabelledSuccessors(g, nodeId, allowedEdges);
+ }
+
+ @Override
+ public void visitNode(long node) {
+ nodeBuilder = null;
+ if (nodeReturnChecker.allowed(node) && (!request.hasMinDepth() || depth >= request.getMinDepth())) {
+ nodeBuilder = Node.newBuilder();
+ NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, node);
}
+ super.visitNode(node);
if (request.getReturnNodes().hasMinTraversalSuccessors()
&& traversalSuccessors < request.getReturnNodes().getMinTraversalSuccessors()
|| request.getReturnNodes().hasMaxTraversalSuccessors()
&& traversalSuccessors > request.getReturnNodes().getMaxTraversalSuccessors()) {
nodeBuilder = null;
}
if (nodeBuilder != null) {
nodeObserver.onNext(nodeBuilder.build());
}
}
+
+ @Override
+ protected void visitEdge(long src, long dst, Label label) {
+ super.visitEdge(src, dst, label);
+ NodePropertyBuilder.buildSuccessorProperties(g, nodeDataMask, nodeBuilder, src, dst, label);
+ }
+ }
+
+ static class FindPathTo extends BFSVisitor {
+ private final AllowedEdges allowedEdges;
+ private final FindPathToRequest request;
+ private final NodePropertyBuilder.NodeDataMask nodeDataMask;
+ private final NodeFilterChecker targetChecker;
+ private Long targetNode = null;
+
+ FindPathTo(SwhBidirectionalGraph bidirectionalGraph, FindPathToRequest request) {
+ super(getDirectedGraph(bidirectionalGraph, request.getDirection()));
+ this.request = request;
+ this.targetChecker = new NodeFilterChecker(g, request.getTarget());
+ this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null);
+ this.allowedEdges = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*");
+ if (request.hasMaxDepth()) {
+ setMaxDepth(request.getMaxDepth());
+ }
+ if (request.hasMaxEdges()) {
+ setMaxEdges(request.getMaxEdges());
+ }
+ request.getSrcList().forEach(srcSwhid -> {
+ long srcNodeId = g.getNodeId(new SWHID(srcSwhid));
+ addSource(srcNodeId);
+ });
+ }
+
+ @Override
+ protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) {
+ return filterLabelledSuccessors(g, nodeId, allowedEdges);
+ }
+
+ @Override
+ public void visitNode(long node) {
+ if (targetChecker.allowed(node)) {
+ targetNode = node;
+ throw new StopTraversalException();
+ }
+ super.visitNode(node);
+ }
+
+ public Path getPath() {
+ if (targetNode == null) {
+ return null;
+ }
+ Path.Builder pathBuilder = Path.newBuilder();
+ long curNode = targetNode;
+ ArrayList path = new ArrayList<>();
+ while (curNode != -1) {
+ path.add(curNode);
+ curNode = parents.get(curNode);
+ }
+ Collections.reverse(path);
+ for (long nodeId : path) {
+ Node.Builder nodeBuilder = Node.newBuilder();
+ NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, nodeId);
+ pathBuilder.addNode(nodeBuilder.build());
+ }
+ return pathBuilder.build();
+ }
+ }
+
+ static class FindPathBetween extends BFSVisitor {
+ private final FindPathBetweenRequest request;
+ private final NodePropertyBuilder.NodeDataMask nodeDataMask;
+ private final AllowedEdges allowedEdgesSrc;
+ private final AllowedEdges allowedEdgesDst;
+
+ private final BFSVisitor srcVisitor;
+ private final BFSVisitor dstVisitor;
+ private Long middleNode = null;
+
+ FindPathBetween(SwhBidirectionalGraph bidirectionalGraph, FindPathBetweenRequest request) {
+ super(getDirectedGraph(bidirectionalGraph, request.getDirection()));
+ this.request = request;
+ this.nodeDataMask = new NodePropertyBuilder.NodeDataMask(request.hasMask() ? request.getMask() : null);
+ this.allowedEdgesSrc = new AllowedEdges(request.hasEdges() ? request.getEdges() : "*");
+ this.allowedEdgesDst = request.hasEdgesReverse()
+ ? new AllowedEdges(request.getEdgesReverse())
+ : (request.hasEdges() ? new AllowedEdges(request.getEdges()).reverse() : new AllowedEdges("*"));
+ SwhUnidirectionalGraph srcGraph = getDirectedGraph(bidirectionalGraph, request.getDirection());
+ SwhUnidirectionalGraph dstGraph = getDirectedGraph(bidirectionalGraph,
+ request.hasDirectionReverse()
+ ? request.getDirectionReverse()
+ : reverseDirection(request.getDirection()));
+
+ this.srcVisitor = new BFSVisitor(srcGraph) {
+ @Override
+ protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) {
+ return filterLabelledSuccessors(g, nodeId, allowedEdgesSrc);
+ }
+
+ @Override
+ public void visitNode(long node) {
+ if (dstVisitor.parents.containsKey(node)) {
+ middleNode = node;
+ throw new StopTraversalException();
+ }
+ super.visitNode(node);
+ }
+ };
+ this.dstVisitor = new BFSVisitor(dstGraph) {
+ @Override
+ protected ArcLabelledNodeIterator.LabelledArcIterator getSuccessors(long nodeId) {
+ return filterLabelledSuccessors(g, nodeId, allowedEdgesDst);
+ }
+
+ @Override
+ public void visitNode(long node) {
+ if (dstVisitor.parents.containsKey(node)) {
+ middleNode = node;
+ throw new StopTraversalException();
+ }
+ super.visitNode(node);
+ }
+ };
+ if (request.hasMaxDepth()) {
+ this.srcVisitor.setMaxDepth(request.getMaxDepth());
+ this.dstVisitor.setMaxDepth(request.getMaxDepth());
+ }
+ if (request.hasMaxEdges()) {
+ this.srcVisitor.setMaxEdges(request.getMaxEdges());
+ this.dstVisitor.setMaxEdges(request.getMaxEdges());
+ }
+ request.getSrcList().forEach(srcSwhid -> {
+ long srcNodeId = g.getNodeId(new SWHID(srcSwhid));
+ srcVisitor.addSource(srcNodeId);
+ });
+ request.getDstList().forEach(srcSwhid -> {
+ long srcNodeId = g.getNodeId(new SWHID(srcSwhid));
+ dstVisitor.addSource(srcNodeId);
+ });
+ }
+
+ @Override
+ public void visit() {
+ srcVisitor.visitSetup();
+ dstVisitor.visitSetup();
+ while (!srcVisitor.queue.isEmpty() || !dstVisitor.queue.isEmpty()) {
+ try {
+ if (!srcVisitor.queue.isEmpty()) {
+ srcVisitor.visitStep();
+ }
+ } catch (StopTraversalException e) {
+ srcVisitor.queue.clear();
+ }
+ try {
+ if (!dstVisitor.queue.isEmpty()) {
+ dstVisitor.visitStep();
+ }
+ } catch (StopTraversalException e) {
+ dstVisitor.queue.clear();
+ }
+ }
+ }
+
+ public Path getPath() {
+ if (middleNode == null) {
+ return null;
+ }
+ Path.Builder pathBuilder = Path.newBuilder();
+ ArrayList path = new ArrayList<>();
+ long curNode = middleNode;
+ while (curNode != -1) {
+ path.add(curNode);
+ curNode = srcVisitor.parents.get(curNode);
+ }
+ Collections.reverse(path);
+ curNode = dstVisitor.parents.get(middleNode);
+ while (curNode != -1) {
+ path.add(curNode);
+ curNode = dstVisitor.parents.get(curNode);
+ }
+ for (long nodeId : path) {
+ Node.Builder nodeBuilder = Node.newBuilder();
+ NodePropertyBuilder.buildNodeProperties(g, nodeDataMask, nodeBuilder, nodeId);
+ pathBuilder.addNode(nodeBuilder.build());
+ }
+ return pathBuilder.build();
+ }
}
public interface NodeObserver {
void onNext(Node nodeId);
}
}
diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java
new file mode 100644
index 0000000..0e288bc
--- /dev/null
+++ b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathBetweenTest.java
@@ -0,0 +1,42 @@
+package org.softwareheritage.graph.rpc;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.softwareheritage.graph.SWHID;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class FindPathBetweenTest extends TraversalServiceTest {
+ private FindPathBetweenRequest.Builder getRequestBuilder(SWHID src, SWHID dst) {
+ return FindPathBetweenRequest.newBuilder().addSrc(src.toString()).addDst(dst.toString());
+ }
+
+ @Test
+ public void forwardRootToLeaf() {
+ ArrayList actual = getSWHIDs(
+ client.findPathBetween(getRequestBuilder(new SWHID(TEST_ORIGIN_ID), fakeSWHID("cnt", 4)).build()));
+ List expected = List.of(new SWHID(TEST_ORIGIN_ID), fakeSWHID("snp", 20), fakeSWHID("rev", 9),
+ fakeSWHID("dir", 8), fakeSWHID("dir", 6), fakeSWHID("cnt", 4));
+ Assertions.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void forwardRevToRev() {
+ ArrayList actual = getSWHIDs(
+ client.findPathBetween(getRequestBuilder(fakeSWHID("rev", 18), fakeSWHID("rev", 3)).build()));
+ List expected = List.of(fakeSWHID("rev", 18), fakeSWHID("rev", 13), fakeSWHID("rev", 9),
+ fakeSWHID("rev", 3));
+ Assertions.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void backwardRevToRev() {
+ ArrayList actual = getSWHIDs(
+ client.findPathBetween(getRequestBuilder(fakeSWHID("rev", 3), fakeSWHID("rev", 18))
+ .setDirection(GraphDirection.BACKWARD).build()));
+ List expected = List.of(fakeSWHID("rev", 3), fakeSWHID("rev", 9), fakeSWHID("rev", 13),
+ fakeSWHID("rev", 18));
+ Assertions.assertEquals(expected, actual);
+ }
+}
diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java
new file mode 100644
index 0000000..182f005
--- /dev/null
+++ b/java/src/test/java/org/softwareheritage/graph/rpc/FindPathToTest.java
@@ -0,0 +1,48 @@
+package org.softwareheritage.graph.rpc;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.softwareheritage.graph.SWHID;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class FindPathToTest extends TraversalServiceTest {
+ private FindPathToRequest.Builder getRequestBuilder(SWHID src, String allowedNodes) {
+ return FindPathToRequest.newBuilder().addSrc(src.toString())
+ .setTarget(NodeFilter.newBuilder().setTypes(allowedNodes).build());
+ }
+
+ @Test
+ public void forwardOriToFirstDir() {
+ ArrayList actual = getSWHIDs(
+ client.findPathTo(getRequestBuilder(new SWHID(TEST_ORIGIN_ID), "dir").build()));
+ List expected = List.of(new SWHID(TEST_ORIGIN_ID), fakeSWHID("snp", 20), fakeSWHID("rev", 9),
+ fakeSWHID("dir", 8));
+ Assertions.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void forwardRelToFirstCnt() {
+ ArrayList actual = getSWHIDs(client.findPathTo(getRequestBuilder(fakeSWHID("rel", 19), "cnt").build()));
+ List expected = List.of(fakeSWHID("rel", 19), fakeSWHID("rev", 18), fakeSWHID("dir", 17),
+ fakeSWHID("cnt", 14));
+ Assertions.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void backwardDirToFirstRel() {
+ ArrayList actual = getSWHIDs(client.findPathTo(
+ getRequestBuilder(fakeSWHID("dir", 16), "rel").setDirection(GraphDirection.BACKWARD).build()));
+ List expected = List.of(fakeSWHID("dir", 16), fakeSWHID("dir", 17), fakeSWHID("rev", 18),
+ fakeSWHID("rel", 19));
+ Assertions.assertEquals(expected, actual);
+ }
+
+ @Test
+ public void forwardCntToItself() {
+ ArrayList actual = getSWHIDs(client.findPathTo(getRequestBuilder(fakeSWHID("cnt", 4), "cnt").build()));
+ List expected = List.of(fakeSWHID("cnt", 4));
+ Assertions.assertEquals(expected, actual);
+ }
+}
diff --git a/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java b/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java
index f038327..b11c1fc 100644
--- a/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java
+++ b/java/src/test/java/org/softwareheritage/graph/rpc/TraversalServiceTest.java
@@ -1,50 +1,58 @@
package org.softwareheritage.graph.rpc;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import org.junit.Rule;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.softwareheritage.graph.GraphTest;
import org.softwareheritage.graph.SWHID;
import org.softwareheritage.graph.SwhBidirectionalGraph;
import java.util.ArrayList;
import java.util.Iterator;
public class TraversalServiceTest extends GraphTest {
@Rule
public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static Server server;
private static ManagedChannel channel;
protected static SwhBidirectionalGraph g;
protected static TraversalServiceGrpc.TraversalServiceBlockingStub client;
@BeforeAll
static void setup() throws Exception {
String serverName = InProcessServerBuilder.generateName();
g = GraphServer.loadGraph(getGraphPath().toString());
server = InProcessServerBuilder.forName(serverName).directExecutor()
.addService(new GraphServer.TraversalService(g.copy())).build().start();
channel = InProcessChannelBuilder.forName(serverName).directExecutor().build();
client = TraversalServiceGrpc.newBlockingStub(channel);
}
@AfterAll
static void teardown() {
channel.shutdownNow();
server.shutdownNow();
}
public ArrayList getSWHIDs(Iterator it) {
ArrayList res = new ArrayList<>();
it.forEachRemaining((Node n) -> {
res.add(new SWHID(n.getSwhid()));
});
return res;
}
+
+ public ArrayList getSWHIDs(Path p) {
+ ArrayList res = new ArrayList<>();
+ p.getNodeList().forEach((Node n) -> {
+ res.add(new SWHID(n.getSwhid()));
+ });
+ return res;
+ }
}
diff --git a/proto/swhgraph.proto b/proto/swhgraph.proto
index 7637941..3687c8e 100644
--- a/proto/swhgraph.proto
+++ b/proto/swhgraph.proto
@@ -1,129 +1,155 @@
syntax = "proto3";
import "google/protobuf/field_mask.proto";
option java_multiple_files = true;
option java_package = "org.softwareheritage.graph.rpc";
option java_outer_classname = "GraphService";
package swh.graph;
service TraversalService {
rpc Traverse (TraversalRequest) returns (stream Node);
+ rpc FindPathTo (FindPathToRequest) returns (Path);
+ rpc FindPathBetween (FindPathBetweenRequest) returns (Path);
rpc CountNodes (TraversalRequest) returns (CountResponse);
rpc CountEdges (TraversalRequest) returns (CountResponse);
rpc Stats (StatsRequest) returns (StatsResponse);
rpc CheckSwhid (CheckSwhidRequest) returns (CheckSwhidResponse);
rpc GetNode (GetNodeRequest) returns (Node);
}
enum GraphDirection {
FORWARD = 0;
BACKWARD = 1;
BOTH = 2;
}
message TraversalRequest {
repeated string src = 1;
-
- // Traversal options
GraphDirection direction = 2;
optional string edges = 3;
optional int64 max_edges = 4;
optional int64 min_depth = 5;
optional int64 max_depth = 6;
optional NodeFilter return_nodes = 7;
optional google.protobuf.FieldMask mask = 8;
}
+message FindPathToRequest {
+ repeated string src = 1;
+ optional NodeFilter target = 2;
+ GraphDirection direction = 3;
+ optional string edges = 4;
+ optional int64 max_edges = 5;
+ optional int64 max_depth = 6;
+ optional google.protobuf.FieldMask mask = 7;
+}
+
+message FindPathBetweenRequest {
+ repeated string src = 1;
+ repeated string dst = 2;
+ GraphDirection direction = 3;
+ optional GraphDirection direction_reverse = 4;
+ optional string edges = 5;
+ optional string edges_reverse = 6;
+ optional int64 max_edges = 7;
+ optional int64 max_depth = 8;
+ optional google.protobuf.FieldMask mask = 9;
+}
+
message NodeFilter {
optional string types = 1;
optional int64 min_traversal_successors = 2;
optional int64 max_traversal_successors = 3;
}
message Node {
string swhid = 1;
repeated Successor successor = 2;
oneof data {
ContentData cnt = 3;
RevisionData rev = 5;
ReleaseData rel = 6;
OriginData ori = 8;
};
}
+message Path {
+ repeated Node node = 1;
+}
+
message Successor {
optional string swhid = 1;
repeated EdgeLabel label = 2;
}
message ContentData {
optional int64 length = 1;
optional bool is_skipped = 2;
}
message RevisionData {
optional int64 author = 1;
optional int64 author_date = 2;
optional int32 author_date_offset = 3;
optional int64 committer = 4;
optional int64 committer_date = 5;
optional int32 committer_date_offset = 6;
optional bytes message = 7;
}
message ReleaseData {
optional int64 author = 1;
optional int64 author_date = 2;
optional int32 author_date_offset = 3;
optional bytes name = 4;
optional bytes message = 5;
}
message OriginData {
optional string url = 1;
}
message EdgeLabel {
bytes name = 1;
int32 permission = 2;
}
message CountResponse {
int64 count = 1;
}
message StatsRequest {
}
message StatsResponse {
int64 num_nodes = 1;
int64 num_edges = 2;
double compression = 3;
double bits_per_node = 4;
double bits_per_edge = 5;
double avg_locality = 6;
int64 indegree_min = 7;
int64 indegree_max = 8;
double indegree_avg = 9;
int64 outdegree_min = 10;
int64 outdegree_max = 11;
double outdegree_avg = 12;
}
message CheckSwhidRequest {
string swhid = 1;
}
message GetNodeRequest {
string swhid = 1;
optional google.protobuf.FieldMask mask = 8;
}
message CheckSwhidResponse {
bool exists = 1;
string details = 2;
}