From 0c5c22882b6750b0e3603e327694fbf4401c15d8 Mon Sep 17 00:00:00 2001 From: Dr Ian Preston <157221403+ianopolousfast@users.noreply.github.com> Date: Wed, 31 Jan 2024 16:06:05 +0000 Subject: [PATCH] Process two consecutive lines at a time (#651) Use a better hash function Don't return index from temperature parsing extra JVM args Co-authored-by: Ian Preston --- calculate_average_ianopolousfast.sh | 2 + .../CalculateAverage_ianopolousfast.java | 150 ++++++++---------- 2 files changed, 65 insertions(+), 87 deletions(-) diff --git a/calculate_average_ianopolousfast.sh b/calculate_average_ianopolousfast.sh index 06c31d9e5..4ed77c70a 100755 --- a/calculate_average_ianopolousfast.sh +++ b/calculate_average_ianopolousfast.sh @@ -16,4 +16,6 @@ # JAVA_OPTS="--enable-preview --add-modules=jdk.incubator.vector" +#-Djdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK=0 -XX:-UseTransparentHugePages" + java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_ianopolousfast diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java index 417abcfbe..92e2f6ecb 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_ianopolousfast.java @@ -19,7 +19,6 @@ import jdk.incubator.vector.VectorOperators; import jdk.incubator.vector.VectorSpecies; -import java.io.IOException; import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; import java.nio.ByteOrder; @@ -39,10 +38,7 @@ * * read chunks in parallel * * minimise allocation * * no unsafe - * - * Timings on 4 core i7-7500U CPU @ 2.70GHz: - * average_baseline: 4m48s - * ianopolous: 13.8s + * * process multiple lines in each thread for better ILP */ public class CalculateAverage_ianopolousfast { @@ -91,11 +87,22 @@ public static boolean matchingStationBytes(long start, long end, MemorySegment b return true; } - private static int hashToIndex(long hash, int len) { - // From Thomas Wuerthinger's entry - int hashAsInt = (int) (hash ^ (hash >>> 28)); - int finalHash = (hashAsInt ^ (hashAsInt >>> 15)); - return (finalHash & (len - 1)); + private static final int GOLDEN_RATIO = 0x9E3779B9; + private static final int HASH_LROTATE = 5; + + // hash from giovannicuccu + private static int hash(MemorySegment memorySegment, long start, int len) { + int x; + int y; + if (len >= Integer.BYTES) { + x = memorySegment.get(JAVA_INT_UNALIGNED, start); + y = memorySegment.get(JAVA_INT_UNALIGNED, start + len - Integer.BYTES); + } + else { + x = memorySegment.get(JAVA_BYTE, start); + y = memorySegment.get(JAVA_BYTE, start + len - Byte.BYTES); + } + return (Integer.rotateLeft(x * GOLDEN_RATIO, HASH_LROTATE) ^ y) * GOLDEN_RATIO; } public static Stat createStation(long start, long end, MemorySegment buffer) { @@ -105,8 +112,9 @@ public static Stat createStation(long start, long end, MemorySegment buffer) { return new Stat(stationBuffer); } - public static Stat dedupeStation(long start, long end, long hash, MemorySegment buffer, Stat[] stations) { - int index = hashToIndex(hash, MAX_STATIONS); + public static Stat dedupeStation(long start, long end, MemorySegment buffer, Stat[] stations) { + int hash = hash(buffer, start, (int) (end - start)); + int index = hash & (MAX_STATIONS - 1); Stat match = stations[index]; while (match != null) { if (matchingStationBytes(start, end, buffer, match)) @@ -119,37 +127,11 @@ public static Stat dedupeStation(long start, long end, long hash, MemorySegment return res; } - static long maskHighBytes(long d, int nbytes) { - return d & (-1L << ((8 - nbytes) * 8)); - } - - public static Stat parseStation(long lineStart, MemorySegment buffer, Stat[] stations) { - ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); - int keySize = line.compare(VectorOperators.EQ, ';').firstTrue(); - - long first8 = buffer.get(LONG_LAYOUT, lineStart); - long second8 = 0; - if (keySize <= 8) { - first8 = maskHighBytes(first8, keySize & 0x07); - } - else if (keySize < 16) { - second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); - } - else if (keySize == BYTE_SPECIES.vectorByteSize()) { - while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { - keySize++; - } - second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); - } - long hash = first8 ^ second8; // todo include later bytes - return dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); - } - public static short getMinus(long d) { return ((d & 0xff00000000000000L) ^ 0x2d00000000000000L) != 0 ? 0 : (short) -1; } - public static long processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) { + public static void processTemperature(long lineSplit, int size, MemorySegment buffer, Stat station) { long d = buffer.get(LONG_LAYOUT, lineSplit); // negative is either 0 or -1 short negative = getMinus(d); @@ -162,10 +144,9 @@ public static long processTemperature(long lineSplit, int size, MemorySegment bu 100 * (((byte) (d >> 24)) - '0')); temperature = (short) ((temperature ^ negative) - negative); // negative treatment inspired by merkitty station.add(temperature); - return lineSplit + size + 1; } - private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stations) { + private static int lineSize(long lineStart, MemorySegment buffer) { ByteVector line = ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart, ByteOrder.nativeOrder()); int lineSize = line.compare(VectorOperators.EQ, '\n').firstTrue(); int index = lineSize; @@ -174,33 +155,19 @@ private static long parseLine(long lineStart, MemorySegment buffer, Stat[] stati ByteOrder.nativeOrder()).compare(VectorOperators.EQ, '\n').firstTrue(); lineSize += index; } - int keySize = lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6, - ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); + return lineSize; + } - long first8 = buffer.get(LONG_LAYOUT, lineStart); - long second8 = 0; - if (keySize <= 8) { - first8 = maskHighBytes(first8, keySize & 0x07); - } - else if (keySize < 16) { - second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); - } - else if (keySize == BYTE_SPECIES.vectorByteSize()) { - while (buffer.get(JAVA_BYTE, lineStart + keySize) != ';') { - keySize++; - } - second8 = maskHighBytes(buffer.get(LONG_LAYOUT, lineStart + 8), keySize & 0x07); - } - long hash = first8 ^ second8; // todo include later bytes - Stat station = dedupeStation(lineStart, lineStart + keySize, hash, buffer, stations); - return processTemperature(lineStart + keySize + 1, lineSize - keySize - 1, buffer, station); + private static int keySize(int lineSize, long lineStart, MemorySegment buffer) { + return lineSize - 6 + ByteVector.fromMemorySegment(BYTE_SPECIES, buffer, lineStart + lineSize - 6, + ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); } - public static Stat[] parseStats(long startByte, long endByte, MemorySegment buffer) { + public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) { // read first partial line - if (startByte > 0) { + if (start1 > 0) { for (int i = 0; i < MAX_LINE_LENGTH; i++) { - byte b = buffer.get(JAVA_BYTE, startByte++); + byte b = buffer.get(JAVA_BYTE, start1++); if (b == '\n') { break; } @@ -213,38 +180,47 @@ public static Stat[] parseStats(long startByte, long endByte, MemorySegment buff // this allows us to not worry about reading beyond the end // in the inner loop (reducing branches) // We need at least the vector lane size bytes back - if (endByte == buffer.byteSize()) { + if (end2 == buffer.byteSize()) { // reverse at least vector lane width - endByte = Math.max(buffer.byteSize() - BYTE_SPECIES.vectorByteSize(), 0); - while (endByte > 0 && buffer.get(JAVA_BYTE, endByte) != '\n') - endByte--; + end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0); + while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n') + end2--; - if (endByte > 0) - endByte++; + if (end2 > 0) + end2++; // copy into a larger buffer to avoid reading off end - MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + BYTE_SPECIES.vectorByteSize()); - for (long i = endByte; i < buffer.byteSize(); i++) - end.set(JAVA_BYTE, i - endByte, buffer.get(JAVA_BYTE, i)); + MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize()); + for (long i = end2; i < buffer.byteSize(); i++) + end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i)); int index = 0; - while (endByte + index < buffer.byteSize()) { - Stat station = parseStation(index, end, stations); - int tempSize = 3; - if (end.get(JAVA_BYTE, index + station.namelen + 5) == '\n') - tempSize = 4; - if (end.get(JAVA_BYTE, index + station.namelen + 6) == '\n') - tempSize = 5; - index = (int) processTemperature(index + station.namelen + 1, tempSize, end, station); + while (end2 + index < buffer.byteSize()) { + int lineSize1 = lineSize(index, end); + int semiSearchStart = index + Math.max(0, lineSize1 - 6); + int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart, + ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue(); + Stat station1 = dedupeStation(index, index + keySize1, end, stations); + processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1); + index += lineSize1 + 1; } } - innerloop(startByte, endByte, buffer, stations); - return stations; - } - - private static void innerloop(long startByte, long endByte, MemorySegment buffer, Stat[] stations) { - while (startByte < endByte) { - startByte = parseLine(startByte, buffer, stations); + while (start1 < end2) { + int lineSize1 = lineSize(start1, buffer); + long start2 = start1 + lineSize1 + 1; + int lineSize2 = start2 < end2 ? lineSize(start2, buffer) : 0; + int keySize1 = keySize(lineSize1, start1, buffer); + int keySize2 = keySize(lineSize2, start2, buffer); + Stat station1 = dedupeStation(start1, start1 + keySize1, buffer, stations); + processTemperature(start1 + keySize1 + 1, lineSize1 - keySize1 - 1, buffer, station1); + if (start2 < end2) { + Stat station2 = dedupeStation(start2, start2 + keySize2, buffer, stations); + processTemperature(start2 + keySize2 + 1, lineSize2 - keySize2 - 1, buffer, station2); + start1 = start2 + lineSize2 + 1; + } + else + start1 += lineSize1 + 1; } + return stations; } public static class Stat {