diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index 92e2f6ecb..8d90dfdd3 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -26,6 +26,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.stream.*; import java.util.*; @@ -45,6 +47,7 @@ public class CalculateAverage_ianopolousfast { public static final int MAX_LINE_LENGTH = 107; public static final int MAX_STATIONS = 1 << 14; private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN); + public static final int CHUNK_SIZE = 32 * 1024 * 1024; private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 16 ? ByteVector.SPECIES_128 : ByteVector.SPECIES_64; @@ -55,11 +58,13 @@ public static void main(String[] args) throws Exception { FileChannel channel = (FileChannel) Files.newByteChannel(input, StandardOpenOption.READ); long filesize = Files.size(input); MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena); - int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors(); - long chunkSize = (filesize + nChunks - 1) / nChunks; - List allResults = IntStream.range(0, nChunks) + int nThreads = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors(); + LinkedBlockingQueue chunks = new LinkedBlockingQueue<>(); + for (long i = 0; i < (filesize + CHUNK_SIZE - 1) / CHUNK_SIZE; i++) + chunks.add(i); + List allResults = IntStream.range(0, nThreads) .parallel() - .mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap)) + .mapToObj(i -> parseStats(chunks, mmap)) .toList(); TreeMap merged = allResults.stream() @@ -163,47 +168,59 @@ private static int keySize(int lineSize, long lineStart, MemorySegment buffer) { ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); } - public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) { - // read first partial line - if (start1 > 0) { - for (int i = 0; i < MAX_LINE_LENGTH; i++) { - byte b = buffer.get(JAVA_BYTE, start1++); - if (b == '\n') { - break; - } - } - } - + public static Stat[] parseStats(LinkedBlockingQueue chunks, MemorySegment buffer) { Stat[] stations = new Stat[MAX_STATIONS]; - // Handle reading the very last few lines in the file - // this allows us to not worry about reading beyond the end - // in the inner loop (reducing branches) - // We need at least the vector lane size bytes back - if (end2 == buffer.byteSize()) { - // reverse at least vector lane width - end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0); - while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n') - end2--; + while (!chunks.isEmpty()) { + long start1, end2; + try { + long nChunk = chunks.poll(0, TimeUnit.MICROSECONDS); + start1 = nChunk * CHUNK_SIZE; + end2 = Math.min(buffer.byteSize(), (nChunk + 1) * CHUNK_SIZE); + } + catch (InterruptedException e) { + continue; + } + // read first partial line + if (start1 > 0) { + int lineSize1 = lineSize(start1, buffer); + start1 += lineSize1 + 1; + } + + // Handle reading the very last few lines in the file + // this allows us to not worry about reading beyond the end + // in the inner loop (reducing branches) + // We need at least the vector lane size bytes back + if (end2 == buffer.byteSize()) { + // reverse at least vector lane width + end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0); + while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n') + end2--; - if (end2 > 0) - end2++; - // copy into a larger buffer to avoid reading off end - MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize()); - for (long i = end2; i < buffer.byteSize(); i++) - end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i)); - int index = 0; - while (end2 + index < buffer.byteSize()) { - int lineSize1 = lineSize(index, end); - int semiSearchStart = index + Math.max(0, lineSize1 - 6); - int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart, - ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); - Stat station1 = dedupeStation(index, index + keySize1, end, stations); - processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1); - index += lineSize1 + 1; + if (end2 > 0) + end2++; + // copy into a larger buffer to avoid reading off end + MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize()); + for (long i = end2; i < buffer.byteSize(); i++) + end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i)); + int index = 0; + while (end2 + index < buffer.byteSize()) { + int lineSize1 = lineSize(index, end); + int semiSearchStart = index + Math.max(0, lineSize1 - 6); + int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart, + ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); + Stat station1 = dedupeStation(index, index + keySize1, end, stations); + processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1); + index += lineSize1 + 1; + } } + + innerLoop(start1, end2, buffer, stations); } + return stations; + } + private static void innerLoop(long start1, long end2, MemorySegment buffer, Stat[] stations) { while (start1 < end2) { int lineSize1 = lineSize(start1, buffer); long start2 = start1 + lineSize1 + 1; @@ -220,7 +237,6 @@ public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) { else start1 += lineSize1 + 1; } - return stations; } public static class Stat {