diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java index c4da735..3b6c060 100644 --- a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java +++ b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java @@ -1,315 +1,393 @@ package org.softwareheritage.graph.compress; import com.github.luben.zstd.ZstdOutputStream; import com.martiansoftware.jsap.*; +import it.unimi.dsi.logging.ProgressLogger; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.softwareheritage.graph.Node; import org.softwareheritage.graph.utils.Sort; import java.io.*; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongArray; /** * Read a graph dataset and extract all the unique node SWHIDs it contains, including the ones that * are not stored as actual objects in the graph, but only referred to by the edges. * Additionally, extract the set of all unique edge labels in the graph. * * * *

* Rationale: Because the graph can contain holes, loose objects and dangling * objects, some nodes that are referred to as destinations in the edge relationships might not * actually be stored in the graph itself. However, to compress the graph using a graph compression * library, it is necessary to have a list of all the nodes in the graph, including the * ones that are simply referred to by the edges but not actually stored as concrete objects. *

* *

* This class reads the entire graph dataset, and uses sort -u to extract the set of * all the unique nodes and unique labels that will be needed as an input for the compression * process. *

*/ public class ExtractNodes { + private final static Logger logger = LoggerFactory.getLogger(ExtractNodes.class); + + // Create one thread per processor. + final static int numThreads = Runtime.getRuntime().availableProcessors(); + + // Allocate up to 20% of maximum memory for sorting subprocesses. + final static long sortBufferSize = (long) (Runtime.getRuntime().maxMemory() * 0.2 / numThreads / 2); + private static JSAPResult parseArgs(String[] args) { JSAPResult config = null; try { SimpleJSAP jsap = new SimpleJSAP(ComposePermutations.class.getName(), "", new Parameter[]{ new UnflaggedOption("dataset", JSAP.STRING_PARSER, JSAP.REQUIRED, "Path to the edges dataset"), new UnflaggedOption("outputBasename", JSAP.STRING_PARSER, JSAP.REQUIRED, "Basename of the output files"), new FlaggedOption("format", JSAP.STRING_PARSER, "orc", JSAP.NOT_REQUIRED, 'f', "format", "Format of the input dataset (orc, csv)"), - new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, "30%", JSAP.NOT_REQUIRED, 'S', - "sort-buffer-size", "Size of the memory buffer used by sort"), + new FlaggedOption("sortBufferSize", JSAP.STRING_PARSER, String.valueOf(sortBufferSize) + "b", + JSAP.NOT_REQUIRED, 'S', "sort-buffer-size", + "Size of the memory buffer used by each sort process"), new FlaggedOption("sortTmpDir", JSAP.STRING_PARSER, null, JSAP.NOT_REQUIRED, 'T', "temp-dir", "Path to the temporary directory used by sort")}); config = jsap.parse(args); if (jsap.messagePrinted()) { System.exit(1); } } catch (JSAPException e) { System.err.println("Usage error: " + e.getMessage()); System.exit(1); } return config; } public static void main(String[] args) throws IOException, InterruptedException { JSAPResult parsedArgs = parseArgs(args); String datasetPath = parsedArgs.getString("dataset"); String outputBasename = parsedArgs.getString("outputBasename"); String datasetFormat = parsedArgs.getString("format"); String sortBufferSize = parsedArgs.getString("sortBufferSize"); String sortTmpDir = parsedArgs.getString("sortTmpDir", null); (new File(sortTmpDir)).mkdirs(); // Open edge dataset GraphDataset dataset; if (datasetFormat.equals("orc")) { dataset = new ORCGraphDataset(datasetPath); } else if (datasetFormat.equals("csv")) { dataset = new CSVEdgeDataset(datasetPath); } else { throw new IllegalArgumentException("Unknown dataset format: " + datasetFormat); } extractNodes(dataset, outputBasename, sortBufferSize, sortTmpDir); } public static void extractNodes(GraphDataset dataset, String outputBasename, String sortBufferSize, String sortTmpDir) throws IOException, InterruptedException { - // Spawn node sorting process - Process nodeSort = Sort.spawnSort(sortBufferSize, sortTmpDir); - BufferedOutputStream nodeSortStdin = new BufferedOutputStream(nodeSort.getOutputStream()); - BufferedInputStream nodeSortStdout = new BufferedInputStream(nodeSort.getInputStream()); - OutputStream nodesFileOutputStream = new ZstdOutputStream( - new BufferedOutputStream(new FileOutputStream(outputBasename + ".nodes.csv.zst"))); - NodesOutputThread nodesOutputThread = new NodesOutputThread(nodeSortStdout, nodesFileOutputStream); - nodesOutputThread.start(); - - // Spawn label sorting process - Process labelSort = Sort.spawnSort(sortBufferSize, sortTmpDir); - BufferedOutputStream labelSortStdin = new BufferedOutputStream(labelSort.getOutputStream()); - BufferedInputStream labelSortStdout = new BufferedInputStream(labelSort.getInputStream()); - OutputStream labelsFileOutputStream = new ZstdOutputStream( - new BufferedOutputStream(new FileOutputStream(outputBasename + ".labels.csv.zst"))); - LabelsOutputThread labelsOutputThread = new LabelsOutputThread(labelSortStdout, labelsFileOutputStream); - labelsOutputThread.start(); - // Read the dataset and write the nodes and labels to the sorting processes AtomicLong edgeCount = new AtomicLong(0); AtomicLongArray edgeCountByType = new AtomicLongArray(Node.Type.values().length * Node.Type.values().length); - // long[][] edgeCountByType = new long[Node.Type.values().length][Node.Type.values().length]; - ForkJoinPool forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors()); + + int numThreads = Runtime.getRuntime().availableProcessors(); + ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads); + + Process[] nodeSorters = new Process[numThreads]; + String[] nodeBatchPaths = new String[numThreads]; + Process[] labelSorters = new Process[numThreads]; + String[] labelBatchPaths = new String[numThreads]; + long[] progressCounts = new long[numThreads]; + + AtomicInteger nextThreadId = new AtomicInteger(0); + ThreadLocal threadLocalId = ThreadLocal.withInitial(nextThreadId::getAndIncrement); + + ProgressLogger pl = new ProgressLogger(logger, 10, TimeUnit.SECONDS); + pl.itemsName = "edges"; + pl.start("Reading node/edge files and writing sorted batches."); + + GraphDataset.NodeCallback nodeCallback = (node) -> { + int threadId = threadLocalId.get(); + if (nodeSorters[threadId] == null) { + nodeBatchPaths[threadId] = sortTmpDir + "/nodes." + threadId + ".txt"; + nodeSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir, + List.of("-o", nodeBatchPaths[threadId])); + } + OutputStream nodeOutputStream = nodeSorters[threadId].getOutputStream(); + nodeOutputStream.write(node); + nodeOutputStream.write('\n'); + }; + + GraphDataset.NodeCallback labelCallback = (label) -> { + int threadId = threadLocalId.get(); + if (labelSorters[threadId] == null) { + labelBatchPaths[threadId] = sortTmpDir + "/labels." + threadId + ".txt"; + labelSorters[threadId] = Sort.spawnSort(sortBufferSize, sortTmpDir, + List.of("-o", labelBatchPaths[threadId])); + } + OutputStream labelOutputStream = labelSorters[threadId].getOutputStream(); + labelOutputStream.write(label); + labelOutputStream.write('\n'); + }; + try { forkJoinPool.submit(() -> { try { dataset.readEdges((node) -> { - synchronized (nodeSortStdin) { - nodeSortStdin.write(node); - nodeSortStdin.write('\n'); - } + nodeCallback.onNode(node); }, (src, dst, label, perm) -> { - synchronized (nodeSortStdin) { - nodeSortStdin.write(src); - nodeSortStdin.write('\n'); - nodeSortStdin.write(dst); - nodeSortStdin.write('\n'); - } + nodeCallback.onNode(src); + nodeCallback.onNode(dst); + if (label != null) { - synchronized (labelSortStdin) { - labelSortStdin.write(label); - labelSortStdin.write('\n'); - } + labelCallback.onNode(label); } edgeCount.incrementAndGet(); - // edgeCount[0]++; // Extract type of src and dst from their SWHID: swh:1:XXX byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3); byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3); int srcType = Node.Type.byteNameToInt(srcTypeBytes); int dstType = Node.Type.byteNameToInt(dstTypeBytes); if (srcType != -1 && dstType != -1) { edgeCountByType.incrementAndGet(srcType * Node.Type.values().length + dstType); - // edgeCountByType[srcType][dstType].incrementAndGet(); - // edgeCountByType[srcType][dstType]++; } else { System.err.println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + new String(dstTypeBytes)); System.exit(1); } + + int threadId = threadLocalId.get(); + if (++progressCounts[threadId] > 1000) { + synchronized (pl) { + pl.update(progressCounts[threadId]); + } + progressCounts[threadId] = 0; + } }); } catch (IOException e) { throw new RuntimeException(e); } }).get(); } catch (ExecutionException e) { throw new RuntimeException(e); } - long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length]; - for (int i = 0; i < edgeCountByTypeArray.length; i++) { - for (int j = 0; j < edgeCountByTypeArray[i].length; j++) { - edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j); + // Close all the sorters stdin + for (int i = 0; i < numThreads; i++) { + if (nodeSorters[i] != null) { + nodeSorters[i].getOutputStream().close(); + } + if (labelSorters[i] != null) { + labelSorters[i].getOutputStream().close(); } } // Wait for sorting processes to finish - nodeSortStdin.close(); - nodeSort.waitFor(); - labelSortStdin.close(); - labelSort.waitFor(); + for (int i = 0; i < numThreads; i++) { + if (nodeSorters[i] != null) { + nodeSorters[i].waitFor(); + } + if (labelSorters[i] != null) { + labelSorters[i].waitFor(); + } + } + pl.done(); + + ArrayList nodeSortMergerOptions = new ArrayList<>(List.of("-m")); + ArrayList labelSortMergerOptions = new ArrayList<>(List.of("-m")); + for (int i = 0; i < numThreads; i++) { + if (nodeBatchPaths[i] != null) { + nodeSortMergerOptions.add(nodeBatchPaths[i]); + } + if (labelBatchPaths[i] != null) { + labelSortMergerOptions.add(labelBatchPaths[i]); + } + } + + // Spawn node merge-sorting process + Process nodeSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, nodeSortMergerOptions); + nodeSortMerger.getOutputStream().close(); + OutputStream nodesFileOutputStream = new ZstdOutputStream( + new BufferedOutputStream(new FileOutputStream(outputBasename + ".nodes.csv.zst"))); + NodesOutputThread nodesOutputThread = new NodesOutputThread( + new BufferedInputStream(nodeSortMerger.getInputStream()), nodesFileOutputStream); + nodesOutputThread.start(); + // Spawn label merge-sorting process + Process labelSortMerger = Sort.spawnSort(sortBufferSize, sortTmpDir, labelSortMergerOptions); + labelSortMerger.getOutputStream().close(); + OutputStream labelsFileOutputStream = new ZstdOutputStream( + new BufferedOutputStream(new FileOutputStream(outputBasename + ".labels.csv.zst"))); + LabelsOutputThread labelsOutputThread = new LabelsOutputThread( + new BufferedInputStream(labelSortMerger.getInputStream()), labelsFileOutputStream); + labelsOutputThread.start(); + + pl.logger().info("Waiting for merge-sort and writing output files..."); + nodeSortMerger.waitFor(); + labelSortMerger.waitFor(); nodesOutputThread.join(); labelsOutputThread.join(); + long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length]; + for (int i = 0; i < edgeCountByTypeArray.length; i++) { + for (int j = 0; j < edgeCountByTypeArray[i].length; j++) { + edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j); + } + } + // Write node, edge and label counts/statistics printEdgeCounts(outputBasename, edgeCount.get(), edgeCountByTypeArray); printNodeCounts(outputBasename, nodesOutputThread.getNodeCount(), nodesOutputThread.getNodeTypeCounts()); printLabelCounts(outputBasename, labelsOutputThread.getLabelCount()); } private static void printEdgeCounts(String basename, long edgeCount, long[][] edgeTypeCounts) throws IOException { PrintWriter nodeCountWriter = new PrintWriter(basename + ".edges.count.txt"); nodeCountWriter.println(edgeCount); nodeCountWriter.close(); PrintWriter nodeTypesCountWriter = new PrintWriter(basename + ".edges.stats.txt"); TreeMap edgeTypeCountsMap = new TreeMap<>(); for (Node.Type src : Node.Type.values()) { for (Node.Type dst : Node.Type.values()) { long cnt = edgeTypeCounts[Node.Type.toInt(src)][Node.Type.toInt(dst)]; if (cnt > 0) edgeTypeCountsMap.put(src.toString().toLowerCase() + ":" + dst.toString().toLowerCase(), cnt); } } for (Map.Entry entry : edgeTypeCountsMap.entrySet()) { nodeTypesCountWriter.println(entry.getKey() + " " + entry.getValue()); } nodeTypesCountWriter.close(); } private static void printNodeCounts(String basename, long nodeCount, long[] nodeTypeCounts) throws IOException { PrintWriter nodeCountWriter = new PrintWriter(basename + ".nodes.count.txt"); nodeCountWriter.println(nodeCount); nodeCountWriter.close(); PrintWriter nodeTypesCountWriter = new PrintWriter(basename + ".nodes.stats.txt"); TreeMap nodeTypeCountsMap = new TreeMap<>(); for (Node.Type v : Node.Type.values()) { nodeTypeCountsMap.put(v.toString().toLowerCase(), nodeTypeCounts[Node.Type.toInt(v)]); } for (Map.Entry entry : nodeTypeCountsMap.entrySet()) { nodeTypesCountWriter.println(entry.getKey() + " " + entry.getValue()); } nodeTypesCountWriter.close(); } private static void printLabelCounts(String basename, long labelCount) throws IOException { PrintWriter nodeCountWriter = new PrintWriter(basename + ".labels.count.txt"); nodeCountWriter.println(labelCount); nodeCountWriter.close(); } private static class NodesOutputThread extends Thread { private final InputStream sortedNodesStream; private final OutputStream nodesOutputStream; private long nodeCount = 0; private final long[] nodeTypeCounts = new long[Node.Type.values().length]; NodesOutputThread(InputStream sortedNodesStream, OutputStream nodesOutputStream) { this.sortedNodesStream = sortedNodesStream; this.nodesOutputStream = nodesOutputStream; } @Override public void run() { BufferedReader reader = new BufferedReader( new InputStreamReader(sortedNodesStream, StandardCharsets.UTF_8)); try { String line; while ((line = reader.readLine()) != null) { nodesOutputStream.write(line.getBytes(StandardCharsets.UTF_8)); nodesOutputStream.write('\n'); nodeCount++; try { Node.Type nodeType = Node.Type.fromStr(line.split(":")[2]); nodeTypeCounts[Node.Type.toInt(nodeType)]++; } catch (ArrayIndexOutOfBoundsException e) { System.err.println("Error parsing SWHID: " + line); System.exit(1); } } nodesOutputStream.close(); } catch (IOException e) { throw new RuntimeException(e); } } public long getNodeCount() { return nodeCount; } public long[] getNodeTypeCounts() { return nodeTypeCounts; } } private static class LabelsOutputThread extends Thread { private final InputStream sortedLabelsStream; private final OutputStream labelsOutputStream; private long labelCount = 0; LabelsOutputThread(InputStream sortedLabelsStream, OutputStream labelsOutputStream) { this.labelsOutputStream = labelsOutputStream; this.sortedLabelsStream = sortedLabelsStream; } @Override public void run() { BufferedReader reader = new BufferedReader( new InputStreamReader(sortedLabelsStream, StandardCharsets.UTF_8)); try { String line; while ((line = reader.readLine()) != null) { labelsOutputStream.write(line.getBytes(StandardCharsets.UTF_8)); labelsOutputStream.write('\n'); labelCount++; } labelsOutputStream.close(); } catch (IOException e) { throw new RuntimeException(e); } } public long getLabelCount() { return labelCount; } } } diff --git a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java index 4ece391..2181a53 100644 --- a/java/src/main/java/org/softwareheritage/graph/utils/Sort.java +++ b/java/src/main/java/org/softwareheritage/graph/utils/Sort.java @@ -1,25 +1,32 @@ package org.softwareheritage.graph.utils; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; public class Sort { public static Process spawnSort(String sortBufferSize, String sortTmpDir) throws IOException { + return spawnSort(sortBufferSize, sortTmpDir, null); + } + + public static Process spawnSort(String sortBufferSize, String sortTmpDir, List options) throws IOException { ProcessBuilder sortProcessBuilder = new ProcessBuilder(); sortProcessBuilder.redirectError(ProcessBuilder.Redirect.INHERIT); ArrayList command = new ArrayList<>(List.of("sort", "-u", "--buffer-size", sortBufferSize)); if (sortTmpDir != null) { command.add("--temporary-directory"); command.add(sortTmpDir); } + if (options != null) { + command.addAll(options); + } sortProcessBuilder.command(command); Map env = sortProcessBuilder.environment(); env.put("LC_ALL", "C"); env.put("LC_COLLATE", "C"); env.put("LANG", "C"); return sortProcessBuilder.start(); } }