Skip to content

Commit

Permalink
My Probably last attempt to optimize performance (#693)
Browse files Browse the repository at this point in the history
* CalculateAverage_pdrakatos

* Rename to be valid with rules

* CalculateAverage_pdrakatos

* Rename to be valid with rules

* Changes on scripts execution

* Fixing bugs causing scripts not to be executed

* Changes on prepare make it compatible

* Fixing passing all tests

* Increase direct memory allocation buffer

* Fixing memory problem causes heap space exception

* Fresh solution to optimize performance of the execution

* New Fresh solution with optimized performance with Custom Hashtable

* Increase maxperm size and xmx to avoid heap spaces error
  • Loading branch information
PanagiotisDrakatos authored Feb 1, 2024
1 parent 1e7314d commit 2aed039
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 55 deletions.
2 changes: 1 addition & 1 deletion calculate_average_PanagiotisDrakatos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@
#
source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2
JAVA_OPTS="--enable-preview -Xmx128m -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA"
JAVA_OPTS="--enable-preview -Xms1536m -Xmx10536m -XX:NewSize=256m -XX:MaxNewSize=512m -XX:MaxMetaspaceSize=512m -XX:+DisableExplicitGC -XX:+UseSerialGC -XX:-TieredCompilation -XX:+UnlockExperimentalVMOptions -XX:+TrustFinalNonStaticFields -dsa -XX:+UseNUMA"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos
2 changes: 1 addition & 1 deletion prepare_PanagiotisDrakatos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ source "$HOME/.sdkman/bin/sdkman-init.sh"
sdk use java 21.0.1-graal 1>&2

if [ ! -f target/CalculateAverage_PanagiotisDrakatos_image ]; then
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=64m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos"
NATIVE_IMAGE_OPTS="--gc=epsilon -O3 -R:MaxHeapSize=10536m --initialize-at-build-time=dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos"
native-image $NATIVE_IMAGE_OPTS -cp target/average-1.0.0-SNAPSHOT.jar -o target/CalculateAverage_PanagiotisDrakatos_image dev.morling.onebrc.CalculateAverage_PanagiotisDrakatos
fi
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,38 @@
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class CalculateAverage_PanagiotisDrakatos {

private static final String FILE = "./measurements.txt";
private static final long SEGMENT_SIZE = 4 * 1024 * 1024;
private static final long COMMA_PATTERN = 0x3B3B3B3B3B3B3B3BL;
private static final long DOT_BITS = 0x10101000;
private static final long MAGIC_MULTIPLIER = (100 * 0x1000000 + 10 * 0x10000 + 1);

private static final long MAP_SIZE = 1024 * 1024 * 12L;
private static TreeMap<String, MeasurementObject> sortedCities;

public static void main(String[] args) throws IOException {
SeekableByteRead(FILE);
System.out.println(sortedCities);
System.out.println(sortedCities.toString());
boolean DEBUG = true;
}

private static void SeekableByteRead(String path) throws IOException {
FileInputStream fileInputStream = new FileInputStream(new File(FILE));
FileChannel fileChannel = fileInputStream.getChannel();
Optional<Map<String, MeasurementObject>> optimistic = getFileSegments(new File(FILE), fileChannel)
.stream()
.map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel)
.parallel()
.map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData)
.reduce(CalculateAverage_PanagiotisDrakatos::combineMaps);
try {
sortedCities = getFileSegments(new File(FILE), fileChannel).stream()
.map(CalculateAverage_PanagiotisDrakatos::SplitSeekableByteChannel)
.parallel()
.map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData)
.flatMap(MeasurementRepository::get)
.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
}
catch (NullPointerException e) {
}
fileChannel.close();
sortedCities = new TreeMap<>(optimistic.orElseThrow());

}

