Skip to content

Commit

Permalink
Added SWAR (SIMD Within A Register) code to increase bytebuffer proce…
Browse files Browse the repository at this point in the history
…ssing/throughput
  • Loading branch information
royvanrijn committed Jan 2, 2024
1 parent 0c22481 commit feb0b8f
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ public void run() throws Exception {
}
}

private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');
private static final long EOL_PATTERN = compilePattern((byte) '\n');

private BitTwiddledMap processBuffer(ByteBuffer bb) {

BitTwiddledMap measurements = new BitTwiddledMap();
Expand All @@ -134,49 +137,68 @@ private BitTwiddledMap processBuffer(ByteBuffer bb) {

while (bb.position() < limit) {

// Find the correct positions in the bytebuffer:

// Start:
final int startPointer = bb.position();

// Separator:
int separatorPointer = startPointer + 3; // key is at least 3 long
while (separatorPointer != limit && bb.get(separatorPointer) != ';') {
separatorPointer++;
}
int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit);
int endPointer = findNextSWAR(bb, EOL_PATTERN, separatorPointer + 3, limit);

// EOL:
int endPointer = separatorPointer + 3; // temperature is at least 3 long
while (endPointer != limit && bb.get(endPointer) != '\n')
endPointer++;
final int entryLength = endPointer - startPointer;

// Extract the name of the key and move the bytebuffer:
// Read the entry in a single get():
bb.get(buffer, 0, entryLength);
bb.position(endPointer + 1); // skip the separator

// Extract the name of the key:
final int nameLength = separatorPointer - startPointer;
bb.get(buffer, 0, nameLength);
final String key = new String(buffer, 0, nameLength);

bb.get(); // skip the separator

// Extract the measurement value (10x), skip making a String altogether:
final int valueLength = endPointer - separatorPointer - 1;
bb.get(buffer, 0, valueLength);
final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength);

// and get rid of the new line (handle both kinds)
byte newline = bb.get();
if (newline == '\r')
bb.get();

int measured = branchlessParseInt(buffer, valueLength);

// Update the map, computeIfAbsent has the least amount of branches I think, compared to get()/put() or merge() or compute():
measurements.getOrCreate(key).updateWith(measured);
}

return measurements;
}

/**
* Thanks to bjhara for the idea of using memory mapped files, TIL.
* -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values:
*/

private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) {
int i;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
int index = firstAnyPattern(word, pattern);
if (index < Long.BYTES) {
return i + index;
}
}
// Handle remaining bytes
for (; i < limit; i++) {
if (bb.get(i) == (byte) pattern) {
return i;
}
}
return limit; // delimiter not found
}

private static long compilePattern(byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}

private static int firstAnyPattern(long word, long pattern) {
final long match = word ^ pattern;
long mask = match - 0x0101010101010101L;
mask &= ~match;
mask &= 0x8080808080808080L;
return Long.numberOfLeadingZeros(mask) >> 3;
}

/**
* -------- Thanks to bjhara for the idea of using memory mapped files.
* @param fileChannel
* @return
* @throws IOException
Expand Down Expand Up @@ -224,15 +246,15 @@ public boolean hasNext() {
* @param input
* @return int value x10
*/
private static int branchlessParseInt(final byte[] input, int length) {
private static int branchlessParseInt(final byte[] input, int start, int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[0] >> 4) & 1;
final int negative = ~(input[start] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
final int has4 = ((length - negative) >> 2) & 1;

final int digit1 = input[negative] - '0';
final int digit2 = input[negative + has4] - '0';
final int digit3 = input[2 + negative + has4] - '0';
final int digit1 = input[start + negative] - '0';
final int digit2 = input[start + negative + has4] - '0';
final int digit3 = input[start + 2 + negative + has4] - '0';

return (-negative ^ (has4 * (digit1 * 100) + digit2 * 10 + digit3) - negative);
}
Expand Down

0 comments on commit feb0b8f

Please sign in to comment.