diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
index df5defe71..6997f4896 100644
--- a/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
+++ b/src/main/java/dev/morling/onebrc/CalculateAverage_jerrinot.java
@@ -24,6 +24,7 @@
import java.lang.reflect.Field;
import java.nio.channels.FileChannel.MapMode;
import java.util.*;
+import java.util.concurrent.atomic.AtomicLong;
/**
* I figured out it would be very hard to win the main competition of the One Billion Rows Challenge.
@@ -31,17 +32,59 @@
*
* Anyway, if you can make sense out of not exactly idiomatic Java code, and you enjoy pushing performance limits
* then QuestDB - the fastest open-source time-series database - is hiring: https://questdb.io/careers/core-database-engineer/
- *
+ *
+ * Credit
+ *
+ * I stand on shoulders of giants. I wouldn't be able to code this without analyzing and borrowing from solutions of others.
+ * People who helped me the most:
+ *
+ * - Thomas Wuerthinger (thomaswue): The munmap() trick and work-stealing. In both cases, I shameless copy-pasted their code.
+ * Including SWAR for detecting new lines. Thomas also gave me helpful hints on how to detect register spilling issues.
+ * - Quan Anh Mai (merykitty): I borrowed their phenomenal branch-free parser.
+ * - Marko Topolnik (mtopolnik): I use a hashing function I saw in his code. It seems the produce good quality hashes
+ * and it's next-level in speed. Marko joined the challenge before me and our discussions made me to join too!
+ * - Van Phu DO (abeobk): I saw the idea with simple lookup tables instead of complicated bit-twiddling in their code first.
+ * - Roy van Rijn (royvanrijn): I borrowed their SWAR code and initially their hash code impl
+ * - Francesco Nigro (franz1981): For our online discussions about performance. Both before and during this challenge.
+ * Francesco gave me the idea to check register spilling.
+ *
*/
public class CalculateAverage_jerrinot {
private static final Unsafe UNSAFE = unsafe();
private static final String MEASUREMENTS_TXT = "measurements.txt";
// todo: with hyper-threading enable we would be better of with availableProcessors / 2;
// todo: validate the testing env. params.
- private static final int THREAD_COUNT = Runtime.getRuntime().availableProcessors();
- // private static final int THREAD_COUNT = 4;
+ private static final int EXTRA_THREAD_COUNT = Runtime.getRuntime().availableProcessors() - 1;
+ // private static final int THREAD_COUNT = 1;
private static final long SEPARATOR_PATTERN = 0x3B3B3B3B3B3B3B3BL;
+ private static final long NEW_LINE_PATTERN = 0x0A0A0A0A0A0A0A0AL;
+ private static final int SEGMENT_SIZE = 4 * 1024 * 1024;
+
+ // credits for the idea with lookup tables instead of bit-shifting: abeobk
+ private static final long[] HASH_MASKS = new long[]{
+ 0x0000000000000000L, // semicolon is the first char
+ 0x00000000000000ffL,
+ 0x000000000000ffffL,
+ 0x0000000000ffffffL,
+ 0x00000000ffffffffL,
+ 0x000000ffffffffffL,
+ 0x0000ffffffffffffL,
+ 0x00ffffffffffffffL, // semicolon is the last char
+ 0xffffffffffffffffL // there is no semicolon at all
+ };
+
+ private static final long[] ADVANCE_MASKS = new long[]{
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0x0000000000000000L,
+ 0xffffffffffffffffL,
+ };
private static Unsafe unsafe() {
try {
@@ -81,56 +124,29 @@ private static void spawnWorker() throws IOException {
static void calculate() throws Exception {
final File file = new File(MEASUREMENTS_TXT);
final long length = file.length();
- // final int chunkCount = Runtime.getRuntime().availableProcessors();
- int chunkPerThread = 3;
- final int chunkCount = THREAD_COUNT * chunkPerThread;
- final var chunkStartOffsets = new long[chunkCount + 1];
try (var raf = new RandomAccessFile(file, "r")) {
- // credit - chunking code: mtopolnik
- final var inputBase = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address();
- for (int i = 1; i < chunkStartOffsets.length - 1; i++) {
- var start = length * i / (chunkStartOffsets.length - 1);
- raf.seek(start);
- while (raf.read() != (byte) '\n') {
- }
- start = raf.getFilePointer();
- chunkStartOffsets[i] = start + inputBase;
- }
- chunkStartOffsets[0] = inputBase;
- chunkStartOffsets[chunkCount] = inputBase + length;
+ long fileStart = raf.getChannel().map(MapMode.READ_ONLY, 0, length, Arena.global()).address();
+ long fileEnd = fileStart + length;
+ var globalCursor = new AtomicLong(fileStart);
- Processor[] processors = new Processor[THREAD_COUNT];
- Thread[] threads = new Thread[THREAD_COUNT];
+ Processor[] processors = new Processor[EXTRA_THREAD_COUNT];
+ Thread[] threads = new Thread[EXTRA_THREAD_COUNT];
- for (int i = 0; i < THREAD_COUNT - 1; i++) {
- long startA = chunkStartOffsets[i * chunkPerThread];
- long endA = chunkStartOffsets[i * chunkPerThread + 1];
- long startB = chunkStartOffsets[i * chunkPerThread + 1];
- long endB = chunkStartOffsets[i * chunkPerThread + 2];
- long startC = chunkStartOffsets[i * chunkPerThread + 2];
- long endC = chunkStartOffsets[i * chunkPerThread + 3];
-
- Processor processor = new Processor(startA, endA, startB, endB, startC, endC);
- processors[i] = processor;
+ for (int i = 0; i < EXTRA_THREAD_COUNT; i++) {
+ Processor processor = new Processor(fileStart, fileEnd, globalCursor);
Thread thread = new Thread(processor);
+ processors[i] = processor;
threads[i] = thread;
thread.start();
}
- int ownIndex = THREAD_COUNT - 1;
- long startA = chunkStartOffsets[ownIndex * chunkPerThread];
- long endA = chunkStartOffsets[ownIndex * chunkPerThread + 1];
- long startB = chunkStartOffsets[ownIndex * chunkPerThread + 1];
- long endB = chunkStartOffsets[ownIndex * chunkPerThread + 2];
- long startC = chunkStartOffsets[ownIndex * chunkPerThread + 2];
- long endC = chunkStartOffsets[ownIndex * chunkPerThread + 3];
- Processor processor = new Processor(startA, endA, startB, endB, startC, endC);
+ Processor processor = new Processor(fileStart, fileEnd, globalCursor);
processor.run();
- var accumulator = new TreeMap();
+ var accumulator = new TreeMap();
processor.accumulateStatus(accumulator);
- for (int i = 0; i < THREAD_COUNT - 1; i++) {
+ for (int i = 0; i < EXTRA_THREAD_COUNT; i++) {
Thread t = threads[i];
t.join();
processors[i].accumulateStatus(accumulator);
@@ -140,10 +156,10 @@ static void calculate() throws Exception {
}
}
- private static void printResults(TreeMap accumulator) {
+ private static void printResults(TreeMap accumulator) {
var sb = new StringBuilder(10000);
boolean first = true;
- for (Map.Entry statsEntry : accumulator.entrySet()) {
+ for (Map.Entry statsEntry : accumulator.entrySet()) {
if (first) {
sb.append("{");
first = false;
@@ -210,20 +226,17 @@ private static class Processor implements Runnable {
private static final int FAST_MAP_SIZE_BYTES = MAPS_SLOT_COUNT * FAST_MAP_ENTRY_SIZE_BYTES;
private static final int SLOW_MAP_MAP_NAMES_BYTES = MAX_UNIQUE_KEYS * STATION_MAX_NAME_BYTES;
private static final int MAP_MASK = MAPS_SLOT_COUNT - 1;
+ private final AtomicLong globalCursor;
private long slowMap;
private long slowMapNamesPtr;
- private long slowMapNamesLo;
- // private long fastMap;
private long cursorA;
private long endA;
private long cursorB;
private long endB;
- private long cursorC;
- private long endC;
- private HashMap stats = new HashMap<>(1000);
-
- // private long maxClusterLen;
+ private HashMap stats = new HashMap<>(1000);
+ private final long fileEnd;
+ private final long fileStart;
// credit: merykitty
private long parseAndStoreTemperature(long startCursor, long baseEntryPtr, long word) {
@@ -264,20 +277,12 @@ private static long getDelimiterMask(final long word) {
return (match - 0x0101010101010101L) & (~match & 0x8080808080808080L);
}
- // todo: immutability cost us in allocations, but that's probably peanuts in the grand scheme of things. still worth checking
- // maybe JVM trusting Final in Records offsets it ..a test is needed
- record StationStats(int min, int max, int count, long sum) {
- StationStats mergeWith(StationStats other) {
- return new StationStats(Math.min(min, other.min), Math.max(max, other.max), count + other.count, sum + other.sum);
- }
- }
-
- void accumulateStatus(TreeMap accumulator) {
- for (Map.Entry entry : stats.entrySet()) {
+ void accumulateStatus(TreeMap accumulator) {
+ for (Map.Entry entry : stats.entrySet()) {
String name = entry.getKey();
- StationStats localStats = entry.getValue();
+ CalculateAverage_jerrinot.StationStats localStats = entry.getValue();
- StationStats globalStats = accumulator.get(name);
+ CalculateAverage_jerrinot.StationStats globalStats = accumulator.get(name);
if (globalStats == null) {
accumulator.put(name, localStats);
}
@@ -287,24 +292,10 @@ void accumulateStatus(TreeMap accumulator) {
}
}
- Processor(long startA, long endA, long startB, long endB, long startC, long endC) {
- this.cursorA = startA;
- this.cursorB = startB;
- this.cursorC = startC;
- this.endA = endA;
- this.endB = endB;
- this.endC = endC;
- }
-
- private void doTail(long fastMAp) {
- doOne(cursorA, endA);
- doOne(cursorB, endB);
- doOne(cursorC, endC);
-
- transferToHeap(fastMAp);
- // UNSAFE.freeMemory(fastMap);
- // UNSAFE.freeMemory(slowMap);
- // UNSAFE.freeMemory(slowMapNamesLo);
+ Processor(long fileStart, long fileEnd, AtomicLong globalCursor) {
+ this.globalCursor = globalCursor;
+ this.fileEnd = fileEnd;
+ this.fileStart = fileStart;
}
private void transferToHeap(long fastMap) {
@@ -324,7 +315,7 @@ private void transferToHeap(long fastMap) {
int count = UNSAFE.getInt(baseAddress + MAP_COUNT_OFFSET);
long sum = UNSAFE.getLong(baseAddress + MAP_SUM_OFFSET);
- stats.put(name, new StationStats(min, max, count, sum));
+ stats.put(name, new CalculateAverage_jerrinot.StationStats(min, max, count, sum));
}
for (long baseAddress = fastMap; baseAddress < fastMap + FAST_MAP_SIZE_BYTES; baseAddress += FAST_MAP_ENTRY_SIZE_BYTES) {
@@ -345,16 +336,21 @@ private void transferToHeap(long fastMap) {
var v = stats.get(name);
if (v == null) {
- stats.put(name, new StationStats(min, max, count, sum));
+ stats.put(name, new CalculateAverage_jerrinot.StationStats(min, max, count, sum));
}
else {
- stats.put(name, new StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum));
+ stats.put(name, new CalculateAverage_jerrinot.StationStats(Math.min(v.min, min), Math.max(v.max, max), v.count + count, v.sum + sum));
}
}
}
- private void doOne(long cursor, long endA) {
- while (cursor < endA) {
+ private void doOne(long cursor, long end) {
+ while (cursor < end) {
+ // it seems that when pulling just from a single chunk
+ // then bit-twiddling is faster than lookup tables
+ // hypothesis: when processing multiple things at once then LOAD latency is partially hidden
+ // but when processing just one thing then it's better to keep things local as much as possible? maybe:)
+
long start = cursor;
long currentWord = UNSAFE.getLong(cursor);
long mask = getDelimiterMask(currentWord);
@@ -392,135 +388,139 @@ private static int hash(long word) {
return (int) hash;
}
+ private static long nextNewLine(long prev) {
+ // again: credits to @thomaswue for this code, literally copy'n'paste
+ while (true) {
+ long currentWord = UNSAFE.getLong(prev);
+ long input = currentWord ^ NEW_LINE_PATTERN;
+ long pos = (input - 0x0101010101010101L) & ~input & 0x8080808080808080L;
+ if (pos != 0) {
+ prev += Long.numberOfTrailingZeros(pos) >>> 3;
+ break;
+ }
+ else {
+ prev += 8;
+ }
+ }
+ return prev;
+ }
+
@Override
public void run() {
+ long fastMap = allocateMem();
+ for (;;) {
+ long startingPtr = globalCursor.addAndGet(SEGMENT_SIZE) - SEGMENT_SIZE;
+ if (startingPtr >= fileEnd) {
+ break;
+ }
+ setCursors(startingPtr);
+ mainLoop(fastMap);
+ doOne(cursorA, endA);
+ doOne(cursorB, endB);
+ }
+ transferToHeap(fastMap);
+ }
+
+ private long allocateMem() {
this.slowMap = UNSAFE.allocateMemory(SLOW_MAP_SIZE_BYTES);
this.slowMapNamesPtr = UNSAFE.allocateMemory(SLOW_MAP_MAP_NAMES_BYTES);
- this.slowMapNamesLo = slowMapNamesPtr;
long fastMap = UNSAFE.allocateMemory(FAST_MAP_SIZE_BYTES);
UNSAFE.setMemory(slowMap, SLOW_MAP_SIZE_BYTES, (byte) 0);
UNSAFE.setMemory(fastMap, FAST_MAP_SIZE_BYTES, (byte) 0);
UNSAFE.setMemory(slowMapNamesPtr, SLOW_MAP_MAP_NAMES_BYTES, (byte) 0);
+ return fastMap;
+ }
- while (cursorA < endA && cursorB < endB && cursorC < endC) {
+ private void mainLoop(long fastMap) {
+ while (cursorA < endA && cursorB < endB) {
long currentWordA = UNSAFE.getLong(cursorA);
long currentWordB = UNSAFE.getLong(cursorB);
- long currentWordC = UNSAFE.getLong(cursorC);
- long startA = cursorA;
- long startB = cursorB;
- long startC = cursorC;
+ long delimiterMaskA = getDelimiterMask(currentWordA);
+ long delimiterMaskB = getDelimiterMask(currentWordB);
- long maskA = getDelimiterMask(currentWordA);
- long maskB = getDelimiterMask(currentWordB);
- long maskC = getDelimiterMask(currentWordC);
+ long candidateWordA = UNSAFE.getLong(cursorA + 8);
+ long candidateWordB = UNSAFE.getLong(cursorB + 8);
- long maskComplementA = -maskA;
- long maskComplementB = -maskB;
- long maskComplementC = -maskC;
+ long startA = cursorA;
+ long startB = cursorB;
- long maskWithDelimiterA = (maskA ^ (maskA - 1));
- long maskWithDelimiterB = (maskB ^ (maskB - 1));
- long maskWithDelimiterC = (maskC ^ (maskC - 1));
+ int trailingZerosA = Long.numberOfTrailingZeros(delimiterMaskA) >> 3;
+ int trailingZerosB = Long.numberOfTrailingZeros(delimiterMaskB) >> 3;
- long isMaskZeroA = (((maskA | maskComplementA) >>> 63) ^ 1);
- long isMaskZeroB = (((maskB | maskComplementB) >>> 63) ^ 1);
- long isMaskZeroC = (((maskC | maskComplementC) >>> 63) ^ 1);
+ long advanceMaskA = ADVANCE_MASKS[trailingZerosA];
+ long advanceMaskB = ADVANCE_MASKS[trailingZerosB];
- cursorA += isMaskZeroA << 3;
- cursorB += isMaskZeroB << 3;
- cursorC += isMaskZeroC << 3;
+ long wordMaskA = HASH_MASKS[trailingZerosA];
+ long wordMaskB = HASH_MASKS[trailingZerosB];
- long nextWordA = UNSAFE.getLong(cursorA);
- long nextWordB = UNSAFE.getLong(cursorB);
- long nextWordC = UNSAFE.getLong(cursorC);
+ long negAdvanceMaskA = ~advanceMaskA;
+ long negAdvanceMaskB = ~advanceMaskB;
- long firstWordMaskA = maskWithDelimiterA >>> 8;
- long firstWordMaskB = maskWithDelimiterB >>> 8;
- long firstWordMaskC = maskWithDelimiterC >>> 8;
+ cursorA += advanceMaskA & 8;
+ cursorB += advanceMaskB & 8;
- long nextMaskA = getDelimiterMask(nextWordA);
- long nextMaskB = getDelimiterMask(nextWordB);
- long nextMaskC = getDelimiterMask(nextWordC);
+ long nextWordA = (advanceMaskA & candidateWordA) | (negAdvanceMaskA & currentWordA);
+ long nextWordB = (advanceMaskB & candidateWordB) | (negAdvanceMaskB & currentWordB);
- boolean slowA = nextMaskA == 0;
- boolean slowB = nextMaskB == 0;
- boolean slowC = nextMaskC == 0;
- boolean slowSome = (slowA || slowB || slowC);
+ long nextDelimiterMaskA = getDelimiterMask(nextWordA);
+ long nextDelimiterMaskB = getDelimiterMask(nextWordB);
- long extA = -isMaskZeroA;
- long extB = -isMaskZeroB;
- long extC = -isMaskZeroC;
+ boolean slowA = nextDelimiterMaskA == 0;
+ boolean slowB = nextDelimiterMaskB == 0;
+ boolean slowSome = (slowA || slowB);
- long maskedFirstWordA = (extA | firstWordMaskA) & currentWordA;
- long maskedFirstWordB = (extB | firstWordMaskB) & currentWordB;
- long maskedFirstWordC = (extC | firstWordMaskC) & currentWordC;
+ long maskedFirstWordA = wordMaskA & currentWordA;
+ long maskedFirstWordB = wordMaskB & currentWordB;
int hashA = hash(maskedFirstWordA);
int hashB = hash(maskedFirstWordB);
- int hashC = hash(maskedFirstWordC);
currentWordA = nextWordA;
currentWordB = nextWordB;
- currentWordC = nextWordC;
- maskA = nextMaskA;
- maskB = nextMaskB;
- maskC = nextMaskC;
+ delimiterMaskA = nextDelimiterMaskA;
+ delimiterMaskB = nextDelimiterMaskB;
if (slowSome) {
- while (maskA == 0) {
+ while (delimiterMaskA == 0) {
cursorA += 8;
currentWordA = UNSAFE.getLong(cursorA);
- maskA = getDelimiterMask(currentWordA);
+ delimiterMaskA = getDelimiterMask(currentWordA);
}
- while (maskB == 0) {
+ while (delimiterMaskB == 0) {
cursorB += 8;
currentWordB = UNSAFE.getLong(cursorB);
- maskB = getDelimiterMask(currentWordB);
- }
- while (maskC == 0) {
- cursorC += 8;
- currentWordC = UNSAFE.getLong(cursorC);
- maskC = getDelimiterMask(currentWordC);
+ delimiterMaskB = getDelimiterMask(currentWordB);
}
}
- final int delimiterByteA = Long.numberOfTrailingZeros(maskA);
- final int delimiterByteB = Long.numberOfTrailingZeros(maskB);
- final int delimiterByteC = Long.numberOfTrailingZeros(maskC);
+ trailingZerosA = Long.numberOfTrailingZeros(delimiterMaskA) >> 3;
+ trailingZerosB = Long.numberOfTrailingZeros(delimiterMaskB) >> 3;
- final long semicolonA = cursorA + (delimiterByteA >> 3);
- final long semicolonB = cursorB + (delimiterByteB >> 3);
- final long semicolonC = cursorC + (delimiterByteC >> 3);
+ final long semicolonA = cursorA + trailingZerosA;
+ final long semicolonB = cursorB + trailingZerosB;
long digitStartA = semicolonA + 1;
long digitStartB = semicolonB + 1;
- long digitStartC = semicolonC + 1;
+
+ long lastWordMaskA = HASH_MASKS[trailingZerosA];
+ long lastWordMaskB = HASH_MASKS[trailingZerosB];
long temperatureWordA = UNSAFE.getLong(digitStartA);
long temperatureWordB = UNSAFE.getLong(digitStartB);
- long temperatureWordC = UNSAFE.getLong(digitStartC);
-
- long lastWordMaskA = ((maskA - 1) ^ maskA) >>> 8;
- long lastWordMaskB = ((maskB - 1) ^ maskB) >>> 8;
- long lastWordMaskC = ((maskC - 1) ^ maskC) >>> 8;
final long maskedLastWordA = currentWordA & lastWordMaskA;
final long maskedLastWordB = currentWordB & lastWordMaskB;
- final long maskedLastWordC = currentWordC & lastWordMaskC;
int lenA = (int) (semicolonA - startA);
int lenB = (int) (semicolonB - startB);
- int lenC = (int) (semicolonC - startC);
int mapIndexA = hashA & MAP_MASK;
int mapIndexB = hashB & MAP_MASK;
- int mapIndexC = hashC & MAP_MASK;
long baseEntryPtrA;
long baseEntryPtrB;
- long baseEntryPtrC;
if (slowSome) {
if (slowA) {
@@ -537,25 +537,37 @@ public void run() {
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap);
}
- if (slowC) {
- baseEntryPtrC = getOrCreateEntryBaseOffsetSlow(lenC, startC, hashC, maskedLastWordC);
- }
- else {
- baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap);
- }
}
else {
baseEntryPtrA = getOrCreateEntryBaseOffsetFast(mapIndexA, lenA, maskedLastWordA, maskedFirstWordA, fastMap);
baseEntryPtrB = getOrCreateEntryBaseOffsetFast(mapIndexB, lenB, maskedLastWordB, maskedFirstWordB, fastMap);
- baseEntryPtrC = getOrCreateEntryBaseOffsetFast(mapIndexC, lenC, maskedLastWordC, maskedFirstWordC, fastMap);
}
cursorA = parseAndStoreTemperature(digitStartA, baseEntryPtrA, temperatureWordA);
cursorB = parseAndStoreTemperature(digitStartB, baseEntryPtrB, temperatureWordB);
- cursorC = parseAndStoreTemperature(digitStartC, baseEntryPtrC, temperatureWordC);
}
- doTail(fastMap);
- // System.out.println("Longest chain: " + longestChain);
+ }
+
+ private void setCursors(long current) {
+ // Credit for the whole work-stealing scheme: @thomaswue
+ // I have totally stolen it from him. I changed the order a bit to suite my taste better,
+ // but it's his code
+ long segmentStart;
+ if (current == fileStart) {
+ segmentStart = current;
+ }
+ else {
+ segmentStart = nextNewLine(current) + 1;
+ }
+ long segmentEnd = nextNewLine(Math.min(fileEnd - 1, current + SEGMENT_SIZE));
+
+ long size = (segmentEnd - segmentStart) / 2;
+ long mid = nextNewLine(segmentStart + size);
+
+ cursorA = segmentStart;
+ endA = mid;
+ cursorB = mid + 1;
+ endB = segmentEnd;
}
private static long getOrCreateEntryBaseOffsetFast(int mapIndexA, int lenA, long maskedLastWord, long maskedFirstWord, long fastMap) {
@@ -625,4 +637,9 @@ private static boolean nameMatchSlow(long start, long namePtr, long fullLen, lon
}
}
+ record StationStats(int min, int max, int count, long sum) {
+ StationStats mergeWith(StationStats other) {
+ return new StationStats(Math.min(min, other.min), Math.max(max, other.max), count + other.count, sum + other.sum);
+ }
+ }
}