diff --git a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java
index 3fb08f0..0420a41 100644
--- a/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java
+++ b/java/src/main/java/org/softwareheritage/graph/algo/Traversal.java
@@ -1,494 +1,494 @@
package org.softwareheritage.graph.algo;
import java.util.*;
import it.unimi.dsi.bits.LongArrayBitVector;
import org.softwareheritage.graph.AllowedEdges;
import org.softwareheritage.graph.Endpoint;
import org.softwareheritage.graph.Graph;
import org.softwareheritage.graph.Neighbors;
import org.softwareheritage.graph.Node;
/**
* Traversal algorithms on the compressed graph.
*
* Internal implementation of the traversal API endpoints. These methods only input/output internal
* long ids, which are converted in the {@link Endpoint} higher-level class to Software Heritage
* PID.
*
* @author The Software Heritage developers
* @see org.softwareheritage.graph.Endpoint
*/
public class Traversal {
/** Graph used in the traversal */
Graph graph;
/** Boolean to specify the use of the transposed graph */
boolean useTransposed;
/** Graph edge restriction */
AllowedEdges edges;
/** Hash set storing if we have visited a node */
HashSet visited;
/** Hash map storing parent node id for each nodes during a traversal */
Map parentNode;
/** Number of edges accessed during traversal */
long nbEdgesAccessed;
/** random number generator, for random walks */
Random rng;
/**
* Constructor.
*
* @param graph graph used in the traversal
* @param direction a string (either "forward" or "backward") specifying edge orientation
* @param edgesFmt a formatted string describing allowed edges
*/
public Traversal(Graph graph, String direction, String edgesFmt) {
if (!direction.matches("forward|backward")) {
throw new IllegalArgumentException("Unknown traversal direction: " + direction);
}
this.graph = graph;
this.useTransposed = (direction.equals("backward"));
this.edges = new AllowedEdges(graph, edgesFmt);
long nbNodes = graph.getNbNodes();
this.visited = new HashSet<>();
this.parentNode = new HashMap<>();
this.nbEdgesAccessed = 0;
this.rng = new Random();
}
/**
* Returns number of accessed edges during traversal.
*
* @return number of edges accessed in last traversal
*/
public long getNbEdgesAccessed() {
return nbEdgesAccessed;
}
/**
* Returns number of accessed nodes during traversal.
*
* @return number of nodes accessed in last traversal
*/
public long getNbNodesAccessed() {
return this.visited.size();
}
/**
* Push version of {@link leaves}: will fire passed callback for each leaf.
*/
public void leavesVisitor(long srcNodeId, NodeIdConsumer cb) {
Stack stack = new Stack();
this.nbEdgesAccessed = 0;
stack.push(srcNodeId);
visited.add(srcNodeId);
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
long neighborsCnt = 0;
nbEdgesAccessed += graph.degree(currentNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) {
neighborsCnt++;
if (!visited.contains(neighborNodeId)) {
stack.push(neighborNodeId);
visited.add(neighborNodeId);
}
}
if (neighborsCnt == 0) {
cb.accept(currentNodeId);
}
}
}
/**
* Returns the leaves of a subgraph rooted at the specified source node.
*
* @param srcNodeId source node
* @return list of node ids corresponding to the leaves
*/
public ArrayList leaves(long srcNodeId) {
ArrayList nodeIds = new ArrayList();
leavesVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId));
return nodeIds;
}
/**
* Push version of {@link neighbors}: will fire passed callback on each
* neighbor.
*/
public void neighborsVisitor(long srcNodeId, NodeIdConsumer cb) {
this.nbEdgesAccessed = graph.degree(srcNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, srcNodeId)) {
cb.accept(neighborNodeId);
}
}
/**
* Returns node direct neighbors (linked with exactly one edge).
*
* @param srcNodeId source node
* @return list of node ids corresponding to the neighbors
*/
public ArrayList neighbors(long srcNodeId) {
ArrayList nodeIds = new ArrayList();
neighborsVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId));
return nodeIds;
}
/**
* Push version of {@link visitNodes}: will fire passed callback on each
* visited node.
*/
public void visitNodesVisitor(long srcNodeId, NodeIdConsumer nodeCb, EdgeIdConsumer edgeCb) {
Stack stack = new Stack();
this.nbEdgesAccessed = 0;
stack.push(srcNodeId);
visited.add(srcNodeId);
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
if (nodeCb != null) {
nodeCb.accept(currentNodeId);
}
nbEdgesAccessed += graph.degree(currentNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) {
+ if (edgeCb != null) {
+ edgeCb.accept(currentNodeId, neighborNodeId);
+ }
if (!visited.contains(neighborNodeId)) {
stack.push(neighborNodeId);
visited.add(neighborNodeId);
- if (edgeCb != null) {
- edgeCb.accept(currentNodeId, neighborNodeId);
- }
}
}
}
}
/** One-argument version to handle callbacks properly */
public void visitNodesVisitor(long srcNodeId, NodeIdConsumer cb) {
visitNodesVisitor(srcNodeId, cb, null);
}
/**
* Performs a graph traversal and returns explored nodes.
*
* @param srcNodeId source node
* @return list of explored node ids
*/
public ArrayList visitNodes(long srcNodeId) {
ArrayList nodeIds = new ArrayList();
visitNodesVisitor(srcNodeId, (nodeId) -> nodeIds.add(nodeId));
return nodeIds;
}
/**
* Push version of {@link visitPaths}: will fire passed callback on each
* discovered (complete) path.
*/
public void visitPathsVisitor(long srcNodeId, PathConsumer cb) {
Stack currentPath = new Stack();
this.nbEdgesAccessed = 0;
visitPathsInternalVisitor(srcNodeId, currentPath, cb);
}
/**
* Performs a graph traversal and returns explored paths.
*
* @param srcNodeId source node
* @return list of explored paths (represented as a list of node ids)
*/
public ArrayList> visitPaths(long srcNodeId) {
ArrayList> paths = new ArrayList<>();
visitPathsVisitor(srcNodeId, (path) -> paths.add(path));
return paths;
}
private void visitPathsInternalVisitor(long currentNodeId,
Stack currentPath,
PathConsumer cb) {
currentPath.push(currentNodeId);
long visitedNeighbors = 0;
nbEdgesAccessed += graph.degree(currentNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) {
visitPathsInternalVisitor(neighborNodeId, currentPath, cb);
visitedNeighbors++;
}
if (visitedNeighbors == 0) {
ArrayList path = new ArrayList();
for (long nodeId : currentPath) {
path.add(nodeId);
}
cb.accept(path);
}
currentPath.pop();
}
/**
* Performs a graph traversal with backtracking, and returns the first
* found path from source to destination.
*
* @param srcNodeId source node
* @param dst destination (either a node or a node type)
* @return found path as a list of node ids
*/
public ArrayList walk(long srcNodeId, T dst, String visitOrder) {
long dstNodeId = -1;
if (visitOrder.equals("dfs")) {
dstNodeId = walkInternalDFS(srcNodeId, dst);
} else if (visitOrder.equals("bfs")) {
dstNodeId = walkInternalBFS(srcNodeId, dst);
} else {
throw new IllegalArgumentException("Unknown visit order: " + visitOrder);
}
if (dstNodeId == -1) {
throw new IllegalArgumentException("Cannot find destination: " + dst);
}
ArrayList nodeIds = backtracking(srcNodeId, dstNodeId);
return nodeIds;
}
/**
* Performs a random walk (picking a random successor at each step) from
* source to destination.
*
* @param srcNodeId source node
* @param dst destination (either a node or a node type)
* @return found path as a list of node ids or an empty path to indicate
* that no suitable path have been found
*/
public ArrayList randomWalk(long srcNodeId, T dst) {
return randomWalk(srcNodeId, dst, 0);
}
/**
* Performs a stubborn random walk (picking a random successor at each
* step) from source to destination. The walk is "stubborn" in the sense
* that it will not give up the first time if a satisfying target node is
* found, but it will retry up to a limited amount of times.
*
* @param srcNodeId source node
* @param dst destination (either a node or a node type)
* @param retries number of times to retry; 0 means no retries (single walk)
* @return found path as a list of node ids or an empty path to indicate
* that no suitable path have been found
*/
public ArrayList randomWalk(long srcNodeId, T dst, int retries) {
long curNodeId = srcNodeId;
ArrayList path = new ArrayList();
this.nbEdgesAccessed = 0;
boolean found;
if (retries < 0) {
throw new IllegalArgumentException("Negative number of retries given: " + retries);
}
while (true) {
path.add(curNodeId);
Neighbors neighbors = new Neighbors(graph, useTransposed, edges, curNodeId);
curNodeId = randomPick(neighbors.iterator());
if (curNodeId < 0) {
found = false;
break;
}
if (isDstNode(curNodeId, dst)) {
path.add(curNodeId);
found = true;
break;
}
}
if (found) {
return path;
} else if (retries > 0) { // try again
return randomWalk(srcNodeId, dst, retries - 1);
} else { // not found and no retries left
path.clear();
return path;
}
}
/**
* Randomly choose an element from an iterator over Longs using reservoir
* sampling
*
* @param elements iterator over selection domain
* @return randomly chosen element or -1 if no suitable element was found
*/
private long randomPick(Iterator elements) {
long curPick = -1;
long seenCandidates = 0;
while (elements.hasNext()) {
seenCandidates++;
if (Math.round(rng.nextFloat() * (seenCandidates - 1)) == 0) {
curPick = elements.next();
}
}
return curPick;
}
/**
* Internal DFS function of {@link #walk}.
*
* @param srcNodeId source node
* @param dst destination (either a node or a node type)
* @return final destination node or -1 if no path found
*/
private long walkInternalDFS(long srcNodeId, T dst) {
Stack stack = new Stack();
this.nbEdgesAccessed = 0;
stack.push(srcNodeId);
visited.add(srcNodeId);
while (!stack.isEmpty()) {
long currentNodeId = stack.pop();
if (isDstNode(currentNodeId, dst)) {
return currentNodeId;
}
nbEdgesAccessed += graph.degree(currentNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) {
if (!visited.contains(neighborNodeId)) {
stack.push(neighborNodeId);
visited.add(neighborNodeId);
parentNode.put(neighborNodeId, currentNodeId);
}
}
}
return -1;
}
/**
* Internal BFS function of {@link #walk}.
*
* @param srcNodeId source node
* @param dst destination (either a node or a node type)
* @return final destination node or -1 if no path found
*/
private long walkInternalBFS(long srcNodeId, T dst) {
Queue queue = new LinkedList();
this.nbEdgesAccessed = 0;
queue.add(srcNodeId);
visited.add(srcNodeId);
while (!queue.isEmpty()) {
long currentNodeId = queue.poll();
if (isDstNode(currentNodeId, dst)) {
return currentNodeId;
}
nbEdgesAccessed += graph.degree(currentNodeId, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, currentNodeId)) {
if (!visited.contains(neighborNodeId)) {
queue.add(neighborNodeId);
visited.add(neighborNodeId);
parentNode.put(neighborNodeId, currentNodeId);
}
}
}
return -1;
}
/**
* Internal function of {@link #walk} to check if a node corresponds to the destination.
*
* @param nodeId current node
* @param dst destination (either a node or a node type)
* @return true if the node is a destination, or false otherwise
*/
private boolean isDstNode(long nodeId, T dst) {
if (dst instanceof Long) {
long dstNodeId = (Long) dst;
return nodeId == dstNodeId;
} else if (dst instanceof Node.Type) {
Node.Type dstType = (Node.Type) dst;
return graph.getNodeType(nodeId) == dstType;
} else {
return false;
}
}
/**
* Internal backtracking function of {@link #walk}.
*
* @param srcNodeId source node
* @param dstNodeId destination node
* @return the found path, as a list of node ids
*/
private ArrayList backtracking(long srcNodeId, long dstNodeId) {
ArrayList path = new ArrayList();
long currentNodeId = dstNodeId;
while (currentNodeId != srcNodeId) {
path.add(currentNodeId);
currentNodeId = parentNode.get(currentNodeId);
}
path.add(srcNodeId);
Collections.reverse(path);
return path;
}
public Long findCommonDescendant(long lhsNode, long rhsNode) {
Queue lhsStack = new ArrayDeque<>();
Queue rhsStack = new ArrayDeque<>();
HashSet lhsVisited = new HashSet<>();
HashSet rhsVisited = new HashSet<>();
lhsStack.add(lhsNode);
rhsStack.add(rhsNode);
lhsVisited.add(lhsNode);
rhsVisited.add(rhsNode);
this.nbEdgesAccessed = 0;
Long curNode;
while (!lhsStack.isEmpty() || !rhsStack.isEmpty()) {
if (!lhsStack.isEmpty()) {
curNode = lhsStack.poll();
nbEdgesAccessed += graph.degree(curNode, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, curNode)) {
if (!lhsVisited.contains(neighborNodeId)) {
if (rhsVisited.contains(neighborNodeId))
return neighborNodeId;
lhsStack.add(neighborNodeId);
lhsVisited.add(neighborNodeId);
}
}
}
if (!rhsStack.isEmpty()) {
curNode = rhsStack.poll();
nbEdgesAccessed += graph.degree(curNode, useTransposed);
for (long neighborNodeId : new Neighbors(graph, useTransposed, edges, curNode)) {
if (!rhsVisited.contains(neighborNodeId)) {
if (lhsVisited.contains(neighborNodeId))
return neighborNodeId;
rhsStack.add(neighborNodeId);
rhsVisited.add(neighborNodeId);
}
}
}
}
return null;
}
}
diff --git a/swh/graph/tests/test_api_client.py b/swh/graph/tests/test_api_client.py
index 642cf6e..96279d0 100644
--- a/swh/graph/tests/test_api_client.py
+++ b/swh/graph/tests/test_api_client.py
@@ -1,258 +1,305 @@
import pytest
from pytest import raises
from swh.core.api import RemoteException
def test_stats(graph_client):
stats = graph_client.stats()
assert set(stats.keys()) == {"counts", "ratios", "indegree", "outdegree"}
assert set(stats["counts"].keys()) == {"nodes", "edges"}
assert set(stats["ratios"].keys()) == {
"compression",
"bits_per_node",
"bits_per_edge",
"avg_locality",
}
assert set(stats["indegree"].keys()) == {"min", "max", "avg"}
assert set(stats["outdegree"].keys()) == {"min", "max", "avg"}
assert stats["counts"]["nodes"] == 21
assert stats["counts"]["edges"] == 23
assert isinstance(stats["ratios"]["compression"], float)
assert isinstance(stats["ratios"]["bits_per_node"], float)
assert isinstance(stats["ratios"]["bits_per_edge"], float)
assert isinstance(stats["ratios"]["avg_locality"], float)
assert stats["indegree"]["min"] == 0
assert stats["indegree"]["max"] == 3
assert isinstance(stats["indegree"]["avg"], float)
assert stats["outdegree"]["min"] == 0
assert stats["outdegree"]["max"] == 3
assert isinstance(stats["outdegree"]["avg"], float)
def test_leaves(graph_client):
actual = list(
graph_client.leaves("swh:1:ori:0000000000000000000000000000000000000021")
)
expected = [
"swh:1:cnt:0000000000000000000000000000000000000001",
"swh:1:cnt:0000000000000000000000000000000000000004",
"swh:1:cnt:0000000000000000000000000000000000000005",
"swh:1:cnt:0000000000000000000000000000000000000007",
]
assert set(actual) == set(expected)
def test_neighbors(graph_client):
actual = list(
graph_client.neighbors(
"swh:1:rev:0000000000000000000000000000000000000009", direction="backward"
)
)
expected = [
"swh:1:snp:0000000000000000000000000000000000000020",
"swh:1:rel:0000000000000000000000000000000000000010",
"swh:1:rev:0000000000000000000000000000000000000013",
]
assert set(actual) == set(expected)
def test_visit_nodes(graph_client):
actual = list(
graph_client.visit_nodes(
"swh:1:rel:0000000000000000000000000000000000000010",
edges="rel:rev,rev:rev",
)
)
expected = [
"swh:1:rel:0000000000000000000000000000000000000010",
"swh:1:rev:0000000000000000000000000000000000000009",
"swh:1:rev:0000000000000000000000000000000000000003",
]
assert set(actual) == set(expected)
def test_visit_edges(graph_client):
actual = list(
graph_client.visit_edges(
"swh:1:rel:0000000000000000000000000000000000000010",
edges="rel:rev,rev:rev,rev:dir",
)
)
expected = [
(
"swh:1:rel:0000000000000000000000000000000000000010",
"swh:1:rev:0000000000000000000000000000000000000009",
),
(
"swh:1:rev:0000000000000000000000000000000000000009",
"swh:1:rev:0000000000000000000000000000000000000003",
),
(
"swh:1:rev:0000000000000000000000000000000000000009",
"swh:1:dir:0000000000000000000000000000000000000008",
),
(
"swh:1:rev:0000000000000000000000000000000000000003",
"swh:1:dir:0000000000000000000000000000000000000002",
),
]
assert set(actual) == set(expected)
+def test_visit_edges_diamond_pattern(graph_client):
+ actual = list(
+ graph_client.visit_edges(
+ "swh:1:rev:0000000000000000000000000000000000000009", edges="*",
+ )
+ )
+ expected = [
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000009",
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ ),
+ (
+ "swh:1:rev:0000000000000000000000000000000000000003",
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000002",
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:cnt:0000000000000000000000000000000000000001",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:cnt:0000000000000000000000000000000000000007",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000008",
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ "swh:1:cnt:0000000000000000000000000000000000000004",
+ ),
+ (
+ "swh:1:dir:0000000000000000000000000000000000000006",
+ "swh:1:cnt:0000000000000000000000000000000000000005",
+ ),
+ ]
+ assert set(actual) == set(expected)
+
+
def test_visit_paths(graph_client):
actual = list(
graph_client.visit_paths(
"swh:1:snp:0000000000000000000000000000000000000020", edges="snp:*,rev:*"
)
)
actual = [tuple(path) for path in actual]
expected = [
(
"swh:1:snp:0000000000000000000000000000000000000020",
"swh:1:rev:0000000000000000000000000000000000000009",
"swh:1:rev:0000000000000000000000000000000000000003",
"swh:1:dir:0000000000000000000000000000000000000002",
),
(
"swh:1:snp:0000000000000000000000000000000000000020",
"swh:1:rev:0000000000000000000000000000000000000009",
"swh:1:dir:0000000000000000000000000000000000000008",
),
(
"swh:1:snp:0000000000000000000000000000000000000020",
"swh:1:rel:0000000000000000000000000000000000000010",
),
]
assert set(actual) == set(expected)
@pytest.mark.skip(reason="currently disabled due to T1969")
def test_walk(graph_client):
args = ("swh:1:dir:0000000000000000000000000000000000000016", "rel")
kwargs = {
"edges": "dir:dir,dir:rev,rev:*",
"direction": "backward",
"traversal": "bfs",
}
actual = list(graph_client.walk(*args, **kwargs))
expected = [
"swh:1:dir:0000000000000000000000000000000000000016",
"swh:1:dir:0000000000000000000000000000000000000017",
"swh:1:rev:0000000000000000000000000000000000000018",
"swh:1:rel:0000000000000000000000000000000000000019",
]
assert set(actual) == set(expected)
kwargs2 = kwargs.copy()
kwargs2["limit"] = -1
actual = list(graph_client.walk(*args, **kwargs2))
expected = ["swh:1:rel:0000000000000000000000000000000000000019"]
assert set(actual) == set(expected)
kwargs2 = kwargs.copy()
kwargs2["limit"] = 2
actual = list(graph_client.walk(*args, **kwargs2))
expected = [
"swh:1:dir:0000000000000000000000000000000000000016",
"swh:1:dir:0000000000000000000000000000000000000017",
]
assert set(actual) == set(expected)
def test_random_walk(graph_client):
"""as the walk is random, we test a visit from a cnt node to the only origin in
the dataset, and only check the final node of the path (i.e., the origin)
"""
args = ("swh:1:cnt:0000000000000000000000000000000000000001", "ori")
kwargs = {"direction": "backward"}
expected_root = "swh:1:ori:0000000000000000000000000000000000000021"
actual = list(graph_client.random_walk(*args, **kwargs))
assert len(actual) > 1 # no origin directly links to a content
assert actual[0] == args[0]
assert actual[-1] == expected_root
kwargs2 = kwargs.copy()
kwargs2["limit"] = -1
actual = list(graph_client.random_walk(*args, **kwargs2))
assert actual == [expected_root]
kwargs2["limit"] = -2
actual = list(graph_client.random_walk(*args, **kwargs2))
assert len(actual) == 2
assert actual[-1] == expected_root
kwargs2["limit"] = 3
actual = list(graph_client.random_walk(*args, **kwargs2))
assert len(actual) == 3
def test_count(graph_client):
actual = graph_client.count_leaves(
"swh:1:ori:0000000000000000000000000000000000000021"
)
assert actual == 4
actual = graph_client.count_visit_nodes(
"swh:1:rel:0000000000000000000000000000000000000010", edges="rel:rev,rev:rev"
)
assert actual == 3
actual = graph_client.count_neighbors(
"swh:1:rev:0000000000000000000000000000000000000009", direction="backward"
)
assert actual == 3
def test_param_validation(graph_client):
with raises(RemoteException) as exc_info: # PID not found
list(graph_client.leaves("swh:1:ori:fff0000000000000000000000000000000000021"))
assert exc_info.value.response.status_code == 404
with raises(RemoteException) as exc_info: # malformed PID
list(
graph_client.neighbors("swh:1:ori:fff000000zzzzzz0000000000000000000000021")
)
assert exc_info.value.response.status_code == 400
with raises(RemoteException) as exc_info: # malformed edge specificaiton
list(
graph_client.visit_nodes(
"swh:1:dir:0000000000000000000000000000000000000016",
edges="dir:notanodetype,dir:rev,rev:*",
direction="backward",
)
)
assert exc_info.value.response.status_code == 400
with raises(RemoteException) as exc_info: # malformed direction
list(
graph_client.visit_nodes(
"swh:1:dir:0000000000000000000000000000000000000016",
edges="dir:dir,dir:rev,rev:*",
direction="notadirection",
)
)
assert exc_info.value.response.status_code == 400
@pytest.mark.skip(reason="currently disabled due to T1969")
def test_param_validation_walk(graph_client):
"""test validation of walk-specific parameters only"""
with raises(RemoteException) as exc_info: # malformed traversal order
list(
graph_client.walk(
"swh:1:dir:0000000000000000000000000000000000000016",
"rel",
edges="dir:dir,dir:rev,rev:*",
direction="backward",
traversal="notatraversalorder",
)
)
assert exc_info.value.response.status_code == 400