From 1e7314d5fb4ec948461ff1e5b49a610efbab25e6 Mon Sep 17 00:00:00 2001 From: gonix Date: Thu, 1 Feb 2024 12:53:46 +0200 Subject: [PATCH] CalculateAverage_gonix update (#706) Backported some of the optimizations from unsafe solution. Co-authored-by: Giedrius D --- calculate_average_gonix.sh | 4 +- .../onebrc/CalculateAverage_gonix.java | 509 +++++++++++------- 2 files changed, 312 insertions(+), 201 deletions(-) diff --git a/calculate_average_gonix.sh b/calculate_average_gonix.sh index a6f91655f..c3f00893c 100755 --- a/calculate_average_gonix.sh +++ b/calculate_average_gonix.sh @@ -1,4 +1,4 @@ -#!/bin/sh +#!/bin/bash # # Copyright 2023 The original authors # @@ -17,4 +17,4 @@ JAVA_OPTS="--enable-preview" -java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gonix +exec cat < <(exec java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_gonix) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java index 572c272ca..cbc1127ae 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_gonix.java @@ -46,6 +46,7 @@ public static void main(String[] args) throws IOException { TreeMap::new)); System.out.println(res); + System.out.close(); } private static List buildChunks(RandomAccessFile file) throws IOException { @@ -75,248 +76,358 @@ private static List buildChunks(RandomAccessFile file) throws } return chunks; } -} -class Aggregator { - private static final int MAX_STATIONS = 10_000; - private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5; - private static final int INDEX_SIZE = 1024 * 1024; - private static final int INDEX_MASK = INDEX_SIZE - 1; - private static final int FLD_COUNT = 0; - private static final int FLD_SUM = 1; - private static final int FLD_MIN = 2; - private static final int FLD_MAX = 3; - - // Poor man's hash map: hash code to offset in `mem`. - private final int[] index; - - // Contiguous storage of key (station name) and stats fields of all - // unique stations. - // The idea here is to improve locality so that stats fields would - // possibly be already in the CPU cache after we are done comparing - // the key. - private final long[] mem; - private int memUsed; - - Aggregator() { - assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2"; - assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS"; - - index = new int[INDEX_SIZE]; - mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)]; - memUsed = 1; - } + private static class Aggregator { + private static final int MAX_STATIONS = 10_000; + private static final int MAX_STATION_SIZE = Math.ceilDiv(100, 8) + 5; + private static final int INDEX_SIZE = 1024 * 1024; + private static final int INDEX_MASK = INDEX_SIZE - 1; + private static final int FLD_COUNT = 0; + private static final int FLD_SUM = 1; + private static final int FLD_MIN = 2; + private static final int FLD_MAX = 3; + + // Poor man's hash map: hash code to offset in `mem`. + private final int[] index; + + // Contiguous storage of key (station name) and stats fields of all + // unique stations. + // The idea here is to improve locality so that stats fields would + // possibly be already in the CPU cache after we are done comparing + // the key. + private final long[] mem; + private int memUsed; - Aggregator processChunk(MappedByteBuffer buf) { - // To avoid checking if it is safe to read a whole long near the - // end of a chunk, we copy last couple of lines to a padded buffer - // and process that part separately. - int limit = buf.limit(); - int pos = Math.max(limit - 16, -1); - while (pos >= 0 && buf.get(pos) != '\n') { - pos--; + Aggregator() { + assert ((INDEX_SIZE & (INDEX_SIZE - 1)) == 0) : "INDEX_SIZE must be power of 2"; + assert (INDEX_SIZE > MAX_STATIONS) : "INDEX_SIZE must be greater than MAX_STATIONS"; + + index = new int[INDEX_SIZE]; + mem = new long[1 + (MAX_STATIONS * MAX_STATION_SIZE)]; + memUsed = 1; } - pos++; - if (pos > 0) { - processChunkLongs(buf, pos); + + Aggregator processChunk(MappedByteBuffer buf) { + // To avoid checking if it is safe to read a whole long near the + // end of a chunk, we copy last couple of lines to a padded buffer + // and process that part separately. + int limit = buf.limit(); + int pos = Math.max(limit - 16, -1); + while (pos >= 0 && buf.get(pos) != '\n') { + pos--; + } + pos++; + if (pos > 0) { + processChunkLongs(buf, pos); + } + int tailLen = limit - pos; + var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder()); + buf.get(pos, tailBuf.array(), 0, tailLen); + processChunkLongs(tailBuf, tailLen); + return this; } - int tailLen = limit - pos; - var tailBuf = ByteBuffer.allocate(tailLen + 8).order(ByteOrder.nativeOrder()); - buf.get(pos, tailBuf.array(), 0, tailLen); - processChunkLongs(tailBuf, tailLen); - return this; - } - Aggregator processChunkLongs(ByteBuffer buf, int limit) { - int pos = 0; - while (pos < limit) { - - int start = pos; - int hash = 0; - long tail = 0; - while (true) { - // Seen this trick used in multiple other solutions. - // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord - long tmpLong = buf.getLong(pos); - long match = tmpLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';' - match = ((match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L)); - if (match == 0) { - hash = ((33 * hash) ^ (int) (tmpLong & 0xFFFFFFFF)) + (int) ((tmpLong >>> 33) & 0xFFFFFFFF); - pos += 8; + Aggregator processChunkLongs(ByteBuffer buf, int limit) { + int pos = 0; + while (pos < limit) { + + int start = pos; + long keyLong = buf.getLong(pos); + long valueSepMark = valueSepMark(keyLong); + if (valueSepMark != 0) { + int tailBits = tailBits(valueSepMark); + pos += valueOffset(tailBits); + // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (1), pos=" + (pos - startAddr); + long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1); + + long valueLong = buf.getLong(pos); + int decimalSepMark = decimalSepMark(valueLong); + pos += nextKeyOffset(decimalSepMark); + // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (1), pos=" + (pos - startAddr); + int measurement = decimalValue(decimalSepMark, valueLong); + + add1(buf, start, tailAndLen, hash(hash1(tailAndLen)), measurement); continue; } - int tailBits = Long.numberOfTrailingZeros(match >>> 7); - long tailMask = ~(-1L << tailBits); - tail = tmpLong & tailMask; - hash = ((33 * hash) ^ (int) (tail & 0xFFFFFFFF)) + (int) ((tail >>> 33) & 0xFFFFFFFF); - pos += tailBits >> 3; - break; - } - hash = (33 * hash) ^ (hash >>> 15); - int lenInLongs = (pos - start) >> 3; - long tailAndLen = (tail << 8) | (lenInLongs & 0xFF); - // assert (buf.get(pos) == ';') : "Expected ';'"; - pos++; + pos += 8; + long keyLong1 = keyLong; + keyLong = buf.getLong(pos); + valueSepMark = valueSepMark(keyLong); + if (valueSepMark != 0) { + int tailBits = tailBits(valueSepMark); + pos += valueOffset(tailBits); + // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (2), pos=" + (pos - startAddr); + long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1); + + long valueLong = buf.getLong(pos); + int decimalSepMark = decimalSepMark(valueLong); + pos += nextKeyOffset(decimalSepMark); + // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (2), pos=" + (pos - startAddr); + int measurement = decimalValue(decimalSepMark, valueLong); + + add2(buf, start, keyLong1, tailAndLen, hash(hash(hash1(keyLong1), tailAndLen)), measurement); + continue; + } - int measurement; - { - // Seen this trick used in multiple other solutions. - // Looks like the original author is @merykitty. - long tmpLong = buf.getLong(pos); - - // The 4th binary digit of the ascii of a digit is 1 while - // that of the '.' is 0. This finds the decimal separator - // The value can be 12, 20, 28 - int decimalSepPos = Long.numberOfTrailingZeros(~tmpLong & 0x10101000); - int shift = 28 - decimalSepPos; - // signed is -1 if negative, 0 otherwise - long signed = (~tmpLong << 59) >> 63; - long designMask = ~(signed & 0xFF); - // Align the number to a specific position and transform the ascii code - // to actual digit value in each byte - long digits = ((tmpLong & designMask) << shift) & 0x0F000F0F00L; - - // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) - // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = - // 0x000000UU00TTHH00 + - // 0x00UU00TTHH000000 * 10 + - // 0xUU00TTHH00000000 * 100 - // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 - // This results in our value lies in the bit 32 to 41 of this product - // That was close :) - long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; - measurement = (int) ((absValue ^ signed) - signed); - pos += (decimalSepPos >>> 3) + 3; + long hash = hash1(keyLong1); + do { + pos += 8; + hash = hash(hash, keyLong); + keyLong = buf.getLong(pos); + valueSepMark = valueSepMark(keyLong); + } while (valueSepMark == 0); + int tailBits = tailBits(valueSepMark); + pos += valueOffset(tailBits); + // assert (UNSAFE.getByte(pos - 1) == ';') : "Expected ';' (N), pos=" + (pos - startAddr); + long tailAndLen = tailAndLen(tailBits, keyLong, pos - start - 1); + hash = hash(hash, tailAndLen); + + long valueLong = buf.getLong(pos); + int decimalSepMark = decimalSepMark(valueLong); + pos += nextKeyOffset(decimalSepMark); + // assert (UNSAFE.getByte(pos - 1) == '\n') : "Expected '\\n' (N), pos=" + (pos - startAddr); + int measurement = decimalValue(decimalSepMark, valueLong); + + addN(buf, start, tailAndLen, hash(hash), measurement); } - // assert (buf.get(pos - 1) == '\n') : "Expected '\\n'"; - add(buf, start, tailAndLen, hash, measurement); + return this; } - return this; - } + public Stream stream() { + return Arrays.stream(index) + .filter(offset -> offset != 0) + .mapToObj(offset -> new Entry(mem, offset)); + } - public Stream stream() { - return Arrays.stream(index) - .filter(offset -> offset != 0) - .mapToObj(offset -> new Entry(mem, offset)); - } + private static long hash1(long value) { + return value; + } - private void add(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { - int idx = hash & INDEX_MASK; - for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) { - if (update(index[idx], buf, start, tailAndLen, measurement)) { - return; - } + private static long hash(long hash, long value) { + return hash ^ value; + } + + private static int hash(long hash) { + hash *= 0x9E3779B97F4A7C15L; // Fibonacci hashing multiplier + return (int) (hash >>> 39); } - index[idx] = create(buf, start, tailAndLen, measurement); - } - private int create(ByteBuffer buf, int start, long tailAndLen, int measurement) { - int offset = memUsed; + private static long valueSepMark(long keyLong) { + // Seen this trick used in multiple other solutions. + // Nice breakdown here: https://graphics.stanford.edu/~seander/bithacks.html#ZeroInWord + long match = keyLong ^ 0x3B3B3B3B_3B3B3B3BL; // 3B == ';' + match = (match - 0x01010101_01010101L) & (~match & 0x80808080_80808080L); + return match; + } - mem[offset] = tailAndLen; + private static int tailBits(long valueSepMark) { + return Long.numberOfTrailingZeros(valueSepMark >>> 7); + } - int memPos = offset + 1; - int memEnd = memPos + (int) (tailAndLen & 0xFF); - int bufPos = start; - while (memPos < memEnd) { - mem[memPos] = buf.getLong(bufPos); - memPos += 1; - bufPos += 8; + private static int valueOffset(int tailBits) { + return (int) (tailBits >>> 3) + 1; } - mem[memPos + FLD_MIN] = measurement; - mem[memPos + FLD_MAX] = measurement; - mem[memPos + FLD_SUM] = measurement; - mem[memPos + FLD_COUNT] = 1; - memUsed = memPos + 4; + private static long tailAndLen(int tailBits, long keyLong, long keyLen) { + long tailMask = ~(-1L << tailBits); + long tail = keyLong & tailMask; + return (tail << 8) | ((keyLen >> 3) & 0xFF); + } - return offset; - } + private static int decimalSepMark(long value) { + // Seen this trick used in multiple other solutions. + // Looks like the original author is @merykitty. - private boolean update(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) { - var mem = this.mem; - if (mem[offset] != tailAndLen) { - return false; + // The 4th binary digit of the ascii of a digit is 1 while + // that of the '.' is 0. This finds the decimal separator + // The value can be 12, 20, 28 + return Long.numberOfTrailingZeros(~value & 0x10101000); } - int memPos = offset + 1; - int memEnd = memPos + (int) (tailAndLen & 0xFF); - int bufPos = start; - while (memPos < memEnd) { - if (mem[memPos] != buf.getLong(bufPos)) { - return false; + + private static int decimalValue(int decimalSepMark, long value) { + // Seen this trick used in multiple other solutions. + // Looks like the original author is @merykitty. + + int shift = 28 - decimalSepMark; + // signed is -1 if negative, 0 otherwise + long signed = (~value << 59) >> 63; + long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii code + // to actual digit value in each byte + long digits = ((value & designMask) << shift) & 0x0F000F0F00L; + + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + + // 0x00UU00TTHH000000 * 10 + + // 0xUU00TTHH00000000 * 100 + // Now TT * 100 has 2 trailing zeroes and HH * 100 + TT * 10 + UU < 0x400 + // This results in our value lies in the bit 32 to 41 of this product + // That was close :) + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + return (int) ((absValue ^ signed) - signed); + } + + private static int nextKeyOffset(int decimalSepMark) { + return (decimalSepMark >>> 3) + 3; + } + + private void add1(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { + int idx = hash & INDEX_MASK; + for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) { + if (update1(index[idx], tailAndLen, measurement)) { + return; + } } - memPos += 1; - bufPos += 8; + index[idx] = create(buf, start, tailAndLen, measurement); } - mem[memPos + FLD_COUNT] += 1; - mem[memPos + FLD_SUM] += measurement; - if (measurement < mem[memPos + FLD_MIN]) { - mem[memPos + FLD_MIN] = measurement; + private void add2(ByteBuffer buf, int start, long keyLong, long tailAndLen, int hash, int measurement) { + int idx = hash & INDEX_MASK; + for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) { + if (update2(index[idx], keyLong, tailAndLen, measurement)) { + return; + } + } + index[idx] = create(buf, start, tailAndLen, measurement); } - if (measurement > mem[memPos + FLD_MAX]) { - mem[memPos + FLD_MAX] = measurement; + + private void addN(ByteBuffer buf, int start, long tailAndLen, int hash, int measurement) { + int idx = hash & INDEX_MASK; + for (; index[idx] != 0; idx = (idx + 1) & INDEX_MASK) { + if (updateN(index[idx], buf, start, tailAndLen, measurement)) { + return; + } + } + index[idx] = create(buf, start, tailAndLen, measurement); } - return true; - } + private int create(ByteBuffer buf, int start, long tailAndLen, int measurement) { + int offset = memUsed; - public static class Entry { - private final long[] mem; - private final int offset; - private String key; + mem[offset] = tailAndLen; - Entry(long[] mem, int offset) { - this.mem = mem; - this.offset = offset; + int memPos = offset + 1; + int memEnd = memPos + (int) (tailAndLen & 0xFF); + int bufPos = start; + while (memPos < memEnd) { + mem[memPos] = buf.getLong(bufPos); + memPos += 1; + bufPos += 8; + } + + mem[memPos + FLD_MIN] = measurement; + mem[memPos + FLD_MAX] = measurement; + mem[memPos + FLD_SUM] = measurement; + mem[memPos + FLD_COUNT] = 1; + memUsed = memPos + 4; + + return offset; } - public String getKey() { - if (key == null) { - int pos = this.offset; - long tailAndLen = mem[pos++]; - int keyLen = (int) (tailAndLen & 0xFF); - var tmpBuf = ByteBuffer.allocate((keyLen << 3) + 8).order(ByteOrder.nativeOrder()); - for (int i = 0; i < keyLen; i++) { - tmpBuf.putLong(mem[pos++]); - } - long tail = tailAndLen >>> 8; - tmpBuf.putLong(tail); - int keyLenBytes = (keyLen << 3) + 8 - (Long.numberOfLeadingZeros(tail) >> 3); - key = new String(tmpBuf.array(), 0, keyLenBytes, StandardCharsets.UTF_8); + private boolean update1(int offset, long tailAndLen, int measurement) { + if (mem[offset] != tailAndLen) { + return false; } - return key; + updateStats(offset + 1, measurement); + return true; } - public Entry add(Entry other) { - int fldOffset = (int) (mem[offset] & 0xFF) + 1; - int pos = offset + fldOffset; - int otherPos = other.offset + fldOffset; - long[] otherMem = other.mem; - mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]); - mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]); - mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM]; - mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT]; - return this; + private boolean update2(int offset, long keyLong, long tailAndLen, int measurement) { + if (mem[offset] != tailAndLen || mem[offset + 1] != keyLong) { + return false; + } + updateStats(offset + 2, measurement); + return true; } - public Entry getValue() { - return this; + private boolean updateN(int offset, ByteBuffer buf, int start, long tailAndLen, int measurement) { + var mem = this.mem; + if (mem[offset] != tailAndLen) { + return false; + } + int memPos = offset + 1; + int memEnd = memPos + (int) (tailAndLen & 0xFF); + int bufPos = start; + while (memPos < memEnd) { + if (mem[memPos] != buf.getLong(bufPos)) { + return false; + } + memPos += 1; + bufPos += 8; + } + updateStats(memPos, measurement); + return true; } - @Override - public String toString() { - int pos = offset + (int) (mem[offset] & 0xFF) + 1; - return round(mem[pos + FLD_MIN]) - + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT]) - + "/" + round(mem[pos + FLD_MAX]); + private void updateStats(int memPos, int measurement) { + mem[memPos + FLD_COUNT] += 1; + mem[memPos + FLD_SUM] += measurement; + if (measurement < mem[memPos + FLD_MIN]) { + mem[memPos + FLD_MIN] = measurement; + } + if (measurement > mem[memPos + FLD_MAX]) { + mem[memPos + FLD_MAX] = measurement; + } } - private static double round(double value) { - return Math.round(value) / 10.0; + public static class Entry { + private final long[] mem; + private final int offset; + private String key; + + Entry(long[] mem, int offset) { + this.mem = mem; + this.offset = offset; + } + + public String getKey() { + if (key == null) { + int pos = this.offset; + long tailAndLen = mem[pos++]; + int keyLen = (int) (tailAndLen & 0xFF); + var tmpBuf = ByteBuffer.allocate((keyLen << 3) + 8).order(ByteOrder.nativeOrder()); + for (int i = 0; i < keyLen; i++) { + tmpBuf.putLong(mem[pos++]); + } + long tail = tailAndLen >>> 8; + tmpBuf.putLong(tail); + int keyLenBytes = (keyLen << 3) + 8 - (Long.numberOfLeadingZeros(tail) >> 3); + key = new String(tmpBuf.array(), 0, keyLenBytes, StandardCharsets.UTF_8); + } + return key; + } + + public Entry add(Entry other) { + int fldOffset = (int) (mem[offset] & 0xFF) + 1; + int pos = offset + fldOffset; + int otherPos = other.offset + fldOffset; + long[] otherMem = other.mem; + mem[pos + FLD_MIN] = Math.min((int) mem[pos + FLD_MIN], (int) otherMem[otherPos + FLD_MIN]); + mem[pos + FLD_MAX] = Math.max((int) mem[pos + FLD_MAX], (int) otherMem[otherPos + FLD_MAX]); + mem[pos + FLD_SUM] += otherMem[otherPos + FLD_SUM]; + mem[pos + FLD_COUNT] += otherMem[otherPos + FLD_COUNT]; + return this; + } + + public Entry getValue() { + return this; + } + + @Override + public String toString() { + int pos = offset + (int) (mem[offset] & 0xFF) + 1; + return round(mem[pos + FLD_MIN]) + + "/" + round(((double) mem[pos + FLD_SUM]) / mem[pos + FLD_COUNT]) + + "/" + round(mem[pos + FLD_MAX]); + } + + private static double round(double value) { + return Math.round(value) / 10.0; + } } } + }