Skip to content

Commit

Permalink
One last improvement for thomaswue (#702)
Browse files Browse the repository at this point in the history
* Combine <8 and 8-16 cases into one case.

* Adopt mask-based approach for the <16 length city fast path (idea of Van Phu Do).

* Slightly improved code layout.

* Update perf number.
  • Loading branch information
thomaswue authored Feb 1, 2024
1 parent 4debc7c commit 241d42c
Showing 1 changed file with 66 additions and 61 deletions.
127 changes: 66 additions & 61 deletions src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread.
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in
* the end.
* Runs in 0.39s on an Intel i9-13900K.
* Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s.
* Credit:
* Quan Anh Mai for branchless number parsing code
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea
* Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers
* Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if
* more work is performed
* Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting
*/
public class CalculateAverage_thomaswue {
private static final String FILE = "./measurements.txt";
Expand Down Expand Up @@ -141,9 +144,15 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart,
long delimiterMask1 = findDelimiter(word1);
long delimiterMask2 = findDelimiter(word2);
long delimiterMask3 = findDelimiter(word3);
Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults);
long word1b = scanner1.getLongAt(scanner1.pos() + 8);

This comment has been minimized.

Copy link
@Alex-Cheng

Alex-Cheng Feb 23, 2024

@thomaswue could please teach me the knowledge about why dividing the data into 3 parts can improve the performance. Thank you. My email is [email protected]

This comment has been minimized.

Copy link
@gunnarmorling

gunnarmorling Feb 23, 2024

Owner

Hey @Alex-Cheng, I think it's asking a bit too much of folks for sending you this kind of advice via 1:1 email. I suggest you take a look at the blog posts listed here, they dive into many of the techniques employed. As for this particular one, we've discussed it in this live stream. Essentially, as there's no data dependency between these different chunks of the file, a modern CPU can parallelize the processing, thus reducing overall wallclock time.

This comment has been minimized.

Copy link
@Alex-Cheng

Alex-Cheng Feb 29, 2024

Thank you. I have read some of the resources you gave me and have followed you in twitter. Hope more such a great event. :-)

long word2b = scanner2.getLongAt(scanner2.pos() + 8);
long word3b = scanner3.getLongAt(scanner3.pos() + 8);
long delimiterMask1b = findDelimiter(word1b);
long delimiterMask2b = findDelimiter(word2b);
long delimiterMask3b = findDelimiter(word3b);
Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
Expand All @@ -155,76 +164,70 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart,
while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
long wordB = scanner1.getLongAt(scanner1.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1));
}
while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
long wordB = scanner2.getLongAt(scanner2.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2));
}
while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
long wordB = scanner3.getLongAt(scanner3.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}

private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) {
private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL,
0xFFFFFFFFFFFFFFFFL };
private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL };

private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results,
List<Result> collectedResults) {
Result existingResult;
long word = initialWord;
long delimiterMask = initialDelimiterMask;
long hash;
long nameAddress = scanner.pos();

// Search for ';', one long at a time. There are two common cases that a specially treated:
// (b) the ';' is found in the first 16 bytes
if (delimiterMask != 0) {
// Special case for when the ';' is found in the first 8 bytes.
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash = word;
long word2 = wordB;
long delimiterMask2 = delimiterMaskB;
if ((delimiterMask | delimiterMask2) != 0) {
int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8
int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8
long mask = MASK2[letterCount1];
word = word & MASK1[letterCount1];
word2 = mask & word2 & MASK1[letterCount2];
hash = word ^ word2;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word) {
scanner.add(letterCount1 + (letterCount2 & mask));
if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) {
return existingResult;
}
}
else {
// Special case for when the ';' is found in bytes 9-16.
hash = word;
long prevWord = word;
scanner.add(8);
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) {
return existingResult;
// Slow-path for when the ';' could not be found in the first 16 bytes.
hash = word ^ word2;
scanner.add(16);
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
}
else {
// Slow-path for when the ';' could not be found in the first 16 bytes.
scanner.add(8);
hash ^= word;
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
else {
scanner.add(8);
hash ^= word;
}
else {
scanner.add(8);
hash ^= word;
}
}
}
Expand All @@ -249,8 +252,8 @@ private static Result findResult(long initialWord, long initialDelimiterMask, Sc
}
}

int remainingShift = (64 - (nameLength + 1 - i) << 3);
if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) {
int remainingShift = (64 - ((nameLength + 1 - i) << 3));
if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) {
break;
}
else {
Expand Down Expand Up @@ -297,7 +300,7 @@ private static void record(Result existingResult, long number) {
}

private static int hashToIndex(long hash, Result[] results) {
long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17);
long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15);
return (int) (hashAsInt & (results.length - 1));
}

Expand All @@ -324,21 +327,23 @@ private static long findDelimiter(long word) {
private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) {
Result r = new Result();
results[hash] = r;
int i = 0;
for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) {
int totalLength = nameLength + 1;
r.firstNameWord = scanner.getLongAt(nameAddress);
r.secondNameWord = scanner.getLongAt(nameAddress + 8);
if (totalLength <= 8) {
r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1];
r.secondNameWord = 0;
}
if (nameLength + 1 > 8) {
r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8);
else if (totalLength < 16) {
r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9];
}
int remainingShift = (64 - (nameLength + 1 - i) << 3);
r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.nameAddress = nameAddress;
collectedResults.add(r);
return r;
}

private static final class Result {
long lastNameLong, secondLastNameLong;
long firstNameWord, secondNameWord;
short min, max;
int count;
long sum;
Expand Down

0 comments on commit 241d42c

Please sign in to comment.