diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
index 3b6c060..e055f7d 100644
--- a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
+++ b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
@@ -1,393 +1,404 @@
package org.softwareheritage.graph.compress;
import com.github.luben.zstd.ZstdOutputStream;
import com.martiansoftware.jsap.*;
import it.unimi.dsi.logging.ProgressLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.softwareheritage.graph.Node;
import org.softwareheritage.graph.utils.Sort;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;
/**
* Read a graph dataset and extract all the unique node SWHIDs it contains, including the ones that
* are not stored as actual objects in the graph, but only referred to by the edges.
* Additionally, extract the set of all unique edge labels in the graph.
*
*
* - The set of nodes is written in
${outputBasename}.nodes.csv.zst
, as a
* zst-compressed sorted list of SWHIDs, one per line.
* - The set of edge labels is written in
${outputBasename}.labels.csv.zst
, as a
* zst-compressed sorted list of labels encoded in base64, one per line.
* - The number of unique nodes referred to in the graph is written in a text file,
*
${outputBasename}.nodes.count.txt
* - The number of unique edges referred to in the graph is written in a text file,
*
${outputBasename}.edges.count.txt
* - The number of unique edge labels is written in a text file,
*
${outputBasename}.labels.count.txt
* - Statistics on the number of nodes of each type are written in a text file,
*
${outputBasename}.nodes.stats.txt
* - Statistics on the number of edges of each type are written in a text file,
*
${outputBasename}.edges.stats.txt
*
*
*
* Rationale: Because the graph can contain holes, loose objects and dangling
* objects, some nodes that are referred to as destinations in the edge relationships might not
* actually be stored in the graph itself. However, to compress the graph using a graph compression
* library, it is necessary to have a list of all the nodes in the graph, including the
* ones that are simply referred to by the edges but not actually stored as concrete objects.
*
*
*
* This class reads the entire graph dataset, and uses sort -u
to extract the set of
* all the unique nodes and unique labels that will be needed as an input for the compression
* process.
*
*/
public class ExtractNodes {
private final static Logger logger = LoggerFactory.getLogger(ExtractNodes.class);
// Create one thread per processor.
final static int numThreads = Runtime.getRuntime().availableProcessors();
// Allocate up to 20% of maximum memory for sorting subprocesses.
final static long sortBufferSize = (long) (Runtime.getRuntime().maxMemory() * 0.2 / numThreads / 2);
private static JSAPResult parseArgs(String[] args) {
JSAPResult config = null;
try {
SimpleJSAP jsap = new SimpleJSAP(ComposePermutations.class.getName(), "", new Parameter[]{
new UnflaggedOption("dataset", JSAP.STRING_PARSER, JSAP.REQUIRED, "Path to the edges dataset"),
new UnflaggedOption("outputBasename", JSAP.STRING_PARSER, JSAP.REQUIRED,
"Basename of the output files"),
new FlaggedOption("format", JSAP.STRING_PARSER, "orc", JSAP.NOT_REQUIRED, 'f', "format",
"Format of the input dataset (orc, csv)"),
new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, String.valueOf(sortBufferSize) + "b",
JSAP.NOT_REQUIRED, 'S', "sort-buffer-size",
"Size of the memory buffer used by each sort process"),
new FlaggedOption("sortTmpDir", JSAP.STRING_PARSER, null, JSAP.NOT_REQUIRED, 'T', "temp-dir",
"Path to the temporary directory used by sort")});
config = jsap.parse(args);
if (jsap.messagePrinted()) {
System.exit(1);
}
} catch (JSAPException e) {
System.err.println("Usage error: " + e.getMessage());
System.exit(1);
}
return config;
}
public static void main(String[] args) throws IOException, InterruptedException {
JSAPResult parsedArgs = parseArgs(args);
String datasetPath = parsedArgs.getString("dataset");
String outputBasename = parsedArgs.getString("outputBasename");
String datasetFormat = parsedArgs.getString("format");
String sortBufferSize = parsedArgs.getString("sortBufferSize");
- String sortTmpDir = parsedArgs.getString("sortTmpDir", null);
+ String sortTmpPath = parsedArgs.getString("sortTmpDir", null);
- (new File(sortTmpDir)).mkdirs();
+ File sortTmpDir = new File(sortTmpPath);
+ sortTmpDir.mkdirs();
// Open edge dataset
GraphDataset dataset;
if (datasetFormat.equals("orc")) {
dataset = new ORCGraphDataset(datasetPath);
} else if (datasetFormat.equals("csv")) {
dataset = new CSVEdgeDataset(datasetPath);
} else {
throw new IllegalArgumentException("Unknown dataset format: " + datasetFormat);
}
extractNodes(dataset, outputBasename, sortBufferSize, sortTmpDir);
}
- public static void extractNodes(GraphDataset dataset, String outputBasename, String sortBufferSize,
- String sortTmpDir) throws IOException, InterruptedException {
+ public static void extractNodes(GraphDataset dataset, String outputBasename, String sortBufferSize, File sortTmpDir)
+ throws IOException, InterruptedException {
// Read the dataset and write the nodes and labels to the sorting processes
AtomicLong edgeCount = new AtomicLong(0);
AtomicLongArray edgeCountByType = new AtomicLongArray(Node.Type.values().length * Node.Type.values().length);
int numThreads = Runtime.getRuntime().availableProcessors();
ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
Process[] nodeSorters = new Process[numThreads];
- String[] nodeBatchPaths = new String[numThreads];
+ File[] nodeBatchPaths = new File[numThreads];
Process[] labelSorters = new Process[numThreads];
- String[] labelBatchPaths = new String[numThreads];
+ File[] labelBatches = new File[numThreads];
long[] progressCounts = new long[numThreads];
AtomicInteger nextThreadId = new AtomicInteger(0);
ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement);
ProgressLogger pl = new ProgressLogger(logger, 10, TimeUnit.SECONDS);
pl.itemsName = "edges";
pl.start("Reading node/edge files and writing sorted batches.");
GraphDataset.NodeCallback nodeCallback = (node) -> {
int threadId = threadLocalId.get();
if (nodeSorters[threadId] == null) {
- nodeBatchPaths[threadId] = sortTmpDir + "/nodes." + threadId + ".txt";
- nodeSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir,
- List.of("-o", nodeBatchPaths[threadId]));
+ nodeBatchPaths[threadId] = File.createTempFile("nodes", ".txt", sortTmpDir);
+ nodeSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir.getPath(),
+ List.of("-o", nodeBatchPaths[threadId].getPath()));
}
OutputStream nodeOutputStream = nodeSorters[threadId].getOutputStream();
nodeOutputStream.write(node);
nodeOutputStream.write('\n');
};
GraphDataset.NodeCallback labelCallback = (label) -> {
int threadId = threadLocalId.get();
if (labelSorters[threadId] == null) {
- labelBatchPaths[threadId] = sortTmpDir + "/labels." + threadId + ".txt";
- labelSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir,
- List.of("-o", labelBatchPaths[threadId]));
+ labelBatches[threadId] = File.createTempFile("labels", ".txt", sortTmpDir);
+ labelSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir.getPath(),
+ List.of("-o", labelBatches[threadId].getPath()));
}
OutputStream labelOutputStream = labelSorters[threadId].getOutputStream();
labelOutputStream.write(label);
labelOutputStream.write('\n');
};
try {
forkJoinPool.submit(() -> {
try {
dataset.readEdges((node) -> {
nodeCallback.onNode(node);
}, (src, dst, label, perm) -> {
nodeCallback.onNode(src);
nodeCallback.onNode(dst);
if (label != null) {
labelCallback.onNode(label);
}
edgeCount.incrementAndGet();
// Extract type of src and dst from their SWHID: swh:1:XXX
byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3);
byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3);
int srcType = Node.Type.byteNameToInt(srcTypeBytes);
int dstType = Node.Type.byteNameToInt(dstTypeBytes);
if (srcType != -1 && dstType != -1) {
edgeCountByType.incrementAndGet(srcType * Node.Type.values().length + dstType);
} else {
System.err.println("Invalid edge type: " + new String(srcTypeBytes) + " -> "
+ new String(dstTypeBytes));
System.exit(1);
}
int threadId = threadLocalId.get();
if (++progressCounts[threadId] > 1000) {
synchronized (pl) {
pl.update(progressCounts[threadId]);
}
progressCounts[threadId] = 0;
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}).get();
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
// Close all the sorters stdin
for (int i = 0; i < numThreads; i++) {
if (nodeSorters[i] != null) {
nodeSorters[i].getOutputStream().close();
}
if (labelSorters[i] != null) {
labelSorters[i].getOutputStream().close();
}
}
// Wait for sorting processes to finish
for (int i = 0; i < numThreads; i++) {
if (nodeSorters[i] != null) {
nodeSorters[i].waitFor();
}
if (labelSorters[i] != null) {
labelSorters[i].waitFor();
}
}
pl.done();
ArrayList nodeSortMergerOptions = new ArrayList<>(List.of("-m"));
ArrayList labelSortMergerOptions = new ArrayList<>(List.of("-m"));
for (int i = 0; i < numThreads; i++) {
if (nodeBatchPaths[i] != null) {
- nodeSortMergerOptions.add(nodeBatchPaths[i]);
+ nodeSortMergerOptions.add(nodeBatchPaths[i].getPath());
}
- if (labelBatchPaths[i] != null) {
- labelSortMergerOptions.add(labelBatchPaths[i]);
+ if (labelBatches[i] != null) {
+ labelSortMergerOptions.add(labelBatches[i].getPath());
}
}
// Spawn node merge-sorting process
- Process nodeSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, nodeSortMergerOptions);
+ Process nodeSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir.getPath(), nodeSortMergerOptions);
nodeSortMerger.getOutputStream().close();
OutputStream nodesFileOutputStream = new ZstdOutputStream(
new BufferedOutputStream(new FileOutputStream(outputBasename + ".nodes.csv.zst")));
NodesOutputThread nodesOutputThread = new NodesOutputThread(
new BufferedInputStream(nodeSortMerger.getInputStream()), nodesFileOutputStream);
nodesOutputThread.start();
// Spawn label merge-sorting process
- Process labelSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, labelSortMergerOptions);
+ Process labelSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir.getPath(), labelSortMergerOptions);
labelSortMerger.getOutputStream().close();
OutputStream labelsFileOutputStream = new ZstdOutputStream(
new BufferedOutputStream(new FileOutputStream(outputBasename + ".labels.csv.zst")));
LabelsOutputThread labelsOutputThread = new LabelsOutputThread(
new BufferedInputStream(labelSortMerger.getInputStream()), labelsFileOutputStream);
labelsOutputThread.start();
pl.logger().info("Waiting for merge-sort and writing output files...");
nodeSortMerger.waitFor();
labelSortMerger.waitFor();
nodesOutputThread.join();
labelsOutputThread.join();
long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length];
for (int i = 0; i < edgeCountByTypeArray.length; i++) {
for (int j = 0; j < edgeCountByTypeArray[i].length; j++) {
edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j);
}
}
// Write node, edge and label counts/statistics
printEdgeCounts(outputBasename, edgeCount.get(), edgeCountByTypeArray);
printNodeCounts(outputBasename, nodesOutputThread.getNodeCount(), nodesOutputThread.getNodeTypeCounts());
printLabelCounts(outputBasename, labelsOutputThread.getLabelCount());
+
+ // Clean up sorted batches
+ for (int i = 0; i < numThreads; i++) {
+ if (nodeBatchPaths[i] != null) {
+ nodeBatchPaths[i].delete();
+ }
+ if (labelBatches[i] != null) {
+ labelBatches[i].delete();
+ }
+ }
}
private static void printEdgeCounts(String basename, long edgeCount, long[][] edgeTypeCounts) throws IOException {
PrintWriter nodeCountWriter = new PrintWriter(basename + ".edges.count.txt");
nodeCountWriter.println(edgeCount);
nodeCountWriter.close();
PrintWriter nodeTypesCountWriter = new PrintWriter(basename + ".edges.stats.txt");
TreeMap edgeTypeCountsMap = new TreeMap<>();
for (Node.Type src : Node.Type.values()) {
for (Node.Type dst : Node.Type.values()) {
long cnt = edgeTypeCounts[Node.Type.toInt(src)][Node.Type.toInt(dst)];
if (cnt > 0)
edgeTypeCountsMap.put(src.toString().toLowerCase() + ":" + dst.toString().toLowerCase(), cnt);
}
}
for (Map.Entry entry : edgeTypeCountsMap.entrySet()) {
nodeTypesCountWriter.println(entry.getKey() + " " + entry.getValue());
}
nodeTypesCountWriter.close();
}
private static void printNodeCounts(String basename, long nodeCount, long[] nodeTypeCounts) throws IOException {
PrintWriter nodeCountWriter = new PrintWriter(basename + ".nodes.count.txt");
nodeCountWriter.println(nodeCount);
nodeCountWriter.close();
PrintWriter nodeTypesCountWriter = new PrintWriter(basename + ".nodes.stats.txt");
TreeMap nodeTypeCountsMap = new TreeMap<>();
for (Node.Type v : Node.Type.values()) {
nodeTypeCountsMap.put(v.toString().toLowerCase(), nodeTypeCounts[Node.Type.toInt(v)]);
}
for (Map.Entry entry : nodeTypeCountsMap.entrySet()) {
nodeTypesCountWriter.println(entry.getKey() + " " + entry.getValue());
}
nodeTypesCountWriter.close();
}
private static void printLabelCounts(String basename, long labelCount) throws IOException {
PrintWriter nodeCountWriter = new PrintWriter(basename + ".labels.count.txt");
nodeCountWriter.println(labelCount);
nodeCountWriter.close();
}
private static class NodesOutputThread extends Thread {
private final InputStream sortedNodesStream;
private final OutputStream nodesOutputStream;
private long nodeCount = 0;
private final long[] nodeTypeCounts = new long[Node.Type.values().length];
NodesOutputThread(InputStream sortedNodesStream, OutputStream nodesOutputStream) {
this.sortedNodesStream = sortedNodesStream;
this.nodesOutputStream = nodesOutputStream;
}
@Override
public void run() {
BufferedReader reader = new BufferedReader(
new InputStreamReader(sortedNodesStream, StandardCharsets.UTF_8));
try {
String line;
while ((line = reader.readLine()) != null) {
nodesOutputStream.write(line.getBytes(StandardCharsets.UTF_8));
nodesOutputStream.write('\n');
nodeCount++;
try {
Node.Type nodeType = Node.Type.fromStr(line.split(":")[2]);
nodeTypeCounts[Node.Type.toInt(nodeType)]++;
} catch (ArrayIndexOutOfBoundsException e) {
System.err.println("Error parsing SWHID: " + line);
System.exit(1);
}
}
nodesOutputStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public long getNodeCount() {
return nodeCount;
}
public long[] getNodeTypeCounts() {
return nodeTypeCounts;
}
}
private static class LabelsOutputThread extends Thread {
private final InputStream sortedLabelsStream;
private final OutputStream labelsOutputStream;
private long labelCount = 0;
LabelsOutputThread(InputStream sortedLabelsStream, OutputStream labelsOutputStream) {
this.labelsOutputStream = labelsOutputStream;
this.sortedLabelsStream = sortedLabelsStream;
}
@Override
public void run() {
BufferedReader reader = new BufferedReader(
new InputStreamReader(sortedLabelsStream, StandardCharsets.UTF_8));
try {
String line;
while ((line = reader.readLine()) != null) {
labelsOutputStream.write(line.getBytes(StandardCharsets.UTF_8));
labelsOutputStream.write('\n');
labelCount++;
}
labelsOutputStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public long getLabelCount() {
return labelCount;
}
}
}
diff --git a/java/src/main/java/org/softwareheritage/graph/compress/LabelMapBuilder.java b/java/src/main/java/org/softwareheritage/graph/compress/LabelMapBuilder.java
index 00cb9b0..9279c08 100644
--- a/java/src/main/java/org/softwareheritage/graph/compress/LabelMapBuilder.java
+++ b/java/src/main/java/org/softwareheritage/graph/compress/LabelMapBuilder.java
@@ -1,473 +1,480 @@
package org.softwareheritage.graph.compress;
import com.martiansoftware.jsap.*;
import it.unimi.dsi.big.webgraph.LazyLongIterator;
import it.unimi.dsi.big.webgraph.labelling.ArcLabelledImmutableGraph;
import it.unimi.dsi.big.webgraph.labelling.BitStreamArcLabelledImmutableGraph;
import it.unimi.dsi.fastutil.Arrays;
import it.unimi.dsi.fastutil.BigArrays;
import it.unimi.dsi.fastutil.Size64;
import it.unimi.dsi.fastutil.longs.LongBigArrays;
import it.unimi.dsi.fastutil.longs.LongHeapSemiIndirectPriorityQueue;
import it.unimi.dsi.fastutil.objects.Object2LongFunction;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.io.InputBitStream;
import it.unimi.dsi.io.OutputBitStream;
import it.unimi.dsi.logging.ProgressLogger;
import it.unimi.dsi.big.webgraph.ImmutableGraph;
import it.unimi.dsi.big.webgraph.NodeIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.softwareheritage.graph.labels.DirEntry;
import org.softwareheritage.graph.labels.SwhLabel;
import org.softwareheritage.graph.maps.NodeIdMap;
import org.softwareheritage.graph.utils.ForkJoinBigQuickSort2;
import org.softwareheritage.graph.utils.ForkJoinQuickSort3;
import java.io.*;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
public class LabelMapBuilder {
final static Logger logger = LoggerFactory.getLogger(LabelMapBuilder.class);
// Create one thread per processor.
final static int numThreads = Runtime.getRuntime().availableProcessors();
// Allocate up to 40% of maximum memory.
- final static int batchSize = Math.min((int) (Runtime.getRuntime().maxMemory() * 0.4 / (numThreads * 8 * 3)),
- Arrays.MAX_ARRAY_SIZE);
+ final static int DEFAULT_BATCH_SIZE = Math
+ .min((int) (Runtime.getRuntime().maxMemory() * 0.4 / (numThreads * 8 * 3)), Arrays.MAX_ARRAY_SIZE);
String orcDatasetPath;
String graphPath;
String outputGraphPath;
String tmpDir;
+ int batchSize;
long numNodes;
long numArcs;
NodeIdMap nodeIdMap;
Object2LongFunction filenameMph;
long numFilenames;
int totalLabelWidth;
public LabelMapBuilder(String orcDatasetPath, String graphPath, String outputGraphPath, int batchSize,
String tmpDir) throws IOException {
this.orcDatasetPath = orcDatasetPath;
this.graphPath = graphPath;
this.outputGraphPath = (outputGraphPath == null) ? graphPath : outputGraphPath;
- // this.batchSize = batchSize;
+ this.batchSize = batchSize;
this.tmpDir = tmpDir;
ImmutableGraph graph = ImmutableGraph.loadOffline(graphPath);
this.numArcs = graph.numArcs();
this.numNodes = graph.numNodes();
this.nodeIdMap = new NodeIdMap(graphPath);
filenameMph = NodeIdMap.loadMph(graphPath + ".labels.mph");
numFilenames = getMPHSize(filenameMph);
totalLabelWidth = DirEntry.labelWidth(numFilenames);
}
private static JSAPResult parse_args(String[] args) {
JSAPResult config = null;
try {
SimpleJSAP jsap = new SimpleJSAP(LabelMapBuilder.class.getName(), "", new Parameter[]{
new UnflaggedOption("dataset", JSAP.STRING_PARSER, JSAP.REQUIRED, "Path to the ORC graph dataset"),
new UnflaggedOption("graphPath", JSAP.STRING_PARSER, JSAP.REQUIRED, "Basename of the output graph"),
new FlaggedOption("outputGraphPath", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'o',
"output-graph", "Basename of the output graph, same as --graph if not specified"),
- new FlaggedOption("batchSize", JSAP.INTEGER_PARSER, "10000000", JSAP.NOT_REQUIRED, 'b',
- "batch-size", "Number of triplets held in memory in each batch"),
+ new FlaggedOption("batchSize", JSAP.INTEGER_PARSER, String.valueOf(DEFAULT_BATCH_SIZE),
+ JSAP.NOT_REQUIRED, 'b', "batch-size", "Number of triplets held in memory in each batch"),
new FlaggedOption("tmpDir", JSAP.STRING_PARSER, "tmp", JSAP.NOT_REQUIRED, 'T', "temp-dir",
"Temporary directory path"),});
config = jsap.parse(args);
if (jsap.messagePrinted()) {
System.exit(1);
}
} catch (JSAPException e) {
e.printStackTrace();
}
return config;
}
public static void main(String[] args) throws IOException, InterruptedException {
JSAPResult config = parse_args(args);
String orcDataset = config.getString("dataset");
String graphPath = config.getString("graphPath");
String outputGraphPath = config.getString("outputGraphPath");
int batchSize = config.getInt("batchSize");
String tmpDir = config.getString("tmpDir");
LabelMapBuilder builder = new LabelMapBuilder(orcDataset, graphPath, outputGraphPath, batchSize, tmpDir);
builder.computeLabelMap();
}
static long getMPHSize(Object2LongFunction mph) {
return (mph instanceof Size64) ? ((Size64) mph).size64() : mph.size();
}
void computeLabelMap() throws IOException {
File tempDirFile = new File(tmpDir);
ObjectArrayList forwardBatches = new ObjectArrayList<>();
ObjectArrayList backwardBatches = new ObjectArrayList<>();
genSortedBatches(forwardBatches, backwardBatches, tempDirFile);
BatchEdgeLabelLineIterator forwardBatchHeapIterator = new BatchEdgeLabelLineIterator(forwardBatches);
writeLabels(forwardBatchHeapIterator, graphPath, outputGraphPath);
+ for (File batch : forwardBatches) {
+ batch.delete();
+ }
BatchEdgeLabelLineIterator backwardBatchHeapIterator = new BatchEdgeLabelLineIterator(backwardBatches);
writeLabels(backwardBatchHeapIterator, graphPath + "-transposed", outputGraphPath + "-transposed");
+ for (File batch : backwardBatches) {
+ batch.delete();
+ }
logger.info("Done");
}
void genSortedBatches(ObjectArrayList forwardBatches, ObjectArrayList backwardBatches, File tempDirFile)
throws IOException {
logger.info("Initializing batch arrays.");
long[][] srcArrays = new long[numThreads][batchSize];
long[][] dstArrays = new long[numThreads][batchSize];
long[][] labelArrays = new long[numThreads][batchSize];
int[] indexes = new int[numThreads];
long[] progressCounts = new long[numThreads];
ProgressLogger plSortingBatches = new ProgressLogger(logger, 10, TimeUnit.SECONDS);
plSortingBatches.itemsName = "edges";
plSortingBatches.expectedUpdates = this.numArcs;
plSortingBatches.start("Reading edges and writing sorted batches.");
AtomicInteger nextThreadId = new AtomicInteger(0);
ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement);
readHashedEdgeLabels((src, dst, label, perms) -> {
// System.err.println("0. Input " + src + " " + dst + " " + label + " " + perms);
int threadId = threadLocalId.get();
int idx = indexes[threadId]++;
srcArrays[threadId][idx] = src;
dstArrays[threadId][idx] = dst;
labelArrays[threadId][idx] = DirEntry.toEncoded(label, perms);
if (++progressCounts[threadId] > 1000) {
synchronized (plSortingBatches) {
plSortingBatches.update(progressCounts[threadId]);
}
progressCounts[threadId] = 0;
}
if (idx == batchSize - 1) {
processBidirectionalBatches(batchSize, srcArrays[threadId], dstArrays[threadId], labelArrays[threadId],
tempDirFile, forwardBatches, backwardBatches);
indexes[threadId] = 0;
}
});
IntStream.range(0, numThreads).parallel().forEach(t -> {
int idx = indexes[t];
if (idx > 0) {
try {
processBidirectionalBatches(idx, srcArrays[t], dstArrays[t], labelArrays[t], tempDirFile,
forwardBatches, backwardBatches);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
});
// Trigger the GC to free up the large arrays
for (int i = 0; i < numThreads; i++) {
srcArrays[i] = null;
dstArrays[i] = null;
labelArrays[i] = null;
}
logger.info("Created " + forwardBatches.size() + " forward batches and " + backwardBatches.size()
+ " backward batches.");
}
void readHashedEdgeLabels(GraphDataset.HashedEdgeCallback cb) throws IOException {
ORCGraphDataset dataset = new ORCGraphDataset(orcDatasetPath);
ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
try {
forkJoinPool.submit(() -> {
try {
dataset.readEdges((node) -> {
}, (src, dst, label, perms) -> {
if (label == null) {
return;
}
long srcNode = nodeIdMap.getNodeId(src);
long dstNode = nodeIdMap.getNodeId(dst);
long labelId = filenameMph.getLong(label);
cb.onHashedEdge(srcNode, dstNode, labelId, perms);
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
void processBidirectionalBatches(final int n, final long[] source, final long[] target, final long[] labels,
final File tempDir, final List forwardBatches, final List backwardBatches) throws IOException {
processBatch(n, source, target, labels, tempDir, forwardBatches);
processBatch(n, target, source, labels, tempDir, backwardBatches);
}
void processBatch(final int n, final long[] source, final long[] target, final long[] labels, final File tempDir,
final List batches) throws IOException {
if (n == 0) {
return;
}
ForkJoinQuickSort3.parallelQuickSort(source, target, labels, 0, n);
final File batchFile = File.createTempFile("batch", ".bitstream", tempDir);
batchFile.deleteOnExit();
batches.add(batchFile);
final OutputBitStream batch = new OutputBitStream(batchFile);
// Compute unique triplets
int u = 1;
for (int i = n - 1; i-- != 0;) {
if (source[i] != source[i + 1] || target[i] != target[i + 1] || labels[i] != labels[i + 1]) {
u++;
}
}
batch.writeDelta(u);
// Write batch
long prevSource = source[0];
batch.writeLongDelta(prevSource);
batch.writeLongDelta(target[0]);
batch.writeLongDelta(labels[0]);
// System.err.println("1. Wrote " + prevSource + " " + target[0] + " " + labels[0]);
for (int i = 1; i < n; i++) {
if (source[i] != prevSource) {
// Default case, we write (source - prevsource, target, label)
batch.writeLongDelta(source[i] - prevSource);
batch.writeLongDelta(target[i]);
batch.writeLongDelta(labels[i]);
prevSource = source[i];
} else if (target[i] != target[i - 1] || labels[i] != labels[i - 1]) {
// Case where source is identical with prevsource, but target or label differ.
// We write (0, target - prevtarget, label)
batch.writeLongDelta(0);
batch.writeLongDelta(target[i] - target[i - 1]);
batch.writeLongDelta(labels[i]);
} else {
continue;
}
// System.err.println("1. Wrote " + source[i] + " " + target[i] + " " + labels[i]);
}
batch.close();
}
void writeLabels(EdgeLabelLineIterator mapLines, String graphBasename, String outputGraphBasename)
throws IOException {
// Loading the graph to iterate
ImmutableGraph graph = ImmutableGraph.loadMapped(graphBasename);
// Get the sorted output and write the labels and label offsets
ProgressLogger plLabels = new ProgressLogger(logger, 10, TimeUnit.SECONDS);
plLabels.itemsName = "edges";
plLabels.expectedUpdates = this.numArcs;
plLabels.start("Writing the labels to the label file: " + outputGraphBasename + "-labelled.*");
OutputBitStream labels = new OutputBitStream(
new File(outputGraphBasename + "-labelled" + BitStreamArcLabelledImmutableGraph.LABELS_EXTENSION));
OutputBitStream offsets = new OutputBitStream(new File(
outputGraphBasename + "-labelled" + BitStreamArcLabelledImmutableGraph.LABEL_OFFSETS_EXTENSION));
offsets.writeGamma(0);
EdgeLabelLine line = new EdgeLabelLine(-1, -1, -1, -1);
NodeIterator it = graph.nodeIterator();
boolean started = false;
ArrayList labelBuffer = new ArrayList<>(128);
while (it.hasNext()) {
long srcNode = it.nextLong();
long bits = 0;
LazyLongIterator s = it.successors();
long dstNode;
while ((dstNode = s.nextLong()) >= 0) {
while (line != null && line.srcNode <= srcNode && line.dstNode <= dstNode) {
if (line.srcNode == srcNode && line.dstNode == dstNode) {
labelBuffer.add(new DirEntry(line.filenameId, line.permission));
}
if (!mapLines.hasNext())
break;
line = mapLines.next();
if (!started) {
plLabels.start("Writing label map to file...");
started = true;
}
}
SwhLabel l = new SwhLabel("edgelabel", totalLabelWidth, labelBuffer.toArray(new DirEntry[0]));
labelBuffer.clear();
bits += l.toBitStream(labels, -1);
plLabels.lightUpdate();
}
offsets.writeLongGamma(bits);
}
labels.close();
offsets.close();
plLabels.done();
graph = null;
PrintWriter pw = new PrintWriter(new FileWriter(outputGraphBasename + "-labelled.properties"));
pw.println(ImmutableGraph.GRAPHCLASS_PROPERTY_KEY + " = " + BitStreamArcLabelledImmutableGraph.class.getName());
pw.println(BitStreamArcLabelledImmutableGraph.LABELSPEC_PROPERTY_KEY + " = " + SwhLabel.class.getName()
+ "(DirEntry," + totalLabelWidth + ")");
pw.println(ArcLabelledImmutableGraph.UNDERLYINGGRAPH_PROPERTY_KEY + " = "
+ Paths.get(outputGraphBasename).getFileName());
pw.close();
}
public static class EdgeLabelLine {
public long srcNode;
public long dstNode;
public long filenameId;
public int permission;
public EdgeLabelLine(long labelSrcNode, long labelDstNode, long labelFilenameId, int labelPermission) {
this.srcNode = labelSrcNode;
this.dstNode = labelDstNode;
this.filenameId = labelFilenameId;
this.permission = labelPermission;
}
}
public abstract static class EdgeLabelLineIterator implements Iterator {
@Override
public abstract boolean hasNext();
@Override
public abstract EdgeLabelLine next();
}
public static class BatchEdgeLabelLineIterator extends EdgeLabelLineIterator {
private static final int STD_BUFFER_SIZE = 128 * 1024;
private final InputBitStream[] batchIbs;
private final int[] inputStreamLength;
private final long[] refArray;
private final LongHeapSemiIndirectPriorityQueue queue;
private final long[] prevTarget;
/** The last returned node (-1 if no node has been returned yet). */
private long lastNode;
private long[][] lastNodeSuccessors = LongBigArrays.EMPTY_BIG_ARRAY;
private long[][] lastNodeLabels = LongBigArrays.EMPTY_BIG_ARRAY;
private long lastNodeOutdegree;
private long lastNodeCurrentSuccessor;
public BatchEdgeLabelLineIterator(final List batches) throws IOException {
this.batchIbs = new InputBitStream[batches.size()];
this.refArray = new long[batches.size()];
this.prevTarget = new long[batches.size()];
this.queue = new LongHeapSemiIndirectPriorityQueue(refArray);
this.inputStreamLength = new int[batches.size()];
for (int i = 0; i < batches.size(); i++) {
batchIbs[i] = new InputBitStream(batches.get(i), STD_BUFFER_SIZE);
this.inputStreamLength[i] = batchIbs[i].readDelta();
this.refArray[i] = batchIbs[i].readLongDelta();
queue.enqueue(i);
}
this.lastNode = -1;
this.lastNodeOutdegree = 0;
this.lastNodeCurrentSuccessor = 0;
}
public boolean hasNextNode() {
return !queue.isEmpty();
}
private void readNextNode() throws IOException {
assert hasNext();
int i;
lastNode++;
lastNodeOutdegree = 0;
lastNodeCurrentSuccessor = 0;
/*
* We extract elements from the queue as long as their target is equal to last. If during the
* process we exhaust a batch, we close it.
*/
while (!queue.isEmpty() && refArray[i = queue.first()] == lastNode) {
lastNodeSuccessors = BigArrays.grow(lastNodeSuccessors, lastNodeOutdegree + 1);
lastNodeLabels = BigArrays.grow(lastNodeLabels, lastNodeOutdegree + 1);
long target = prevTarget[i] += batchIbs[i].readLongDelta();
long label = batchIbs[i].readLongDelta();
BigArrays.set(lastNodeSuccessors, lastNodeOutdegree, target);
BigArrays.set(lastNodeLabels, lastNodeOutdegree, label);
// System.err.println("2. Read " + lastNode + " " + target + " " + label);
if (--inputStreamLength[i] == 0) {
queue.dequeue();
batchIbs[i].close();
batchIbs[i] = null;
} else {
// We read a new source and update the queue.
final long sourceDelta = batchIbs[i].readLongDelta();
if (sourceDelta != 0) {
refArray[i] += sourceDelta;
prevTarget[i] = 0;
queue.changed();
}
}
lastNodeOutdegree++;
}
// Neither quicksort nor heaps are stable, so we reestablish order here.
// LongBigArrays.radixSort(lastNodeSuccessors, lastNodeLabels, 0, lastNodeOutdegree);
ForkJoinBigQuickSort2.parallelQuickSort(lastNodeSuccessors, lastNodeLabels, 0, lastNodeOutdegree);
}
@Override
public boolean hasNext() {
return lastNodeCurrentSuccessor < lastNodeOutdegree || hasNextNode();
}
@Override
public EdgeLabelLine next() {
if (lastNode == -1 || lastNodeCurrentSuccessor >= lastNodeOutdegree) {
try {
do {
readNextNode();
} while (hasNextNode() && lastNodeOutdegree == 0);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
long src = lastNode;
long dst = BigArrays.get(lastNodeSuccessors, lastNodeCurrentSuccessor);
long compressedLabel = BigArrays.get(lastNodeLabels, lastNodeCurrentSuccessor);
long labelName = DirEntry.labelNameFromEncoded(compressedLabel);
int permission = DirEntry.permissionFromEncoded(compressedLabel);
// System.err.println("3. Output (encoded): " + src + " " + dst + " " + compressedLabel);
// System.err.println("4. Output (decoded): " + src + " " + dst + " " + labelName + " " +
// permission);
lastNodeCurrentSuccessor++;
return new EdgeLabelLine(src, dst, labelName, permission);
}
}
}
diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ScatteredArcsORCGraph.java b/java/src/main/java/org/softwareheritage/graph/compress/ScatteredArcsORCGraph.java
index 06d9f4b..05531f5 100644
--- a/java/src/main/java/org/softwareheritage/graph/compress/ScatteredArcsORCGraph.java
+++ b/java/src/main/java/org/softwareheritage/graph/compress/ScatteredArcsORCGraph.java
@@ -1,250 +1,252 @@
package org.softwareheritage.graph.compress;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
import it.unimi.dsi.big.webgraph.BVGraph;
import it.unimi.dsi.big.webgraph.ImmutableSequentialGraph;
import it.unimi.dsi.big.webgraph.NodeIterator;
import it.unimi.dsi.big.webgraph.Transform;
import it.unimi.dsi.fastutil.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.UnflaggedOption;
import it.unimi.dsi.fastutil.Size64;
import it.unimi.dsi.fastutil.io.BinIO;
import it.unimi.dsi.fastutil.objects.Object2LongFunction;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.logging.ProgressLogger;
public class ScatteredArcsORCGraph extends ImmutableSequentialGraph {
private static final Logger LOGGER = LoggerFactory.getLogger(ScatteredArcsORCGraph.class);
/** The default number of threads. */
public static final int DEFAULT_NUM_THREADS = Runtime.getRuntime().availableProcessors();
/** The default batch size. */
public static final int DEFAULT_BATCH_SIZE = Math
.min((int) (Runtime.getRuntime().maxMemory() * 0.4 / (DEFAULT_NUM_THREADS * 8 * 2)), Arrays.MAX_ARRAY_SIZE);
/** The batch graph used to return node iterators. */
private final Transform.BatchGraph batchGraph;
/**
* Creates a scattered-arcs ORC graph.
*
* @param dataset the Swh ORC Graph dataset
* @param function an explicitly provided function from string representing nodes to node numbers,
* or null
for the standard behaviour.
* @param n the number of nodes of the graph (used only if function
is not
* null
).
* @param numThreads the number of threads to use.
* @param batchSize the number of integers in a batch; two arrays of integers of this size will be
* allocated by each thread.
* @param tempDir a temporary directory for the batches, or null
for
* {@link File#createTempFile(java.lang.String, java.lang.String)}'s choice.
* @param pl a progress logger, or null
.
*/
public ScatteredArcsORCGraph(final ORCGraphDataset dataset, final Object2LongFunction function,
final long n, final int numThreads, final int batchSize, final File tempDir, final ProgressLogger pl)
throws IOException {
final ObjectArrayList batches = new ObjectArrayList<>();
ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
long[][] srcArrays = new long[numThreads][batchSize];
long[][] dstArrays = new long[numThreads][batchSize];
int[] indexes = new int[numThreads];
long[] progressCounts = new long[numThreads];
AtomicInteger pairs = new AtomicInteger(0);
AtomicInteger nextThreadId = new AtomicInteger(0);
ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement);
if (pl != null) {
pl.itemsName = "arcs";
pl.start("Creating sorted batches...");
}
try {
forkJoinPool.submit(() -> {
try {
dataset.readEdges((node) -> {
}, (src, dst, label, perms) -> {
long s = function.getLong(src);
long t = function.getLong(dst);
int threadId = threadLocalId.get();
int idx = indexes[threadId]++;
srcArrays[threadId][idx] = s;
dstArrays[threadId][idx] = t;
if (idx == batchSize - 1) {
pairs.addAndGet(Transform.processBatch(batchSize, srcArrays[threadId], dstArrays[threadId],
tempDir, batches));
indexes[threadId] = 0;
}
- if (++progressCounts[threadId] > 1000) {
+ if (pl != null && ++progressCounts[threadId] > 1000) {
synchronized (pl) {
pl.update(progressCounts[threadId]);
}
progressCounts[threadId] = 0;
}
});
} catch (IOException e) {
throw new RuntimeException(e);
}
}).get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
IntStream.range(0, numThreads).parallel().forEach(t -> {
int idx = indexes[t];
if (idx > 0) {
try {
pairs.addAndGet(Transform.processBatch(idx, srcArrays[t], dstArrays[t], tempDir, batches));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
});
// Trigger the GC to free up the large arrays
for (int i = 0; i < numThreads; i++) {
srcArrays[i] = null;
dstArrays[i] = null;
}
- pl.done();
- pl.logger().info("Created " + batches.size() + " batches.");
+ if (pl != null) {
+ pl.done();
+ pl.logger().info("Created " + batches.size() + " batches.");
+ }
batchGraph = new Transform.BatchGraph(n, pairs.get(), batches);
}
@Override
public long numNodes() {
if (batchGraph == null)
throw new UnsupportedOperationException(
"The number of nodes is unknown (you need to generate all the batches first).");
return batchGraph.numNodes();
}
@Override
public long numArcs() {
if (batchGraph == null)
throw new UnsupportedOperationException(
"The number of arcs is unknown (you need to generate all the batches first).");
return batchGraph.numArcs();
}
@Override
public NodeIterator nodeIterator(final long from) {
return batchGraph.nodeIterator(from);
}
@Override
public boolean hasCopiableIterators() {
return batchGraph.hasCopiableIterators();
}
@Override
public ScatteredArcsORCGraph copy() {
return this;
}
@SuppressWarnings("unchecked")
public static void main(final String[] args)
throws IllegalArgumentException, SecurityException, IOException, JSAPException, ClassNotFoundException {
final SimpleJSAP jsap = new SimpleJSAP(ScatteredArcsORCGraph.class.getName(),
"Converts a scattered list of arcs from an ORC graph dataset into a BVGraph.",
new Parameter[]{
new FlaggedOption("logInterval", JSAP.LONG_PARSER,
Long.toString(ProgressLogger.DEFAULT_LOG_INTERVAL), JSAP.NOT_REQUIRED, 'l',
"log-interval", "The minimum time interval between activity logs in milliseconds."),
new FlaggedOption("numThreads", JSAP.INTSIZE_PARSER, Integer.toString(DEFAULT_NUM_THREADS),
JSAP.NOT_REQUIRED, 't', "threads", "The number of threads to use."),
new FlaggedOption("batchSize", JSAP.INTSIZE_PARSER, Integer.toString(DEFAULT_BATCH_SIZE),
JSAP.NOT_REQUIRED, 's', "batch-size", "The maximum size of a batch, in arcs."),
new FlaggedOption("tempDir", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'T',
"temp-dir", "A directory for all temporary batch files."),
new FlaggedOption("function", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'f',
"function",
"A serialised function from strings to longs that will be used to translate identifiers to node numbers."),
new FlaggedOption("comp", JSAP.STRING_PARSER, null, JSAP.NOT_REQUIRED, 'c', "comp",
"A compression flag (may be specified several times).")
.setAllowMultipleDeclarations(true),
new FlaggedOption("windowSize", JSAP.INTEGER_PARSER,
String.valueOf(BVGraph.DEFAULT_WINDOW_SIZE), JSAP.NOT_REQUIRED, 'w', "window-size",
"Reference window size (0 to disable)."),
new FlaggedOption("maxRefCount", JSAP.INTEGER_PARSER,
String.valueOf(BVGraph.DEFAULT_MAX_REF_COUNT), JSAP.NOT_REQUIRED, 'm', "max-ref-count",
"Maximum number of backward references (-1 for ∞)."),
new FlaggedOption("minIntervalLength", JSAP.INTEGER_PARSER,
String.valueOf(BVGraph.DEFAULT_MIN_INTERVAL_LENGTH), JSAP.NOT_REQUIRED, 'i',
"min-interval-length", "Minimum length of an interval (0 to disable)."),
new FlaggedOption("zetaK", JSAP.INTEGER_PARSER, String.valueOf(BVGraph.DEFAULT_ZETA_K),
JSAP.NOT_REQUIRED, 'k', "zeta-k", "The k parameter for zeta-k codes."),
new UnflaggedOption("dataset", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED,
JSAP.NOT_GREEDY, "The path to the ORC graph dataset."),
new UnflaggedOption("basename", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED,
JSAP.NOT_GREEDY, "The basename of the output graph"),});
final JSAPResult jsapResult = jsap.parse(args);
if (jsap.messagePrinted())
System.exit(1);
String basename = jsapResult.getString("basename");
String orcDatasetPath = jsapResult.getString("dataset");
ORCGraphDataset orcDataset = new ORCGraphDataset(orcDatasetPath);
int flags = 0;
for (final String compressionFlag : jsapResult.getStringArray("comp")) {
try {
flags |= BVGraph.class.getField(compressionFlag).getInt(BVGraph.class);
} catch (final Exception notFound) {
throw new JSAPException("Compression method " + compressionFlag + " unknown.");
}
}
final int windowSize = jsapResult.getInt("windowSize");
final int zetaK = jsapResult.getInt("zetaK");
int maxRefCount = jsapResult.getInt("maxRefCount");
if (maxRefCount == -1)
maxRefCount = Integer.MAX_VALUE;
final int minIntervalLength = jsapResult.getInt("minIntervalLength");
if (!jsapResult.userSpecified("function")) {
throw new IllegalArgumentException("Function must be specified.");
}
final Object2LongFunction function = (Object2LongFunction) BinIO
.loadObject(jsapResult.getString("function"));
long n = function instanceof Size64 ? ((Size64) function).size64() : function.size();
File tempDir = null;
if (jsapResult.userSpecified("tempDir")) {
tempDir = new File(jsapResult.getString("tempDir"));
}
final ProgressLogger pl = new ProgressLogger(LOGGER, jsapResult.getLong("logInterval"), TimeUnit.MILLISECONDS);
final int batchSize = jsapResult.getInt("batchSize");
final int numThreads = jsapResult.getInt("numThreads");
final ScatteredArcsORCGraph graph = new ScatteredArcsORCGraph(orcDataset, function, n, numThreads, batchSize,
tempDir, pl);
BVGraph.store(graph, basename, windowSize, maxRefCount, minIntervalLength, zetaK, flags, pl);
}
}
diff --git a/java/src/test/java/org/softwareheritage/graph/compress/ExtractNodesTest.java b/java/src/test/java/org/softwareheritage/graph/compress/ExtractNodesTest.java
index 3f13b9d..d9713f8 100644
--- a/java/src/test/java/org/softwareheritage/graph/compress/ExtractNodesTest.java
+++ b/java/src/test/java/org/softwareheritage/graph/compress/ExtractNodesTest.java
@@ -1,106 +1,106 @@
package org.softwareheritage.graph.compress;
import org.apache.commons.codec.digest.DigestUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.softwareheritage.graph.GraphTest;
import org.softwareheritage.graph.Node;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.TreeSet;
public class ExtractNodesTest extends GraphTest {
/** Generate a fake SWHID for a given node type and numeric ID */
private static byte[] f(String type, int id) {
String hash = new String(DigestUtils.sha1Hex(type + id).getBytes());
return String.format("swh:1:%s:%s", type, hash).getBytes();
}
static class FakeDataset implements GraphDataset {
@Override
public void readEdges(NodeCallback nodeCb, EdgeCallback edgeCb) throws IOException {
// For each node type, write nodes {1..4} as present in the graph
for (Node.Type type : Node.Type.values()) {
for (int i = 1; i <= 4; i++) {
byte[] node = f(type.toString().toLowerCase(), i);
nodeCb.onNode(node);
}
}
edgeCb.onEdge(f("ori", 1), f("snp", 1), null, -1);
edgeCb.onEdge(f("ori", 2), f("snp", 2), null, -1);
edgeCb.onEdge(f("ori", 3), f("snp", 3), null, -1);
edgeCb.onEdge(f("ori", 4), f("snp", 404), null, -1);
edgeCb.onEdge(f("snp", 1), f("rev", 1), "dup1".getBytes(), -1);
edgeCb.onEdge(f("snp", 1), f("rev", 1), "dup2".getBytes(), -1);
edgeCb.onEdge(f("snp", 3), f("cnt", 1), "c1".getBytes(), -1);
edgeCb.onEdge(f("snp", 4), f("rel", 1), "r1".getBytes(), -1);
edgeCb.onEdge(f("rel", 1), f("rel", 2), null, -1);
edgeCb.onEdge(f("rel", 2), f("rev", 1), null, -1);
edgeCb.onEdge(f("rel", 3), f("rev", 2), null, -1);
edgeCb.onEdge(f("rel", 4), f("dir", 1), null, -1);
edgeCb.onEdge(f("rev", 1), f("rev", 1), null, -1);
edgeCb.onEdge(f("rev", 1), f("rev", 1), null, -1);
edgeCb.onEdge(f("rev", 1), f("rev", 2), null, -1);
edgeCb.onEdge(f("rev", 2), f("rev", 404), null, -1);
edgeCb.onEdge(f("rev", 3), f("rev", 2), null, -1);
edgeCb.onEdge(f("rev", 4), f("dir", 1), null, -1);
edgeCb.onEdge(f("dir", 1), f("cnt", 1), "c1".getBytes(), 42);
edgeCb.onEdge(f("dir", 1), f("dir", 1), "d1".getBytes(), 1337);
edgeCb.onEdge(f("dir", 1), f("rev", 1), "r1".getBytes(), 0);
}
}
@Test
public void testExtractNodes(@TempDir Path outputDir, @TempDir Path sortTmpDir)
throws IOException, InterruptedException {
FakeDataset dataset = new FakeDataset();
- ExtractNodes.extractNodes(dataset, outputDir.toString() + "/graph", "2M", sortTmpDir.toString());
+ ExtractNodes.extractNodes(dataset, outputDir.toString() + "/graph", "2M", sortTmpDir.toFile());
// Check count files
Long nodeCount = Long.parseLong(Files.readString(outputDir.resolve("graph.nodes.count.txt")).strip());
Long edgeCount = Long.parseLong(Files.readString(outputDir.resolve("graph.edges.count.txt")).strip());
Long labelCount = Long.parseLong(Files.readString(outputDir.resolve("graph.labels.count.txt")).strip());
Assertions.assertEquals(26L, nodeCount);
Assertions.assertEquals(21L, edgeCount);
Assertions.assertEquals(5L, labelCount);
// Check stat files
List nodeStats = Files.readAllLines(outputDir.resolve("graph.nodes.stats.txt"));
List edgeStats = Files.readAllLines(outputDir.resolve("graph.edges.stats.txt"));
Assertions.assertEquals(nodeStats, List.of("cnt 4", "dir 4", "ori 4", "rel 4", "rev 5", "snp 5"));
Assertions.assertEquals(edgeStats, List.of("dir:cnt 1", "dir:dir 1", "dir:rev 1", "ori:snp 4", "rel:dir 1",
"rel:rel 1", "rel:rev 2", "rev:dir 1", "rev:rev 5", "snp:cnt 1", "snp:rel 1", "snp:rev 2"));
// Build ordered set of expected node IDs
TreeSet expectedNodes = new TreeSet<>();
for (Node.Type type : Node.Type.values()) {
for (int i = 1; i <= 4; i++) {
byte[] node = f(type.toString().toLowerCase(), i);
expectedNodes.add(new String(node));
}
}
expectedNodes.add(new String(f("snp", 404)));
expectedNodes.add(new String(f("rev", 404)));
String[] nodeLines = readZstFile(outputDir.resolve("graph.nodes.csv.zst"));
Assertions.assertArrayEquals(expectedNodes.toArray(new String[0]), nodeLines);
// Build ordered set of expected label IDs
TreeSet expectedLabels = new TreeSet<>();
expectedLabels.add("dup1");
expectedLabels.add("dup2");
expectedLabels.add("c1");
expectedLabels.add("r1");
expectedLabels.add("d1");
String[] labelLines = readZstFile(outputDir.resolve("graph.labels.csv.zst"));
Assertions.assertArrayEquals(expectedLabels.toArray(new String[0]), labelLines);
}
}