diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
--- a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
+++ b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java
@@ -2,12 +2,21 @@
import com.github.luben.zstd.ZstdOutputStream;
import com.martiansoftware.jsap.*;
+import it.unimi.dsi.logging.ProgressLogger;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.softwareheritage.graph.Node;
import org.softwareheritage.graph.utils.Sort;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicLongArray;
/**
* Read a graph dataset and extract all the unique node SWHIDs it contains, including the ones that
@@ -46,6 +55,14 @@
*
*/
public class ExtractNodes {
+ private final static Logger logger = LoggerFactory.getLogger(ExtractNodes.class);
+
+ // Create one thread per processor.
+ final static int numThreads = Runtime.getRuntime().availableProcessors();
+
+ // Allocate up to 20% of maximum memory for sorting subprocesses.
+ final static long sortBufferSize = (long) (Runtime.getRuntime().maxMemory() * 0.2 / numThreads / 2);
+
private static JSAPResult parseArgs(String[] args) {
JSAPResult config = null;
try {
@@ -56,8 +73,9 @@
new FlaggedOption("format", JSAP.STRING_PARSER, "orc", JSAP.NOT_REQUIRED, 'f', "format",
"Format of the input dataset (orc, csv)"),
- new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, "30%", JSAP.NOT_REQUIRED, 'S',
- "sort-buffer-size", "Size of the memory buffer used by sort"),
+ new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, String.valueOf(sortBufferSize),
+ JSAP.NOT_REQUIRED, 'S', "sort-buffer-size",
+ "Size of the memory buffer used by each sort process"),
new FlaggedOption("sortTmpDir", JSAP.STRING_PARSER, null, JSAP.NOT_REQUIRED, 'T', "temp-dir",
"Path to the temporary directory used by sort")});
@@ -98,65 +116,157 @@
public static void extractNodes(GraphDataset dataset, String outputBasename, String sortBufferSize,
String sortTmpDir) throws IOException, InterruptedException {
- // Spawn node sorting process
- Process nodeSort = Sort.spawnSort(sortBufferSize, sortTmpDir);
- BufferedOutputStream nodeSortStdin = new BufferedOutputStream(nodeSort.getOutputStream());
- BufferedInputStream nodeSortStdout = new BufferedInputStream(nodeSort.getInputStream());
+ // Read the dataset and write the nodes and labels to the sorting processes
+ AtomicLong edgeCount = new AtomicLong(0);
+ AtomicLongArray edgeCountByType = new AtomicLongArray(Node.Type.values().length * Node.Type.values().length);
+
+ int numThreads = Runtime.getRuntime().availableProcessors();
+ ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
+
+ Process[] nodeSorters = new Process[numThreads];
+ String[] nodeBatchPaths = new String[numThreads];
+ Process[] labelSorters = new Process[numThreads];
+ String[] labelBatchPaths = new String[numThreads];
+ long[] progressCounts = new long[numThreads];
+
+ AtomicInteger nextThreadId = new AtomicInteger(0);
+ ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement);
+
+ ProgressLogger pl = new ProgressLogger(logger, 10, TimeUnit.SECONDS);
+ pl.itemsName = "edges";
+ pl.start("Reading node/edge files and writing sorted batches.");
+
+ GraphDataset.NodeCallback nodeCallback = (node) -> {
+ int threadId = threadLocalId.get();
+ if (nodeSorters[threadId] == null) {
+ nodeBatchPaths[threadId] = sortTmpDir + "/nodes." + threadId + ".txt";
+ nodeSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir,
+ List.of("-o", nodeBatchPaths[threadId]));
+ }
+ OutputStream nodeOutputStream = nodeSorters[threadId].getOutputStream();
+ nodeOutputStream.write(node);
+ nodeOutputStream.write('\n');
+ };
+
+ GraphDataset.NodeCallback labelCallback = (label) -> {
+ int threadId = threadLocalId.get();
+ if (labelSorters[threadId] == null) {
+ labelBatchPaths[threadId] = sortTmpDir + "/labels." + threadId + ".txt";
+ labelSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir,
+ List.of("-o", labelBatchPaths[threadId]));
+ }
+ OutputStream labelOutputStream = labelSorters[threadId].getOutputStream();
+ labelOutputStream.write(label);
+ labelOutputStream.write('\n');
+ };
+
+ try {
+ forkJoinPool.submit(() -> {
+ try {
+ dataset.readEdges((node) -> {
+ nodeCallback.onNode(node);
+ }, (src, dst, label, perm) -> {
+ nodeCallback.onNode(src);
+ nodeCallback.onNode(dst);
+
+ if (label != null) {
+ labelCallback.onNode(label);
+ }
+ edgeCount.incrementAndGet();
+ // Extract type of src and dst from their SWHID: swh:1:XXX
+ byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3);
+ byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3);
+ int srcType = Node.Type.byteNameToInt(srcTypeBytes);
+ int dstType = Node.Type.byteNameToInt(dstTypeBytes);
+ if (srcType != -1 && dstType != -1) {
+ edgeCountByType.incrementAndGet(srcType * Node.Type.values().length + dstType);
+ } else {
+ System.err.println("Invalid edge type: " + new String(srcTypeBytes) + " -> "
+ + new String(dstTypeBytes));
+ System.exit(1);
+ }
+
+ int threadId = threadLocalId.get();
+ if (++progressCounts[threadId] > 1000) {
+ synchronized (pl) {
+ pl.update(progressCounts[threadId]);
+ }
+ progressCounts[threadId] = 0;
+ }
+ });
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }).get();
+ } catch (ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+
+ // Close all the sorters stdin
+ for (int i = 0; i < numThreads; i++) {
+ if (nodeSorters[i] != null) {
+ nodeSorters[i].getOutputStream().close();
+ }
+ if (labelSorters[i] != null) {
+ labelSorters[i].getOutputStream().close();
+ }
+ }
+
+ // Wait for sorting processes to finish
+ for (int i = 0; i < numThreads; i++) {
+ if (nodeSorters[i] != null) {
+ nodeSorters[i].waitFor();
+ }
+ if (labelSorters[i] != null) {
+ labelSorters[i].waitFor();
+ }
+ }
+ pl.done();
+
+ ArrayList nodeSortMergerOptions = new ArrayList<>(List.of("-m"));
+ ArrayList labelSortMergerOptions = new ArrayList<>(List.of("-m"));
+ for (int i = 0; i < numThreads; i++) {
+ if (nodeBatchPaths[i] != null) {
+ nodeSortMergerOptions.add(nodeBatchPaths[i]);
+ }
+ if (labelBatchPaths[i] != null) {
+ labelSortMergerOptions.add(labelBatchPaths[i]);
+ }
+ }
+
+ // Spawn node merge-sorting process
+ Process nodeSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, nodeSortMergerOptions);
+ nodeSortMerger.getOutputStream().close();
OutputStream nodesFileOutputStream = new ZstdOutputStream(
new BufferedOutputStream(new FileOutputStream(outputBasename + ".nodes.csv.zst")));
- NodesOutputThread nodesOutputThread = new NodesOutputThread(nodeSortStdout, nodesFileOutputStream);
+ NodesOutputThread nodesOutputThread = new NodesOutputThread(
+ new BufferedInputStream(nodeSortMerger.getInputStream()), nodesFileOutputStream);
nodesOutputThread.start();
- // Spawn label sorting process
- Process labelSort = Sort.spawnSort(sortBufferSize, sortTmpDir);
- BufferedOutputStream labelSortStdin = new BufferedOutputStream(labelSort.getOutputStream());
- BufferedInputStream labelSortStdout = new BufferedInputStream(labelSort.getInputStream());
+ // Spawn label merge-sorting process
+ Process labelSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, labelSortMergerOptions);
+ labelSortMerger.getOutputStream().close();
OutputStream labelsFileOutputStream = new ZstdOutputStream(
new BufferedOutputStream(new FileOutputStream(outputBasename + ".labels.csv.zst")));
- LabelsOutputThread labelsOutputThread = new LabelsOutputThread(labelSortStdout, labelsFileOutputStream);
+ LabelsOutputThread labelsOutputThread = new LabelsOutputThread(
+ new BufferedInputStream(labelSortMerger.getInputStream()), labelsFileOutputStream);
labelsOutputThread.start();
- // Read the dataset and write the nodes and labels to the sorting processes
- long[] edgeCount = {0};
- long[][] edgeCountByType = new long[Node.Type.values().length][Node.Type.values().length];
- dataset.readEdges((node) -> {
- nodeSortStdin.write(node);
- nodeSortStdin.write('\n');
- }, (src, dst, label, perm) -> {
- nodeSortStdin.write(src);
- nodeSortStdin.write('\n');
- nodeSortStdin.write(dst);
- nodeSortStdin.write('\n');
- if (label != null) {
- labelSortStdin.write(label);
- labelSortStdin.write('\n');
- }
- edgeCount[0]++;
- // Extract type of src and dst from their SWHID: swh:1:XXX
- byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3);
- byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3);
- int srcType = Node.Type.byteNameToInt(srcTypeBytes);
- int dstType = Node.Type.byteNameToInt(dstTypeBytes);
- if (srcType != -1 && dstType != -1) {
- edgeCountByType[srcType][dstType]++;
- } else {
- System.err
- .println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + new String(dstTypeBytes));
- System.exit(1);
- }
- });
-
- // Wait for sorting processes to finish
- nodeSortStdin.close();
- nodeSort.waitFor();
- labelSortStdin.close();
- labelSort.waitFor();
-
+ pl.logger().info("Waiting for merge-sort and writing output files...");
+ nodeSortMerger.waitFor();
+ labelSortMerger.waitFor();
nodesOutputThread.join();
labelsOutputThread.join();
+ long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length];
+ for (int i = 0; i < edgeCountByTypeArray.length; i++) {
+ for (int j = 0; j < edgeCountByTypeArray[i].length; j++) {
+ edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j);
+ }
+ }
+
// Write node, edge and label counts/statistics
- printEdgeCounts(outputBasename, edgeCount[0], edgeCountByType);
+ printEdgeCounts(outputBasename, edgeCount.get(), edgeCountByTypeArray);
printNodeCounts(outputBasename, nodesOutputThread.getNodeCount(), nodesOutputThread.getNodeTypeCounts());
printLabelCounts(outputBasename, labelsOutputThread.getLabelCount());
}
diff --git a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java
--- a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java
+++ b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java
@@ -7,6 +7,10 @@
public class Sort {
public static Process spawnSort(String sortBufferSize, String sortTmpDir) throws IOException {
+ return spawnSort(sortBufferSize, sortTmpDir, null);
+ }
+
+ public static Process spawnSort(String sortBufferSize, String sortTmpDir, List options) throws IOException {
ProcessBuilder sortProcessBuilder = new ProcessBuilder();
sortProcessBuilder.redirectError(ProcessBuilder.Redirect.INHERIT);
ArrayList command = new ArrayList<>(List.of("sort", "-u", "--buffer-size", sortBufferSize));
@@ -14,6 +18,9 @@
command.add("--temporary-directory");
command.add(sortTmpDir);
}
+ if (options != null) {
+ command.addAll(options);
+ }
sortProcessBuilder.command(command);
Map env = sortProcessBuilder.environment();
env.put("LC_ALL", "C");