diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java index df5defe71..6997f4896 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java @@ -24,6 +24,7 @@ import java.lang.reflect.Field; import java.nio.channels.FileChannel.MapMode; import java.util.*; +import java.util.concurrent.atomic.AtomicLong; /** * I figured out it would be very hard to win the main competition of the One Billion Rows Challenge. @@ -31,17 +32,59 @@ * * Anyway, if you can make sense out of not exactly idiomatic Java code, and you enjoy pushing performance limits * then QuestDB - the fastest open-source time-series database - is hiring: https://questdb.io/careers/core-database-engineer/ - * + *

+ * Credit + *

+ * I stand on shoulders of giants. I wouldn't be able to code this without analyzing and borrowing from solutions of others. + * People who helped me the most: + *

*/ public class CalculateAverage_jerrinot { private static final Unsafe UNSAFE = unsafe(); private static final String MEASUREMENTS_TXT = "measurements.txt"; // todo: with hyper-threading enable we would be better of with availableProcessors / 2; // todo: validate the testing env. params. - private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors(); - // private static final int THREAD_COUNT = 4; + private static final int EXTRA_THREAD_COUNT = Runtime.getRuntime().availableProcessors() - 1; + // private static final int THREAD_COUNT = 1; private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL; + private static final long NEW_LINE_PATTERN = 0x0A0A0A0A0A0A0A0AL; + private static final int SEGMENT_SIZE = 4 * 1024 * 1024; + + // credits for the idea with lookup tables instead of bit-shifting: abeobk + private static final long[] HASH_MASKS = new long[]{ + 0x0000000000000000L, // semicolon is the first char + 0x00000000000000ffL, + 0x000000000000ffffL, + 0x0000000000ffffffL, + 0x00000000ffffffffL, + 0x000000ffffffffffL, + 0x0000ffffffffffffL, + 0x00ffffffffffffffL, // semicolon is the last char + 0xffffffffffffffffL // there is no semicolon at all + }; + + private static final long[] ADVANCE_MASKS = new long[]{ + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0x0000000000000000L, + 0xffffffffffffffffL, + }; private static Unsafe unsafe() { try { @@ -81,56 +124,29 @@ private static void spawnWorker() throws IOException { static void calculate() throws Exception { final File file = new File(MEASUREMENTS_TXT); final long length = file.length(); - // final int chunkCount = Runtime.getRuntime().availableProcessors(); - int chunkPerThread = 3; - final int chunkCount = THREAD_COUNT * chunkPerThread; - final var chunkStartOffsets = new long[chunkCount + 1]; try (var raf = new RandomAccessFile(file, "r")) { - // credit - chunking code: mtopolnik - final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address(); - for (int i = 1; i < chunkStartOffsets.length - 1; i++) { - var start = length * i / (chunkStartOffsets.length - 1); - raf.seek(start); - while (raf.read() != (byte) '\n') { - } - start = raf.getFilePointer(); - chunkStartOffsets[i] = start + inputBase; - } - chunkStartOffsets[0] = inputBase; - chunkStartOffsets[chunkCount] = inputBase + length; + long fileStart = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address(); + long fileEnd = fileStart + length; + var globalCursor = new AtomicLong(fileStart); - Processor[] processors = new Processor[THREAD_COUNT]; - Thread[] threads = new Thread[THREAD_COUNT]; + Processor[] processors = new Processor[EXTRA_THREAD_COUNT]; + Thread[] threads = new Thread[EXTRA_THREAD_COUNT]; - for (int i = 0; i < THREAD_COUNT - 1; i++) { - long startA = chunkStartOffsets[i * chunkPerThread]; - long endA = chunkStartOffsets[i * chunkPerThread + 1]; - long startB = chunkStartOffsets[i * chunkPerThread + 1]; - long endB = chunkStartOffsets[i * chunkPerThread + 2]; - long startC = chunkStartOffsets[i * chunkPerThread + 2]; - long endC = chunkStartOffsets[i * chunkPerThread + 3]; - - Processor processor = new Processor(startA, endA, startB, endB, startC, endC); - processors[i] = processor; + for (int i = 0; i < EXTRA_THREAD_COUNT; i++) { + Processor processor = new Processor(fileStart, fileEnd, globalCursor); Thread thread = new Thread(processor); + processors[i] = processor; threads[i] = thread; thread.start(); } - int ownIndex = THREAD_COUNT - 1; - long startA = chunkStartOffsets[ownIndex * chunkPerThread]; - long endA = chunkStartOffsets[ownIndex * chunkPerThread + 1]; - long startB = chunkStartOffsets[ownIndex * chunkPerThread + 1]; - long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2]; - long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2]; - long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3]; - Processor processor = new Processor(startA, endA, startB, endB, startC, endC); + Processor processor = new Processor(fileStart, fileEnd, globalCursor); processor.run(); - var accumulator = new TreeMap(); + var accumulator = new TreeMap(); processor.accumulateStatus(accumulator); - for (int i = 0; i < THREAD_COUNT - 1; i++) { + for (int i = 0; i < EXTRA_THREAD_COUNT; i++) { Thread t = threads[i]; t.join(); processors[i].accumulateStatus(accumulator); @@ -140,10 +156,10 @@ static void calculate() throws Exception { } } - private static void printResults(TreeMap accumulator) { + private static void printResults(TreeMap accumulator) { var sb = new StringBuilder(10000); boolean first = true; - for (Map.Entry statsEntry : accumulator.entrySet()) { + for (Map.Entry statsEntry : accumulator.entrySet()) { if (first) { sb.append("{"); first = false; @@ -210,20 +226,17 @@ private static class Processor implements Runnable { private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES; private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES; private static final int MAP_MASK = MAPS_SLOT_COUNT - 1; + private final AtomicLong globalCursor; private long slowMap; private long slowMapNamesPtr; - private long slowMapNamesLo; - // private long fastMap; private long cursorA; private long endA; private long cursorB; private long endB; - private long cursorC; - private long endC; - private HashMap stats = new HashMap<>(1000); - - // private long maxClusterLen; + private HashMap stats = new HashMap<>(1000); + private final long fileEnd; + private final long fileStart; // credit: merykitty private long parseAndStoreTemperature(long startCursor, long baseEntryPtr, long word) { @@ -264,20 +277,12 @@ private static long getDelimiterMask(final long word) { return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L); } - // todo: immutability cost us in allocations, but that's probably peanuts in the grand scheme of things. still worth checking - // maybe JVM trusting Final in Records offsets it ..a test is needed - record StationStats(int min, int max, int count, long sum) { - StationStats mergeWith(StationStats other) { - return new StationStats(Math.min(min, other.min), Math.max(max, other.max), count + other.count, sum + other.sum); - } - } - - void accumulateStatus(TreeMap accumulator) { - for (Map.Entry entry : stats.entrySet()) { + void accumulateStatus(TreeMap accumulator) { + for (Map.Entry entry : stats.entrySet()) { String name = entry.getKey(); - StationStats localStats = entry.getValue(); + CalculateAverage_jerrinot.StationStats localStats = entry.getValue(); - StationStats globalStats = accumulator.get(name); + CalculateAverage_jerrinot.StationStats globalStats = accumulator.get(name); if (globalStats == null) { accumulator.put(name, localStats); } @@ -287,24 +292,10 @@ void accumulateStatus(TreeMap accumulator) { } } - Processor(long startA, long endA, long startB, long endB, long startC, long endC) { - this.cursorA = startA; - this.cursorB = startB; - this.cursorC = startC; - this.endA = endA; - this.endB = endB; - this.endC = endC; - } - - private void doTail(long fastMAp) { - doOne(cursorA, endA); - doOne(cursorB, endB); - doOne(cursorC, endC); - - transferToHeap(fastMAp); - // UNSAFE.freeMemory(fastMap); - // UNSAFE.freeMemory(slowMap); - // UNSAFE.freeMemory(slowMapNamesLo); + Processor(long fileStart, long fileEnd, AtomicLong globalCursor) { + this.globalCursor = globalCursor; + this.fileEnd = fileEnd; + this.fileStart = fileStart; } private void transferToHeap(long fastMap) { @@ -324,7 +315,7 @@ private void transferToHeap(long fastMap) { int count = UNSAFE.getInt(baseAddress + MAP_COUNT_OFFSET); long sum = UNSAFE.getLong(baseAddress + MAP_SUM_OFFSET); - stats.put(name, new StationStats(min, max, count, sum)); + stats.put(name, new CalculateAverage_jerrinot.StationStats(min, max, count, sum)); } for (long baseAddress = fastMap; baseAddress < fastMap + FAST_MAP_SIZE_BYTES; baseAddress += FAST_MAP_ENTRY_SIZE_BYTES) { @@ -345,16 +336,21 @@ private void transferToHeap(long fastMap) { var v = stats.get(name); if (v == null) { - stats.put(name, new StationStats(min, max, count, sum)); + stats.put(name, new CalculateAverage_jerrinot.StationStats(min, max, count, sum)); } else { - stats.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum)); + stats.put(name, new CalculateAverage_jerrinot.StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum)); } } } - private void doOne(long cursor, long endA) { - while (cursor < endA) { + private void doOne(long cursor, long end) { + while (cursor < end) { + // it seems that when pulling just from a single chunk + // then bit-twiddling is faster than lookup tables + // hypothesis: when processing multiple things at once then LOAD latency is partially hidden + // but when processing just one thing then it's better to keep things local as much as possible? maybe:) + long start = cursor; long currentWord = UNSAFE.getLong(cursor); long mask = getDelimiterMask(currentWord); @@ -392,135 +388,139 @@ private static int hash(long word) { return (int) hash; } + private static long nextNewLine(long prev) { + // again: credits to @thomaswue for this code, literally copy'n'paste + while (true) { + long currentWord = UNSAFE.getLong(prev); + long input = currentWord ^ NEW_LINE_PATTERN; + long pos = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L; + if (pos != 0) { + prev += Long.numberOfTrailingZeros(pos) >>> 3; + break; + } + else { + prev += 8; + } + } + return prev; + } + @Override public void run() { + long fastMap = allocateMem(); + for (;;) { + long startingPtr = globalCursor.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE; + if (startingPtr >= fileEnd) { + break; + } + setCursors(startingPtr); + mainLoop(fastMap); + doOne(cursorA, endA); + doOne(cursorB, endB); + } + transferToHeap(fastMap); + } + + private long allocateMem() { this.slowMap = UNSAFE.allocateMemory(SLOW_MAP_SIZE_BYTES); this.slowMapNamesPtr = UNSAFE.allocateMemory(SLOW_MAP_MAP_NAMES_BYTES); - this.slowMapNamesLo = slowMapNamesPtr; long fastMap = UNSAFE.allocateMemory(FAST_MAP_SIZE_BYTES); UNSAFE.setMemory(slowMap, SLOW_MAP_SIZE_BYTES, (byte) 0); UNSAFE.setMemory(fastMap, FAST_MAP_SIZE_BYTES, (byte) 0); UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0); + return fastMap; + } - while (cursorA < endA && cursorB < endB && cursorC < endC) { + private void mainLoop(long fastMap) { + while (cursorA < endA && cursorB < endB) { long currentWordA = UNSAFE.getLong(cursorA); long currentWordB = UNSAFE.getLong(cursorB); - long currentWordC = UNSAFE.getLong(cursorC); - long startA = cursorA; - long startB = cursorB; - long startC = cursorC; + long delimiterMaskA = getDelimiterMask(currentWordA); + long delimiterMaskB = getDelimiterMask(currentWordB); - long maskA = getDelimiterMask(currentWordA); - long maskB = getDelimiterMask(currentWordB); - long maskC = getDelimiterMask(currentWordC); + long candidateWordA = UNSAFE.getLong(cursorA + 8); + long candidateWordB = UNSAFE.getLong(cursorB + 8); - long maskComplementA = -maskA; - long maskComplementB = -maskB; - long maskComplementC = -maskC; + long startA = cursorA; + long startB = cursorB; - long maskWithDelimiterA = (maskA ^ (maskA - 1)); - long maskWithDelimiterB = (maskB ^ (maskB - 1)); - long maskWithDelimiterC = (maskC ^ (maskC - 1)); + int trailingZerosA = Long.numberOfTrailingZeros(delimiterMaskA) >> 3; + int trailingZerosB = Long.numberOfTrailingZeros(delimiterMaskB) >> 3; - long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1); - long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1); - long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1); + long advanceMaskA = ADVANCE_MASKS[trailingZerosA]; + long advanceMaskB = ADVANCE_MASKS[trailingZerosB]; - cursorA += isMaskZeroA << 3; - cursorB += isMaskZeroB << 3; - cursorC += isMaskZeroC << 3; + long wordMaskA = HASH_MASKS[trailingZerosA]; + long wordMaskB = HASH_MASKS[trailingZerosB]; - long nextWordA = UNSAFE.getLong(cursorA); - long nextWordB = UNSAFE.getLong(cursorB); - long nextWordC = UNSAFE.getLong(cursorC); + long negAdvanceMaskA = ~advanceMaskA; + long negAdvanceMaskB = ~advanceMaskB; - long firstWordMaskA = maskWithDelimiterA >>> 8; - long firstWordMaskB = maskWithDelimiterB >>> 8; - long firstWordMaskC = maskWithDelimiterC >>> 8; + cursorA += advanceMaskA & 8; + cursorB += advanceMaskB & 8; - long nextMaskA = getDelimiterMask(nextWordA); - long nextMaskB = getDelimiterMask(nextWordB); - long nextMaskC = getDelimiterMask(nextWordC); + long nextWordA = (advanceMaskA & candidateWordA) | (negAdvanceMaskA & currentWordA); + long nextWordB = (advanceMaskB & candidateWordB) | (negAdvanceMaskB & currentWordB); - boolean slowA = nextMaskA == 0; - boolean slowB = nextMaskB == 0; - boolean slowC = nextMaskC == 0; - boolean slowSome = (slowA || slowB || slowC); + long nextDelimiterMaskA = getDelimiterMask(nextWordA); + long nextDelimiterMaskB = getDelimiterMask(nextWordB); - long extA = -isMaskZeroA; - long extB = -isMaskZeroB; - long extC = -isMaskZeroC; + boolean slowA = nextDelimiterMaskA == 0; + boolean slowB = nextDelimiterMaskB == 0; + boolean slowSome = (slowA || slowB); - long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA; - long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB; - long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC; + long maskedFirstWordA = wordMaskA & currentWordA; + long maskedFirstWordB = wordMaskB & currentWordB; int hashA = hash(maskedFirstWordA); int hashB = hash(maskedFirstWordB); - int hashC = hash(maskedFirstWordC); currentWordA = nextWordA; currentWordB = nextWordB; - currentWordC = nextWordC; - maskA = nextMaskA; - maskB = nextMaskB; - maskC = nextMaskC; + delimiterMaskA = nextDelimiterMaskA; + delimiterMaskB = nextDelimiterMaskB; if (slowSome) { - while (maskA == 0) { + while (delimiterMaskA == 0) { cursorA += 8; currentWordA = UNSAFE.getLong(cursorA); - maskA = getDelimiterMask(currentWordA); + delimiterMaskA = getDelimiterMask(currentWordA); } - while (maskB == 0) { + while (delimiterMaskB == 0) { cursorB += 8; currentWordB = UNSAFE.getLong(cursorB); - maskB = getDelimiterMask(currentWordB); - } - while (maskC == 0) { - cursorC += 8; - currentWordC = UNSAFE.getLong(cursorC); - maskC = getDelimiterMask(currentWordC); + delimiterMaskB = getDelimiterMask(currentWordB); } } - final int delimiterByteA = Long.numberOfTrailingZeros(maskA); - final int delimiterByteB = Long.numberOfTrailingZeros(maskB); - final int delimiterByteC = Long.numberOfTrailingZeros(maskC); + trailingZerosA = Long.numberOfTrailingZeros(delimiterMaskA) >> 3; + trailingZerosB = Long.numberOfTrailingZeros(delimiterMaskB) >> 3; - final long semicolonA = cursorA + (delimiterByteA >> 3); - final long semicolonB = cursorB + (delimiterByteB >> 3); - final long semicolonC = cursorC + (delimiterByteC >> 3); + final long semicolonA = cursorA + trailingZerosA; + final long semicolonB = cursorB + trailingZerosB; long digitStartA = semicolonA + 1; long digitStartB = semicolonB + 1; - long digitStartC = semicolonC + 1; + + long lastWordMaskA = HASH_MASKS[trailingZerosA]; + long lastWordMaskB = HASH_MASKS[trailingZerosB]; long temperatureWordA = UNSAFE.getLong(digitStartA); long temperatureWordB = UNSAFE.getLong(digitStartB); - long temperatureWordC = UNSAFE.getLong(digitStartC); - - long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8; - long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8; - long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8; final long maskedLastWordA = currentWordA & lastWordMaskA; final long maskedLastWordB = currentWordB & lastWordMaskB; - final long maskedLastWordC = currentWordC & lastWordMaskC; int lenA = (int) (semicolonA - startA); int lenB = (int) (semicolonB - startB); - int lenC = (int) (semicolonC - startC); int mapIndexA = hashA & MAP_MASK; int mapIndexB = hashB & MAP_MASK; - int mapIndexC = hashC & MAP_MASK; long baseEntryPtrA; long baseEntryPtrB; - long baseEntryPtrC; if (slowSome) { if (slowA) { @@ -537,25 +537,37 @@ public void run() { baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap); } - if (slowC) { - baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC); - } - else { - baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap); - } } else { baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA, fastMap); baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap); - baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap); } cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA); cursorB = parseAndStoreTemperature(digitStartB, baseEntryPtrB, temperatureWordB); - cursorC = parseAndStoreTemperature(digitStartC, baseEntryPtrC, temperatureWordC); } - doTail(fastMap); - // System.out.println("Longest chain: " + longestChain); + } + + private void setCursors(long current) { + // Credit for the whole work-stealing scheme: @thomaswue + // I have totally stolen it from him. I changed the order a bit to suite my taste better, + // but it's his code + long segmentStart; + if (current == fileStart) { + segmentStart = current; + } + else { + segmentStart = nextNewLine(current) + 1; + } + long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE)); + + long size = (segmentEnd - segmentStart) / 2; + long mid = nextNewLine(segmentStart + size); + + cursorA = segmentStart; + endA = mid; + cursorB = mid + 1; + endB = segmentEnd; } private static long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord, long fastMap) { @@ -625,4 +637,9 @@ private static boolean nameMatchSlow(long start, long namePtr, long fullLen, lon } } + record StationStats(int min, int max, int count, long sum) { + StationStats mergeWith(StationStats other) { + return new StationStats(Math.min(min, other.min), Math.max(max, other.max), count + other.count, sum + other.sum); + } + } }