record FileSegment(long start, long end, FileChannel fileChannel) {
Expand Down Expand Up @@ -95,14 +92,40 @@ private static long findSegment(RandomAccessFile raf, long location, final long
private static ByteBuffer SplitSeekableByteChannel(FileSegment segment) {
try {
MappedByteBuffer buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segment.end - segment.start());
int end = buffer.limit() - 1;
while (buffer.get(end) != '\n') {
end--;
}
return buffer.slice(0, end);
return buffer;
}
catch (Exception ex) {
throw new RuntimeException(ex);
long start = segment.start;
long end = 0;
try {
end = segment.fileChannel.size();
}
catch (IOException e) {
throw new RuntimeException(e);
}
MappedByteBuffer buffer = null;
ArrayList<ByteBuffer> list = new ArrayList<>();
while (start < end) {
try {
buffer = segment.fileChannel.map(FileChannel.MapMode.READ_ONLY, start, Math.min(MAP_SIZE, end - start));
// don't split the data in the middle of lines
// find the closest previous newline
int realEnd = buffer.limit() - 1;
while (buffer.get(realEnd) != '\n')
realEnd--;

realEnd++;
buffer.limit(realEnd);
start += realEnd;
list.add(buffer.slice(0, realEnd - 1));
}
catch (Exception e) {
e.printStackTrace();
}
}
sortedCities = list.stream().parallel().map(CalculateAverage_PanagiotisDrakatos::MappingByteBufferToData).flatMap(MeasurementRepository::get)
.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
return null;
}
}

Expand All @@ -121,38 +144,61 @@ public static ByteBuffer concat(ByteBuffer[] buffers) {
return all;
}

private static Map<String, MeasurementObject> combineMaps(Map<String, MeasurementObject> map1, Map<String, MeasurementObject> map2) {
for (var entry : map2.entrySet()) {
map1.merge(entry.getKey(), entry.getValue(), MeasurementObject::combine);
}
private static TreeMap<String, MeasurementObject> combineMaps(Stream<MeasurementRepository.Entry> stream1, Stream<MeasurementRepository.Entry> stream2) {
Stream<MeasurementRepository.Entry> resultingStream = Stream.concat(stream1, stream2);
return resultingStream.collect(Collectors.toMap(e -> e.cityName, MeasurementRepository.Entry::measurement, MeasurementObject::updateWith, TreeMap::new));
}

private static int longHashStep(final int hash, final long word) {
return 31 * hash + (int) (word ^ (word >>> 32));
}

private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');

return map1;
private static long compilePattern(final byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) | ((long) value << 24) | ((long) value << 16)
| ((long) value << 8) | (long) value;
}

private static Map<String, MeasurementObject> MappingByteBufferToData(ByteBuffer byteBuffer) {
Map<String, MeasurementObject> cities = new HashMap<>();
private static MeasurementRepository MappingByteBufferToData(ByteBuffer byteBuffer) {
MeasurementRepository measurements = new MeasurementRepository();
ByteBuffer bb = byteBuffer.duplicate();

int start = 0;
int end = 0;
while (start < bb.limit()) {
while (bb.get(end) != ';') {
end++;
int limit = bb.limit();

long[] cityNameAsLongArray = new long[16];
int[] delimiterPointerAndHash = new int[2];

bb.order(ByteOrder.nativeOrder());
final boolean bufferIsBigEndian = bb.order().equals(ByteOrder.BIG_ENDIAN);

while ((start = bb.position()) < limit + 1) {

int delimiterPointer;

findNextDelimiterAndCalculateHash(bb, SEPARATOR_PATTERN, start, limit, delimiterPointerAndHash, cityNameAsLongArray, bufferIsBigEndian);
delimiterPointer = delimiterPointerAndHash[0];
// Simple lookup is faster for '\n' (just three options)
if (delimiterPointer >= limit) {
return measurements;
}
final int cityNameLength = delimiterPointer - start;

int temp_counter = 0;
int temp_end = end;
int temp_end = delimiterPointer + 1;
try {
bb.position(end);
// bb.position(delimiterPointer++);
while (bb.get(temp_end) != '\n') {
temp_counter++;
temp_end++;
}
}
catch (IndexOutOfBoundsException e) {
temp_counter--;
temp_end--;
// temp_counter--;
// temp_end--;
}
ByteBuffer city = bb.slice(start, end - start);
ByteBuffer temp = bb.slice(end + 1, temp_counter);
ByteBuffer temp = bb.duplicate().slice(delimiterPointer + 1, temp_counter);
int tempPointer = 0;
int abs = 1;
if (temp.get(0) == '-') {
Expand All @@ -167,22 +213,141 @@ private static Map<String, MeasurementObject> MappingByteBufferToData(ByteBuffer
measuredValue = abs * (temp.get(tempPointer) * 100 + temp.get(tempPointer + 1) * 10 + temp.get(tempPointer + 3) - 5328);
}

byte[] citybytes = new byte[city.limit()];
city.get(citybytes);
String cityName = new String(citybytes, StandardCharsets.UTF_8);
measurements.update(cityNameAsLongArray, bb, cityNameLength, delimiterPointerAndHash[1]).updateWith(measuredValue);

if (temp_end + 1 > limit)
return measurements;
bb.position(temp_end + 1);
}
return measurements;
}

private static void findNextDelimiterAndCalculateHash(final ByteBuffer bb, final long pattern, final int start, final int limit, final int[] output,
final long[] asLong, final boolean bufferBigEndian) {
int hash = 1;
int i;
int lCnt = 0;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
if (bufferBigEndian) {
word = Long.reverseBytes(word); // Reversing the bytes is the cheapest way to do this
}
final long match = word ^ pattern;
long mask = ((match - 0x0101010101010101L) & ~match) & 0x8080808080808080L;

if (mask != 0) {
final int index = Long.numberOfTrailingZeros(mask) >> 3;
output[0] = (i + index);

// update the map with the new measurement
MeasurementObject agg = cities.get(cityName);
if (agg == null) {
cities.put(cityName, new MeasurementObject(measuredValue, measuredValue, 0, 0).updateWith(measuredValue));
final long partialHash = word & ((mask >> 7) - 1);
asLong[lCnt] = partialHash;
output[1] = longHashStep(hash, partialHash);
return;
}
else {
cities.put(cityName, agg.updateWith(measuredValue));
asLong[lCnt++] = word;
hash = longHashStep(hash, word);
}
// Handle remaining bytes near the limit of the buffer:
long partialHash = 0;
int len = 0;
for (; i < limit; i++) {
byte read;
if ((read = bb.get(i)) == (byte) pattern) {
asLong[lCnt] = partialHash;
output[0] = i;
output[1] = longHashStep(hash, partialHash);
return;
}
start = temp_end + 1;
end = temp_end;
partialHash = partialHash | ((long) read << (len << 3));
len++;
}
return cities;
output[0] = limit; // delimiter not found
}

static class MeasurementRepository {
private int tableSize = 1 << 20; // can grow in theory, made large enough not to (this is faster)
private int tableMask = (tableSize - 1);
private int tableLimit = (int) (tableSize * LOAD_FACTOR);
private int tableFilled = 0;
private static final float LOAD_FACTOR = 0.8f;

private Entry[] table = new Entry[tableSize];

record Entry(int hash, long[] nameBytesInLong, String cityName, MeasurementObject measurement) {
@Override
public String toString() {
return cityName + "=" + measurement;
}
}

public MeasurementObject update(long[] nameBytesInLong, ByteBuffer bb, int length, int calculatedHash) {

final int nameBytesInLongLength = 1 + (length >>> 3);

int index = calculatedHash & tableMask;
Entry tableEntry;
while ((tableEntry = table[index]) != null
&& (tableEntry.hash != calculatedHash || !arrayEquals(tableEntry.nameBytesInLong, nameBytesInLong, nameBytesInLongLength))) { // search for the right spot
index = (index + 1) & tableMask;
}

if (tableEntry != null) {
return tableEntry.measurement;
}

// --- This is a brand new entry, insert into the hashtable and do the extra calculations (once!) do slower calculations here.
MeasurementObject measurement = new MeasurementObject();

// Now create a string:
byte[] buffer = new byte[length];
bb.get(buffer, 0, length);
String cityName = new String(buffer, 0, length);

// Store the long[] for faster equals:
long[] nameBytesInLongCopy = new long[nameBytesInLongLength];
System.arraycopy(nameBytesInLong, 0, nameBytesInLongCopy, 0, nameBytesInLongLength);

// And add entry:
Entry toAdd = new Entry(calculatedHash, nameBytesInLongCopy, cityName, measurement);
table[index] = toAdd;

// Resize the table if filled too much:
if (++tableFilled > tableLimit) {
resizeTable();
}

return toAdd.measurement;
}

private void resizeTable() {
// Resize the table:
Entry[] oldEntries = table;
table = new Entry[tableSize <<= 2]; // x2
tableMask = (tableSize - 1);
tableLimit = (int) (tableSize * LOAD_FACTOR);

for (Entry entry : oldEntries) {
if (entry != null) {
int updatedTableIndex = entry.hash & tableMask;
while (table[updatedTableIndex] != null) {
updatedTableIndex = (updatedTableIndex + 1) & tableMask;
}
table[updatedTableIndex] = entry;
}
}
}

public Stream<Entry> get() {
return Arrays.stream(table).filter(Objects::nonNull);
}
}

private static boolean arrayEquals(final long[] a, final long[] b, final int length) {
for (int i = 0; i < length; i++) {
if (a[i] != b[i])
return false;
}
return true;
}

private static final class MeasurementObject {
Expand All @@ -202,6 +367,10 @@ public MeasurementObject(int MAX, int MIN, long SUM, int REPEAT) {
}

public MeasurementObject() {
this.MAX = -999;
this.MIN = 9999;
this.SUM = 0;
this.REPEAT = 0;
}

public MeasurementObject(int MAX, int MIN, long SUM) {
Expand All @@ -224,6 +393,15 @@ public static MeasurementObject combine(MeasurementObject m1, MeasurementObject
return mres;
}

public static MeasurementObject updateWith(MeasurementObject m1, MeasurementObject m2) {
var mres = new MeasurementObject();
mres.MIN = MeasurementObject.min(m1.MIN, m2.MIN);
mres.MAX = MeasurementObject.max(m1.MAX, m2.MAX);
mres.SUM = m1.SUM + m2.SUM;
mres.REPEAT = m1.REPEAT + m2.REPEAT;
return mres;
}

public MeasurementObject updateWith(int measurement) {
MIN = MeasurementObject.min(MIN, measurement);
MAX = MeasurementObject.max(MAX, measurement);
Expand Down Expand Up @@ -268,4 +446,4 @@ public String toString() {
return round(MIN) + "/" + round((1.0 * SUM) / REPEAT) + "/" + round(MAX);
}
}
}
}

0 comments on commit 2aed039

Please sign in to comment.