diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java index d899c3d72..cc6e3b95a 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_artsiomkorzun.java @@ -39,6 +39,8 @@ public class CalculateAverage_artsiomkorzun { private static final long LINE_PATTERN = 0x0A0A0A0A0A0A0A0AL; private static final long DOT_BITS = 0x10101000; private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); + private static final long[] WORD_MASK = { 0, 0, 0, 0, 0, 0, 0, 0, -1 }; + private static final int[] LENGTH_MASK = { 0, 0, 0, 0, 0, 0, 0, 0, -1 }; private static final Unsafe UNSAFE; @@ -190,12 +192,6 @@ public Aggregates() { UNSAFE.setMemory(pointer, SIZE, (byte) 0); } - public long find(long word, long hash) { - long address = pointer + offset(hash); - long w = word(address + 24); - return (w == word) ? address : 0; - } - public long find(long word1, long word2, long hash) { long address = pointer + offset(hash); long w1 = word(address + 24); @@ -393,14 +389,20 @@ public void run() { long word1 = word(chunk1.position); long word2 = word(chunk2.position); long word3 = word(chunk3.position); + long word4 = word(chunk1.position + 8); + long word5 = word(chunk2.position + 8); + long word6 = word(chunk3.position + 8); long separator1 = separator(word1); long separator2 = separator(word2); long separator3 = separator(word3); + long separator4 = separator(word4); + long separator5 = separator(word5); + long separator6 = separator(word6); - long pointer1 = find(aggregates, chunk1, word1, separator1); - long pointer2 = find(aggregates, chunk2, word2, separator2); - long pointer3 = find(aggregates, chunk3, word3, separator3); + long pointer1 = find(aggregates, chunk1, word1, word4, separator1, separator4); + long pointer2 = find(aggregates, chunk2, word2, word5, separator2, separator5); + long pointer3 = find(aggregates, chunk3, word3, word6, separator3, separator6); long value1 = value(chunk1); long value2 = value(chunk2); @@ -413,26 +415,41 @@ public void run() { while (chunk1.has()) { long word1 = word(chunk1.position); + long word2 = word(chunk1.position + 8); + long separator1 = separator(word1); - long pointer1 = find(aggregates, chunk1, word1, separator1); - long value1 = value(chunk1); - Aggregates.update(pointer1, value1); + long separator2 = separator(word2); + + long pointer = find(aggregates, chunk1, word1, word2, separator1, separator2); + long value = value(chunk1); + + Aggregates.update(pointer, value); } while (chunk2.has()) { - long word2 = word(chunk2.position); + long word1 = word(chunk2.position); + long word2 = word(chunk2.position + 8); + + long separator1 = separator(word1); long separator2 = separator(word2); - long pointer2 = find(aggregates, chunk2, word2, separator2); - long value2 = value(chunk2); - Aggregates.update(pointer2, value2); + + long pointer = find(aggregates, chunk2, word1, word2, separator1, separator2); + long value = value(chunk2); + + Aggregates.update(pointer, value); } while (chunk3.has()) { - long word3 = word(chunk3.position); - long separator3 = separator(word3); - long pointer3 = find(aggregates, chunk3, word3, separator3); - long value3 = value(chunk3); - Aggregates.update(pointer3, value3); + long word1 = word(chunk3.position); + long word2 = word(chunk3.position + 8); + + long separator1 = separator(word1); + long separator2 = separator(word2); + + long pointer = find(aggregates, chunk3, word1, word2, separator1, separator2); + long value = value(chunk3); + + Aggregates.update(pointer, value); } } @@ -456,60 +473,50 @@ private static long next(long position) { continue; } - return position + (Long.numberOfTrailingZeros(line) >>> 3) + 1; + return position + length(line) + 1; } } - private static long find(Aggregates aggregates, Chunk chunk, long word, long separator) { + private static long find(Aggregates aggregates, Chunk chunk, long word1, long word2, long separator1, long separator2) { + boolean small = (separator1 | separator2) != 0; long start = chunk.position; long hash; + long word; - if (separator != 0) { - word = mask(word, separator); - hash = mix(word); + if (small) { + int length1 = length(separator1); + int length2 = length(separator2); + word1 = mask(word1, separator1); + word2 = mask(word2 & WORD_MASK[length1], separator2); + hash = mix(word1 ^ word2); - chunk.position += length(separator); - long pointer = aggregates.find(word, hash); + chunk.position += length1 + (length2 & LENGTH_MASK[length1]) + 1; + long pointer = aggregates.find(word1, word2, hash); if (pointer != 0) { return pointer; } + + word = (separator1 == 0) ? word2 : word1; } else { - long word0 = word; - word = word(start + 8); - separator = separator(word); - - if (separator != 0) { - word = mask(word, separator); - hash = mix(word ^ word0); + chunk.position += 16; + hash = word1 ^ word2; - chunk.position += length(separator) + 8; - long pointer = aggregates.find(word0, word, hash); + while (true) { + word = word(chunk.position); + long separator = separator(word); - if (pointer != 0) { - return pointer; - } - } - else { - chunk.position += 16; - hash = word ^ word0; - - while (true) { - word = word(chunk.position); - separator = separator(word); - - if (separator == 0) { - chunk.position += 8; - hash ^= word; - continue; - } - - word = mask(word, separator); - hash = mix(hash ^ word); - chunk.position += length(separator); - break; + if (separator == 0) { + chunk.position += 8; + hash ^= word; + continue; } + + word = mask(word, separator); + hash = mix(hash ^ word); + chunk.position += length(separator) + 1; + break; } } @@ -535,8 +542,8 @@ private static long mask(long word, long separator) { return word & mask; } - private static long length(long separator) { - return (Long.numberOfTrailingZeros(separator) >>> 3) + 1; + private static int length(long separator) { + return Long.numberOfTrailingZeros(separator) >>> 3; } private static long mix(long x) {