diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java --- a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java +++ b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java @@ -2,12 +2,21 @@ import com.github.luben.zstd.ZstdOutputStream; import com.martiansoftware.jsap.*; +import it.unimi.dsi.logging.ProgressLogger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.softwareheritage.graph.Node; import org.softwareheritage.graph.utils.Sort; import java.io.*; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicLongArray; /** * Read a graph dataset and extract all the unique node SWHIDs it contains, including the ones that @@ -46,6 +55,14 @@ *

*/ public class ExtractNodes { + private final static Logger logger = LoggerFactory.getLogger(ExtractNodes.class); + + // Create one thread per processor. + final static int numThreads = Runtime.getRuntime().availableProcessors(); + + // Allocate up to 20% of maximum memory for sorting subprocesses. + final static long sortBufferSize = (long) (Runtime.getRuntime().maxMemory() * 0.2 / numThreads / 2); + private static JSAPResult parseArgs(String[] args) { JSAPResult config = null; try { @@ -56,8 +73,9 @@ new FlaggedOption("format", JSAP.STRING_PARSER, "orc", JSAP.NOT_REQUIRED, 'f', "format", "Format of the input dataset (orc, csv)"), - new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, "30%", JSAP.NOT_REQUIRED, 'S', - "sort-buffer-size", "Size of the memory buffer used by sort"), + new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, String.valueOf(sortBufferSize), + JSAP.NOT_REQUIRED, 'S', "sort-buffer-size", + "Size of the memory buffer used by each sort process"), new FlaggedOption("sortTmpDir", JSAP.STRING_PARSER, null, JSAP.NOT_REQUIRED, 'T', "temp-dir", "Path to the temporary directory used by sort")}); @@ -98,65 +116,157 @@ public static void extractNodes(GraphDataset dataset, String outputBasename, String sortBufferSize, String sortTmpDir) throws IOException, InterruptedException { - // Spawn node sorting process - Process nodeSort = Sort.spawnSort(sortBufferSize, sortTmpDir); - BufferedOutputStream nodeSortStdin = new BufferedOutputStream(nodeSort.getOutputStream()); - BufferedInputStream nodeSortStdout = new BufferedInputStream(nodeSort.getInputStream()); + // Read the dataset and write the nodes and labels to the sorting processes + AtomicLong edgeCount = new AtomicLong(0); + AtomicLongArray edgeCountByType = new AtomicLongArray(Node.Type.values().length * Node.Type.values().length); + + int numThreads = Runtime.getRuntime().availableProcessors(); + ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads); + + Process[] nodeSorters = new Process[numThreads]; + String[] nodeBatchPaths = new String[numThreads]; + Process[] labelSorters = new Process[numThreads]; + String[] labelBatchPaths = new String[numThreads]; + long[] progressCounts = new long[numThreads]; + + AtomicInteger nextThreadId = new AtomicInteger(0); + ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement); + + ProgressLogger pl = new ProgressLogger(logger, 10, TimeUnit.SECONDS); + pl.itemsName = "edges"; + pl.start("Reading node/edge files and writing sorted batches."); + + GraphDataset.NodeCallback nodeCallback = (node) -> { + int threadId = threadLocalId.get(); + if (nodeSorters[threadId] == null) { + nodeBatchPaths[threadId] = sortTmpDir + "/nodes." + threadId + ".txt"; + nodeSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir, + List.of("-o", nodeBatchPaths[threadId])); + } + OutputStream nodeOutputStream = nodeSorters[threadId].getOutputStream(); + nodeOutputStream.write(node); + nodeOutputStream.write('\n'); + }; + + GraphDataset.NodeCallback labelCallback = (label) -> { + int threadId = threadLocalId.get(); + if (labelSorters[threadId] == null) { + labelBatchPaths[threadId] = sortTmpDir + "/labels." + threadId + ".txt"; + labelSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir, + List.of("-o", labelBatchPaths[threadId])); + } + OutputStream labelOutputStream = labelSorters[threadId].getOutputStream(); + labelOutputStream.write(label); + labelOutputStream.write('\n'); + }; + + try { + forkJoinPool.submit(() -> { + try { + dataset.readEdges((node) -> { + nodeCallback.onNode(node); + }, (src, dst, label, perm) -> { + nodeCallback.onNode(src); + nodeCallback.onNode(dst); + + if (label != null) { + labelCallback.onNode(label); + } + edgeCount.incrementAndGet(); + // Extract type of src and dst from their SWHID: swh:1:XXX + byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3); + byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3); + int srcType = Node.Type.byteNameToInt(srcTypeBytes); + int dstType = Node.Type.byteNameToInt(dstTypeBytes); + if (srcType != -1 && dstType != -1) { + edgeCountByType.incrementAndGet(srcType * Node.Type.values().length + dstType); + } else { + System.err.println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + + new String(dstTypeBytes)); + System.exit(1); + } + + int threadId = threadLocalId.get(); + if (++progressCounts[threadId] > 1000) { + synchronized (pl) { + pl.update(progressCounts[threadId]); + } + progressCounts[threadId] = 0; + } + }); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).get(); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + + // Close all the sorters stdin + for (int i = 0; i < numThreads; i++) { + if (nodeSorters[i] != null) { + nodeSorters[i].getOutputStream().close(); + } + if (labelSorters[i] != null) { + labelSorters[i].getOutputStream().close(); + } + } + + // Wait for sorting processes to finish + for (int i = 0; i < numThreads; i++) { + if (nodeSorters[i] != null) { + nodeSorters[i].waitFor(); + } + if (labelSorters[i] != null) { + labelSorters[i].waitFor(); + } + } + pl.done(); + + ArrayList nodeSortMergerOptions = new ArrayList<>(List.of("-m")); + ArrayList labelSortMergerOptions = new ArrayList<>(List.of("-m")); + for (int i = 0; i < numThreads; i++) { + if (nodeBatchPaths[i] != null) { + nodeSortMergerOptions.add(nodeBatchPaths[i]); + } + if (labelBatchPaths[i] != null) { + labelSortMergerOptions.add(labelBatchPaths[i]); + } + } + + // Spawn node merge-sorting process + Process nodeSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, nodeSortMergerOptions); + nodeSortMerger.getOutputStream().close(); OutputStream nodesFileOutputStream = new ZstdOutputStream( new BufferedOutputStream(new FileOutputStream(outputBasename + ".nodes.csv.zst"))); - NodesOutputThread nodesOutputThread = new NodesOutputThread(nodeSortStdout, nodesFileOutputStream); + NodesOutputThread nodesOutputThread = new NodesOutputThread( + new BufferedInputStream(nodeSortMerger.getInputStream()), nodesFileOutputStream); nodesOutputThread.start(); - // Spawn label sorting process - Process labelSort = Sort.spawnSort(sortBufferSize, sortTmpDir); - BufferedOutputStream labelSortStdin = new BufferedOutputStream(labelSort.getOutputStream()); - BufferedInputStream labelSortStdout = new BufferedInputStream(labelSort.getInputStream()); + // Spawn label merge-sorting process + Process labelSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, labelSortMergerOptions); + labelSortMerger.getOutputStream().close(); OutputStream labelsFileOutputStream = new ZstdOutputStream( new BufferedOutputStream(new FileOutputStream(outputBasename + ".labels.csv.zst"))); - LabelsOutputThread labelsOutputThread = new LabelsOutputThread(labelSortStdout, labelsFileOutputStream); + LabelsOutputThread labelsOutputThread = new LabelsOutputThread( + new BufferedInputStream(labelSortMerger.getInputStream()), labelsFileOutputStream); labelsOutputThread.start(); - // Read the dataset and write the nodes and labels to the sorting processes - long[] edgeCount = {0}; - long[][] edgeCountByType = new long[Node.Type.values().length][Node.Type.values().length]; - dataset.readEdges((node) -> { - nodeSortStdin.write(node); - nodeSortStdin.write('\n'); - }, (src, dst, label, perm) -> { - nodeSortStdin.write(src); - nodeSortStdin.write('\n'); - nodeSortStdin.write(dst); - nodeSortStdin.write('\n'); - if (label != null) { - labelSortStdin.write(label); - labelSortStdin.write('\n'); - } - edgeCount[0]++; - // Extract type of src and dst from their SWHID: swh:1:XXX - byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3); - byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3); - int srcType = Node.Type.byteNameToInt(srcTypeBytes); - int dstType = Node.Type.byteNameToInt(dstTypeBytes); - if (srcType != -1 && dstType != -1) { - edgeCountByType[srcType][dstType]++; - } else { - System.err - .println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + new String(dstTypeBytes)); - System.exit(1); - } - }); - - // Wait for sorting processes to finish - nodeSortStdin.close(); - nodeSort.waitFor(); - labelSortStdin.close(); - labelSort.waitFor(); - + pl.logger().info("Waiting for merge-sort and writing output files..."); + nodeSortMerger.waitFor(); + labelSortMerger.waitFor(); nodesOutputThread.join(); labelsOutputThread.join(); + long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length]; + for (int i = 0; i < edgeCountByTypeArray.length; i++) { + for (int j = 0; j < edgeCountByTypeArray[i].length; j++) { + edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j); + } + } + // Write node, edge and label counts/statistics - printEdgeCounts(outputBasename, edgeCount[0], edgeCountByType); + printEdgeCounts(outputBasename, edgeCount.get(), edgeCountByTypeArray); printNodeCounts(outputBasename, nodesOutputThread.getNodeCount(), nodesOutputThread.getNodeTypeCounts()); printLabelCounts(outputBasename, labelsOutputThread.getLabelCount()); } diff --git a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java --- a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java +++ b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java @@ -7,6 +7,10 @@ public class Sort { public static Process spawnSort(String sortBufferSize, String sortTmpDir) throws IOException { + return spawnSort(sortBufferSize, sortTmpDir, null); + } + + public static Process spawnSort(String sortBufferSize, String sortTmpDir, List options) throws IOException { ProcessBuilder sortProcessBuilder = new ProcessBuilder(); sortProcessBuilder.redirectError(ProcessBuilder.Redirect.INHERIT); ArrayList command = new ArrayList<>(List.of("sort", "-u", "--buffer-size", sortBufferSize)); @@ -14,6 +18,9 @@ command.add("--temporary-directory"); command.add(sortTmpDir); } + if (options != null) { + command.addAll(options); + } sortProcessBuilder.command(command); Map env = sortProcessBuilder.environment(); env.put("LC_ALL", "C");