diff --git a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java --- a/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java +++ b/java/src/main/java/org/softwareheritage/graph/compress/ExtractNodes.java @@ -8,6 +8,10 @@ import java.io.*; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicLongArray; /** * Read a graph dataset and extract all the unique node SWHIDs it contains, including the ones that @@ -117,34 +121,62 @@ labelsOutputThread.start(); // Read the dataset and write the nodes and labels to the sorting processes - long[] edgeCount = {0}; - long[][] edgeCountByType = new long[Node.Type.values().length][Node.Type.values().length]; - dataset.readEdges((node) -> { - nodeSortStdin.write(node); - nodeSortStdin.write('\n'); - }, (src, dst, label, perm) -> { - nodeSortStdin.write(src); - nodeSortStdin.write('\n'); - nodeSortStdin.write(dst); - nodeSortStdin.write('\n'); - if (label != null) { - labelSortStdin.write(label); - labelSortStdin.write('\n'); - } - edgeCount[0]++; - // Extract type of src and dst from their SWHID: swh:1:XXX - byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3); - byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3); - int srcType = Node.Type.byteNameToInt(srcTypeBytes); - int dstType = Node.Type.byteNameToInt(dstTypeBytes); - if (srcType != -1 && dstType != -1) { - edgeCountByType[srcType][dstType]++; - } else { - System.err - .println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + new String(dstTypeBytes)); - System.exit(1); + AtomicLong edgeCount = new AtomicLong(0); + AtomicLongArray edgeCountByType = new AtomicLongArray(Node.Type.values().length * Node.Type.values().length); + // long[][] edgeCountByType = new long[Node.Type.values().length][Node.Type.values().length]; + ForkJoinPool forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors()); + try { + forkJoinPool.submit(() -> { + try { + dataset.readEdges((node) -> { + synchronized (nodeSortStdin) { + nodeSortStdin.write(node); + nodeSortStdin.write('\n'); + } + }, (src, dst, label, perm) -> { + synchronized (nodeSortStdin) { + nodeSortStdin.write(src); + nodeSortStdin.write('\n'); + nodeSortStdin.write(dst); + nodeSortStdin.write('\n'); + } + if (label != null) { + synchronized (labelSortStdin) { + labelSortStdin.write(label); + labelSortStdin.write('\n'); + } + } + edgeCount.incrementAndGet(); + // edgeCount[0]++; + // Extract type of src and dst from their SWHID: swh:1:XXX + byte[] srcTypeBytes = Arrays.copyOfRange(src, 6, 6 + 3); + byte[] dstTypeBytes = Arrays.copyOfRange(dst, 6, 6 + 3); + int srcType = Node.Type.byteNameToInt(srcTypeBytes); + int dstType = Node.Type.byteNameToInt(dstTypeBytes); + if (srcType != -1 && dstType != -1) { + edgeCountByType.incrementAndGet(srcType * Node.Type.values().length + dstType); + // edgeCountByType[srcType][dstType].incrementAndGet(); + // edgeCountByType[srcType][dstType]++; + } else { + System.err.println("Invalid edge type: " + new String(srcTypeBytes) + " -> " + + new String(dstTypeBytes)); + System.exit(1); + } + }); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).get(); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + + long[][] edgeCountByTypeArray = new long[Node.Type.values().length][Node.Type.values().length]; + for (int i = 0; i < edgeCountByTypeArray.length; i++) { + for (int j = 0; j < edgeCountByTypeArray[i].length; j++) { + edgeCountByTypeArray[i][j] = edgeCountByType.get(i * Node.Type.values().length + j); } - }); + } // Wait for sorting processes to finish nodeSortStdin.close(); @@ -156,7 +188,7 @@ labelsOutputThread.join(); // Write node, edge and label counts/statistics - printEdgeCounts(outputBasename, edgeCount[0], edgeCountByType); + printEdgeCounts(outputBasename, edgeCount.get(), edgeCountByTypeArray); printNodeCounts(outputBasename, nodesOutputThread.getNodeCount(), nodesOutputThread.getNodeTypeCounts()); printLabelCounts(outputBasename, labelsOutputThread.getLabelCount()); }