diff --git a/calculate_average_PanagiotisDrakatos.sh b/calculate_average_PanagiotisDrakatos.sh index e6c936578..699ebdb28 100755 --- a/calculate_average_PanagiotisDrakatos.sh +++ b/calculate_average_PanagiotisDrakatos.sh @@ -32,5 +32,5 @@ # source "$HOME/.sdkman/bin/sdkman-init.sh" sdk use java 21.0.1-graal 1>&2 -JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" +JAVA_OPTS="--enable-preview -Xms1536m -Xmx10536m -XX:NewSize=256m -XX:MaxNewSize=512m -XX:MaxMetaspaceSize=512m -XX:+DisableExplicitGC -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA" java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos diff --git a/prepare_PanagiotisDrakatos.sh b/prepare_PanagiotisDrakatos.sh index c322486c9..35fadfcb5 100755 --- a/prepare_PanagiotisDrakatos.sh +++ b/prepare_PanagiotisDrakatos.sh @@ -18,6 +18,6 @@ source "$HOME/.sdkman/bin/sdkman-init.sh" sdk use java 21.0.1-graal 1>&2 if [ ! -f target/CalculateAverage_PanagiotisDrakatos_image ]; then - NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=64m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos" + NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=10536m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos" native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_PanagiotisDrakatos_image dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos fi \ No newline at end of file diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_PanagiotisDrakatos.java b/src/main/java/dev/morling/onebrc/CalculateAverage_PanagiotisDrakatos.java index 9ab7a2264..04633948f 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_PanagiotisDrakatos.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_PanagiotisDrakatos.java @@ -20,41 +20,38 @@ import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; -import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.stream.Collectors; import java.util.stream.Stream; -import java.util.stream.StreamSupport; public class CalculateAverage_PanagiotisDrakatos { - private static final String FILE = "./measurements.txt"; - private static final long SEGMENT_SIZE = 4 * 1024 * 1024; - private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL; - private static final long DOT_BITS = 0x10101000; - private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1); - + private static final long MAP_SIZE = 1024 * 1024 * 12L; private static TreeMap sortedCities; public static void main(String[] args) throws IOException { SeekableByteRead(FILE); - System.out.println(sortedCities); + System.out.println(sortedCities.toString()); boolean DEBUG = true; } private static void SeekableByteRead(String path) throws IOException { FileInputStream fileInputStream = new FileInputStream(new File(FILE)); FileChannel fileChannel = fileInputStream.getChannel(); - Optional> optimistic = getFileSegments(new File(FILE), fileChannel) - .stream() - .map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel) - .parallel() - .map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData) - .reduce(CalculateAverage_PanagiotisDrakatos::combineMaps); + try { + sortedCities = getFileSegments(new File(FILE), fileChannel).stream() + .map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel) + .parallel() + .map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData) + .flatMap(MeasurementRepository::get) + .collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new)); + } + catch (NullPointerException e) { + } fileChannel.close(); - sortedCities = new TreeMap<>(optimistic.orElseThrow()); - } record FileSegment(long start, long end, FileChannel fileChannel) { @@ -95,14 +92,40 @@ private static long findSegment(RandomAccessFile raf, long location, final long private static ByteBuffer SplitSeekableByteChannel(FileSegment segment) { try { MappedByteBuffer buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.end - segment.start()); - int end = buffer.limit() - 1; - while (buffer.get(end) != '\n') { - end--; - } - return buffer.slice(0, end); + return buffer; } catch (Exception ex) { - throw new RuntimeException(ex); + long start = segment.start; + long end = 0; + try { + end = segment.fileChannel.size(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + MappedByteBuffer buffer = null; + ArrayList list = new ArrayList<>(); + while (start < end) { + try { + buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, start, Math.min(MAP_SIZE, end - start)); + // don't split the data in the middle of lines + // find the closest previous newline + int realEnd = buffer.limit() - 1; + while (buffer.get(realEnd) != '\n') + realEnd--; + + realEnd++; + buffer.limit(realEnd); + start += realEnd; + list.add(buffer.slice(0, realEnd - 1)); + } + catch (Exception e) { + e.printStackTrace(); + } + } + sortedCities = list.stream().parallel().map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData).flatMap(MeasurementRepository::get) + .collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new)); + return null; } } @@ -121,38 +144,61 @@ public static ByteBuffer concat(ByteBuffer[] buffers) { return all; } - private static Map combineMaps(Map map1, Map map2) { - for (var entry : map2.entrySet()) { - map1.merge(entry.getKey(), entry.getValue(), MeasurementObject::combine); - } + private static TreeMap combineMaps(Stream stream1, Stream stream2) { + Stream resultingStream = Stream.concat(stream1, stream2); + return resultingStream.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new)); + } + + private static int longHashStep(final int hash, final long word) { + return 31 * hash + (int) (word ^ (word >>> 32)); + } + + private static final long SEPARATOR_PATTERN = compilePattern((byte) ';'); - return map1; + private static long compilePattern(final byte value) { + return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16) + | ((long) value << 8) | (long) value; } - private static Map MappingByteBufferToData(ByteBuffer byteBuffer) { - Map cities = new HashMap<>(); + private static MeasurementRepository MappingByteBufferToData(ByteBuffer byteBuffer) { + MeasurementRepository measurements = new MeasurementRepository(); ByteBuffer bb = byteBuffer.duplicate(); + int start = 0; - int end = 0; - while (start < bb.limit()) { - while (bb.get(end) != ';') { - end++; + int limit = bb.limit(); + + long[] cityNameAsLongArray = new long[16]; + int[] delimiterPointerAndHash = new int[2]; + + bb.order(ByteOrder.nativeOrder()); + final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN); + + while ((start = bb.position()) < limit + 1) { + + int delimiterPointer; + + findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, start, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian); + delimiterPointer = delimiterPointerAndHash[0]; + // Simple lookup is faster for '\n' (just three options) + if (delimiterPointer >= limit) { + return measurements; } + final int cityNameLength = delimiterPointer - start; + int temp_counter = 0; - int temp_end = end; + int temp_end = delimiterPointer + 1; try { - bb.position(end); + // bb.position(delimiterPointer++); while (bb.get(temp_end) != '\n') { temp_counter++; temp_end++; } } catch (IndexOutOfBoundsException e) { - temp_counter--; - temp_end--; + // temp_counter--; + // temp_end--; } - ByteBuffer city = bb.slice(start, end - start); - ByteBuffer temp = bb.slice(end + 1, temp_counter); + ByteBuffer temp = bb.duplicate().slice(delimiterPointer + 1, temp_counter); int tempPointer = 0; int abs = 1; if (temp.get(0) == '-') { @@ -167,22 +213,141 @@ private static Map MappingByteBufferToData(ByteBuffer measuredValue = abs * (temp.get(tempPointer) * 100 + temp.get(tempPointer + 1) * 10 + temp.get(tempPointer + 3) - 5328); } - byte[] citybytes = new byte[city.limit()]; - city.get(citybytes); - String cityName = new String(citybytes, StandardCharsets.UTF_8); + measurements.update(cityNameAsLongArray, bb, cityNameLength, delimiterPointerAndHash[1]).updateWith(measuredValue); + + if (temp_end + 1 > limit) + return measurements; + bb.position(temp_end + 1); + } + return measurements; + } + + private static void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output, + final long[] asLong, final boolean bufferBigEndian) { + int hash = 1; + int i; + int lCnt = 0; + for (i = start; i <= limit - 8; i += 8) { + long word = bb.getLong(i); + if (bufferBigEndian) { + word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this + } + final long match = word ^ pattern; + long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L; + + if (mask != 0) { + final int index = Long.numberOfTrailingZeros(mask) >> 3; + output[0] = (i + index); - // update the map with the new measurement - MeasurementObject agg = cities.get(cityName); - if (agg == null) { - cities.put(cityName, new MeasurementObject(measuredValue, measuredValue, 0, 0).updateWith(measuredValue)); + final long partialHash = word & ((mask >> 7) - 1); + asLong[lCnt] = partialHash; + output[1] = longHashStep(hash, partialHash); + return; } - else { - cities.put(cityName, agg.updateWith(measuredValue)); + asLong[lCnt++] = word; + hash = longHashStep(hash, word); + } + // Handle remaining bytes near the limit of the buffer: + long partialHash = 0; + int len = 0; + for (; i < limit; i++) { + byte read; + if ((read = bb.get(i)) == (byte) pattern) { + asLong[lCnt] = partialHash; + output[0] = i; + output[1] = longHashStep(hash, partialHash); + return; } - start = temp_end + 1; - end = temp_end; + partialHash = partialHash | ((long) read << (len << 3)); + len++; } - return cities; + output[0] = limit; // delimiter not found + } + + static class MeasurementRepository { + private int tableSize = 1 << 20; // can grow in theory, made large enough not to (this is faster) + private int tableMask = (tableSize - 1); + private int tableLimit = (int) (tableSize * LOAD_FACTOR); + private int tableFilled = 0; + private static final float LOAD_FACTOR = 0.8f; + + private Entry[] table = new Entry[tableSize]; + + record Entry(int hash, long[] nameBytesInLong, String cityName, MeasurementObject measurement) { + @Override + public String toString() { + return cityName + "=" + measurement; + } + } + + public MeasurementObject update(long[] nameBytesInLong, ByteBuffer bb, int length, int calculatedHash) { + + final int nameBytesInLongLength = 1 + (length >>> 3); + + int index = calculatedHash & tableMask; + Entry tableEntry; + while ((tableEntry = table[index]) != null + && (tableEntry.hash != calculatedHash || !arrayEquals(tableEntry.nameBytesInLong, nameBytesInLong, nameBytesInLongLength))) { // search for the right spot + index = (index + 1) & tableMask; + } + + if (tableEntry != null) { + return tableEntry.measurement; + } + + // --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here. + MeasurementObject measurement = new MeasurementObject(); + + // Now create a string: + byte[] buffer = new byte[length]; + bb.get(buffer, 0, length); + String cityName = new String(buffer, 0, length); + + // Store the long[] for faster equals: + long[] nameBytesInLongCopy = new long[nameBytesInLongLength]; + System.arraycopy(nameBytesInLong, 0, nameBytesInLongCopy, 0, nameBytesInLongLength); + + // And add entry: + Entry toAdd = new Entry(calculatedHash, nameBytesInLongCopy, cityName, measurement); + table[index] = toAdd; + + // Resize the table if filled too much: + if (++tableFilled > tableLimit) { + resizeTable(); + } + + return toAdd.measurement; + } + + private void resizeTable() { + // Resize the table: + Entry[] oldEntries = table; + table = new Entry[tableSize <<= 2]; // x2 + tableMask = (tableSize - 1); + tableLimit = (int) (tableSize * LOAD_FACTOR); + + for (Entry entry : oldEntries) { + if (entry != null) { + int updatedTableIndex = entry.hash & tableMask; + while (table[updatedTableIndex] != null) { + updatedTableIndex = (updatedTableIndex + 1) & tableMask; + } + table[updatedTableIndex] = entry; + } + } + } + + public Stream get() { + return Arrays.stream(table).filter(Objects::nonNull); + } + } + + private static boolean arrayEquals(final long[] a, final long[] b, final int length) { + for (int i = 0; i < length; i++) { + if (a[i] != b[i]) + return false; + } + return true; } private static final class MeasurementObject { @@ -202,6 +367,10 @@ public MeasurementObject(int MAX, int MIN, long SUM, int REPEAT) { } public MeasurementObject() { + this.MAX = -999; + this.MIN = 9999; + this.SUM = 0; + this.REPEAT = 0; } public MeasurementObject(int MAX, int MIN, long SUM) { @@ -224,6 +393,15 @@ public static MeasurementObject combine(MeasurementObject m1, MeasurementObject return mres; } + public static MeasurementObject updateWith(MeasurementObject m1, MeasurementObject m2) { + var mres = new MeasurementObject(); + mres.MIN = MeasurementObject.min(m1.MIN, m2.MIN); + mres.MAX = MeasurementObject.max(m1.MAX, m2.MAX); + mres.SUM = m1.SUM + m2.SUM; + mres.REPEAT = m1.REPEAT + m2.REPEAT; + return mres; + } + public MeasurementObject updateWith(int measurement) { MIN = MeasurementObject.min(MIN, measurement); MAX = MeasurementObject.max(MAX, measurement); @@ -268,4 +446,4 @@ public String toString() { return round(MIN) + "/" + round((1.0 * SUM) / REPEAT) + "/" + round(MAX); } } -} +} \ No newline at end of file