diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java index d825e77f9..0e9125337 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_armandino.java @@ -24,15 +24,12 @@ import java.nio.channels.FileChannel; import java.nio.file.Path; import java.nio.file.StandardOpenOption; -import java.util.Arrays; import java.util.Collection; -import java.util.Objects; +import java.util.Map; import java.util.TreeMap; -import java.util.stream.Stream; import static java.nio.channels.FileChannel.MapMode.READ_ONLY; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.stream.Collectors.toMap; public class CalculateAverage_armandino { @@ -42,19 +39,59 @@ public class CalculateAverage_armandino { private static final int INITIAL_MAP_CAPACITY = 8192; private static final byte SEMICOLON = 59; private static final byte NL = 10; - private static final byte DOT = 46; - private static final byte MINUS = 45; - private static final byte ZERO_DIGIT = 48; private static final int PRIME = 1117; + + private static final int KEY_OFFSET = 0, // 100b + HASH_OFFSET = 100, // int + KEY_LENGTH_OFFSET = 104, // short + MIN_OFFSET = 106, // short + MAX_OFFSET = 108, // short + COUNT_OFFSET = 110, // int + SUM_OFFSET = 114; // long + + private static final long ENTRY_SIZE = 100 // key: offset=0 + + 4 // keyHash: offset=100 + + 2 // keyLength: offset=104 + + 2 // min: 108; offset=106 + + 2 // max: 110; offset=108 + + 4 // count: 114; offset=110 + + 8; // sum: 122; offset=118 + private static final Unsafe UNSAFE = getUnsafe(); public static void main(String[] args) throws Exception { var channel = FileChannel.open(FILE, StandardOpenOption.READ); - var results = Arrays.stream(split(channel)).parallel() - .map(chunk -> new ChunkProcessor().process(chunk.start, chunk.end)) - .flatMap(SimpleMap::stream) - .collect(toMap(Stats::getKey, s -> s, CalculateAverage_armandino::mergeStats, TreeMap::new)); + Chunk[] chunks = split(channel); + ChunkProcessor[] processors = new ChunkProcessor[chunks.length]; + + for (int i = 0; i < processors.length; i++) { + processors[i] = new ChunkProcessor(chunks[i].start, chunks[i].end); + processors[i].start(); + } + + Map results = new TreeMap<>(); + + for (int i = 0; i < processors.length; i++) { + processors[i].join(); + final long end = processors[i].map.mapEnd; + + for (long addr = processors[i].map.mapStart; addr < end; addr += ENTRY_SIZE) { + final short keyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET); + + if (keyLength == 0) + continue; + + final byte[] keyBytes = new byte[keyLength]; + UNSAFE.copyMemory(null, addr, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength); + final short min = UNSAFE.getShort(addr + MIN_OFFSET); + final short max = UNSAFE.getShort(addr + MAX_OFFSET); + final int count = UNSAFE.getInt(addr + COUNT_OFFSET); + final long sum = UNSAFE.getLong(addr + SUM_OFFSET); + final Stats s = new Stats(new String(keyBytes, 0, keyLength, UTF_8), min, max, count, sum); + results.merge(s.key, s, CalculateAverage_armandino::mergeStats); + } + } print(results.values()); } @@ -67,87 +104,69 @@ private static Stats mergeStats(final Stats x, final Stats y) { return x; } - private static class ChunkProcessor { - private final SimpleMap map = new SimpleMap(INITIAL_MAP_CAPACITY); + private static class ChunkProcessor extends Thread { + private final UnsafeMap map = new UnsafeMap(INITIAL_MAP_CAPACITY); + + final long chunkStart; + final long chunkEnd; - private SimpleMap process(final long chunkStart, final long chunkEnd) { + private ChunkProcessor(long chunkStart, long chunkEnd) { + this.chunkStart = chunkStart; + this.chunkEnd = chunkEnd; + } + + @Override + public void run() { long i = chunkStart; while (i < chunkEnd) { final long keyAddress = i; int keyHash = 0; - int measurement = 0; byte b; while ((b = UNSAFE.getByte(i++)) != SEMICOLON) { keyHash = PRIME * keyHash + b; } - final int keyLength = (int) (i - keyAddress - 1); + final short keyLength = (short) (i - keyAddress - 1); + final long numberWord = UNSAFE.getLong(i); + final int decimalSepPos = Long.numberOfTrailingZeros(~numberWord & 0x10101000); + final short measurement = parseNumber(decimalSepPos, numberWord); + final int addOffset = (decimalSepPos >>> 3) + 3; + i += addOffset; - if ((b = UNSAFE.getByte(i++)) == MINUS) { - while ((b = UNSAFE.getByte(i++)) != DOT) { - measurement = measurement * 10 + b - ZERO_DIGIT; - } - - b = UNSAFE.getByte(i); - measurement = measurement * 10 + b - ZERO_DIGIT; - measurement = -measurement; - i += 2; - } - else { - measurement = b - ZERO_DIGIT; // D1 - b = UNSAFE.getByte(i); // dot or D2 - - if (b == DOT) { - measurement = measurement * 10 + UNSAFE.getByte(i + 1) - ZERO_DIGIT; // F - i += 3; - } - else { - measurement = measurement * 10 + b - ZERO_DIGIT; // D2 - measurement = measurement * 10 + UNSAFE.getByte(i + 2) - ZERO_DIGIT; // F - i += 4; // skip NL - } - } - - final Stats stats = map.putStats(keyHash, keyAddress, keyLength); - stats.min = Math.min(stats.min, measurement); - stats.max = Math.max(stats.max, measurement); - stats.sum += measurement; - stats.count++; + map.addEntry(keyHash, keyAddress, keyLength, measurement); } + } - return map; + // credit: merykitty + private static short parseNumber(int decimalSepPos, long numberWord) { + int shift = 28 - decimalSepPos; + // signed is -1 if negative, 0 otherwise + long signed = (~numberWord << 59) >> 63; + long designMask = ~(signed & 0xFF); + // Align the number to a specific position and transform the ascii to digit value + long digits = ((numberWord & designMask) << shift) & 0x0F000F0F00L; + // Now digits is in the form 0xUU00TTHH00 (UU: units digit, TT: tens digit, HH: hundreds digit) + // 0xUU00TTHH00 * (100 * 0x1000000 + 10 * 0x10000 + 1) = + // 0x000000UU00TTHH00 + 0x00UU00TTHH000000 * 10 + 0xUU00TTHH00000000 * 100 + long absValue = ((digits * 0x640a0001) >>> 32) & 0x3FF; + return (short) ((absValue ^ signed) - signed); } } - private static class Stats implements Comparable { - private String key; - private final long keyAddress; - private final int keyLength; - private final int keyHash; - private int min = Integer.MAX_VALUE; - private int max = Integer.MIN_VALUE; + private static class Stats { + private final String key; + private int min; + private int max; private int count; private long sum; - private Stats(long keyAddress, int keyLength, int keyHash) { - this.keyAddress = keyAddress; - this.keyLength = keyLength; - this.keyHash = keyHash; - } - - String getKey() { - if (key == null) { - var keyBytes = new byte[keyLength]; - UNSAFE.copyMemory(null, keyAddress, keyBytes, Unsafe.ARRAY_BYTE_BASE_OFFSET, keyLength); - key = new String(keyBytes, 0, keyLength, UTF_8); - } - return key; - } - - @Override - public int compareTo(final Stats o) { - return getKey().compareTo(o.getKey()); + Stats(final String key, final int min, final int max, final int count, final long sum) { + this.min = min; + this.max = max; + this.count = count; + this.sum = sum; + this.key = key; } void print(final PrintStream out) { @@ -219,90 +238,114 @@ private static Unsafe getUnsafe() { } } - private static class SimpleMap { - private Stats[] table; + private static class UnsafeMap { - SimpleMap(int initialCapacity) { - table = new Stats[initialCapacity]; - } + long mapStart; + long mapEnd; + int capacity; // num entries - Stream stream() { - return Arrays.stream(table).filter(Objects::nonNull); + UnsafeMap(int numEntries) { + capacity = numEntries; + final long size = ENTRY_SIZE * numEntries; + mapStart = UNSAFE.allocateMemory(size); + mapEnd = mapStart + size; + UNSAFE.setMemory(mapStart, size, (byte) 0); } - Stats putStats(final int keyHash, final long keyAddress, final int keyLength) { - final int pos = (table.length - 1) & keyHash; - - Stats stats = table[pos]; - if (stats == null) - return createAt(table, keyAddress, keyLength, keyHash, pos); - if (stats.keyHash == keyHash && keysEqual(stats, keyAddress, keyLength)) - return stats; - - int i = pos; - while (++i < table.length) { - stats = table[i]; - if (stats == null) - return createAt(table, keyAddress, keyLength, keyHash, i); - if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) - return stats; + void addEntry(final int keyHash, final long keyAddress, final short keyLength, final short measurement) { + final int pos = (capacity - 1) & keyHash; + + long addr = mapStart + pos * ENTRY_SIZE; + int hash = UNSAFE.getInt(addr + HASH_OFFSET); + + if (hash == 0) { // new entry + initEntry(addr, keyAddress, keyLength, measurement, keyHash); + return; + } + if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) { + updateEntry(addr, measurement); + return; + } + + // this can be improved to avoid clustering at the start. + // should only affect the 10k test + addr = mapStart; + + while (addr < mapEnd) { + addr += ENTRY_SIZE; + hash = UNSAFE.getInt(addr + HASH_OFFSET); + + if (hash == 0) { + initEntry(addr, keyAddress, keyLength, measurement, keyHash); + return; + } + if (hash == keyHash && keysEqual(addr, keyAddress, keyLength)) { + updateEntry(addr, measurement); + return; + } } - i = pos; - while (i-- > 0) { - stats = table[i]; - if (stats == null) - return createAt(table, keyAddress, keyLength, keyHash, i); - if (keyHash == stats.keyHash && keysEqual(stats, keyAddress, keyLength)) - return stats; + resize(keyHash, keyAddress, keyLength, measurement); + } + + private void resize(final int keyHash, final long keyAddress, final short keyLength, final short measurement) { + UnsafeMap newMap = new UnsafeMap(capacity * 2); + + for (long addr = mapStart; addr < mapEnd; addr += ENTRY_SIZE) { + final short oKeyLength = UNSAFE.getShort(addr + KEY_LENGTH_OFFSET); + final int oKeyHsh = UNSAFE.getInt(addr + HASH_OFFSET); + final short oMin = UNSAFE.getShort(addr + MIN_OFFSET); + final short oMax = UNSAFE.getShort(addr + MAX_OFFSET); + final int oCount = UNSAFE.getInt(addr + COUNT_OFFSET); + final long oSum = UNSAFE.getLong(addr + SUM_OFFSET); + + final int newPos = (newMap.capacity - 1) & oKeyHsh; + long newAddr = newMap.mapStart + newPos * ENTRY_SIZE; + + UNSAFE.putShort(newAddr + KEY_LENGTH_OFFSET, oKeyLength); + UNSAFE.putInt(newAddr + HASH_OFFSET, oKeyHsh); + UNSAFE.putShort(newAddr + MIN_OFFSET, oMin); + UNSAFE.putShort(newAddr + MAX_OFFSET, oMax); + UNSAFE.putInt(newAddr + COUNT_OFFSET, oCount); + UNSAFE.putLong(newAddr + SUM_OFFSET, oSum); } - resize(); - return putStats(keyHash, keyAddress, keyLength); + + newMap.addEntry(keyHash, keyAddress, keyLength, measurement); + + this.mapStart = newMap.mapStart; + this.mapEnd = newMap.mapEnd; + this.capacity = newMap.capacity; } - private static Stats createAt(Stats[] table, long keyAddress, int keyLength, int key, int i) { - Stats stats = new Stats(keyAddress, keyLength, key); - table[i] = stats; - return stats; + private static void initEntry(final long entry, final long keyAddress, final short keyLength, final short measurement, final int keyHash) { + UNSAFE.copyMemory(keyAddress, entry, keyLength); + UNSAFE.putInt(entry + HASH_OFFSET, keyHash); + UNSAFE.putShort(entry + KEY_LENGTH_OFFSET, keyLength); + UNSAFE.putShort(entry + MIN_OFFSET, Short.MAX_VALUE); + UNSAFE.putShort(entry + MAX_OFFSET, Short.MIN_VALUE); + + updateEntry(entry, measurement); } - private static boolean keysEqual(Stats stats, long keyAddress, final int keyLength) { - // credit: abeobk - long xsum = 0; - int n = keyLength & 0xF8; - for (int i = 0; i < n; i += 8) { - xsum |= (UNSAFE.getLong(stats.keyAddress + i) ^ UNSAFE.getLong(keyAddress + i)); - } - return xsum == 0; + private static void updateEntry(final long entry, final short measurement) { + UNSAFE.putShort(entry + MIN_OFFSET, + (short) Math.min(UNSAFE.getShort(entry + MIN_OFFSET), measurement)); + UNSAFE.putShort(entry + MAX_OFFSET, + (short) Math.max(UNSAFE.getShort(entry + MAX_OFFSET), measurement)); + UNSAFE.putInt(entry + COUNT_OFFSET, + UNSAFE.getInt(entry + COUNT_OFFSET) + 1); + UNSAFE.putLong(entry + SUM_OFFSET, + UNSAFE.getLong(entry + SUM_OFFSET) + measurement); } + } - private void resize() { - var copy = new SimpleMap(table.length * 2); - for (Stats s : table) { - if (s != null) { - final int pos = (copy.table.length - 1) & s.keyHash; - int i = pos; - if (copy.table[i] == null) { - copy.table[i] = s; - continue; - } - while (i < copy.table.length && copy.table[i] != null) { - i++; - } - if (i == copy.table.length) { - i = pos; - while (i >= 0 && copy.table[i] != null) { - i--; - } - } - if (i < 0) { - // if we reach here it's a bug! - throw new IllegalStateException("table is full"); - } - copy.table[i] = s; - } - } - table = copy.table; + private static boolean keysEqual(long key1Address, long key2Address, final int keyLength) { + // credit: abeobk + long xsum = 0; + int n = keyLength & 0xF8; + for (int i = 0; i < n; i += 8) { + xsum |= (UNSAFE.getLong(key1Address + i) ^ UNSAFE.getLong(key2Address + i)); } + return xsum == 0; } }