-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
* Combine <8 and 8-16 cases into one case. * Adopt mask-based approach for the <16 length city fast path (idea of Van Phu Do). * Slightly improved code layout. * Update perf number.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,11 +27,14 @@ | |
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread. | ||
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in | ||
* the end. | ||
* Runs in 0.39s on an Intel i9-13900K. | ||
* Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s. | ||
* Credit: | ||
* Quan Anh Mai for branchless number parsing code | ||
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea | ||
* Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers | ||
* Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if | ||
* more work is performed | ||
* Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting | ||
*/ | ||
public class CalculateAverage_thomaswue { | ||
private static final String FILE = "./measurements.txt"; | ||
|
@@ -141,9 +144,15 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, | |
long delimiterMask1 = findDelimiter(word1); | ||
long delimiterMask2 = findDelimiter(word2); | ||
long delimiterMask3 = findDelimiter(word3); | ||
Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults); | ||
Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults); | ||
Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults); | ||
long word1b = scanner1.getLongAt(scanner1.pos() + 8); | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
gunnarmorling
Owner
|
||
long word2b = scanner2.getLongAt(scanner2.pos() + 8); | ||
long word3b = scanner3.getLongAt(scanner3.pos() + 8); | ||
long delimiterMask1b = findDelimiter(word1b); | ||
long delimiterMask2b = findDelimiter(word2b); | ||
long delimiterMask3b = findDelimiter(word3b); | ||
Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults); | ||
Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults); | ||
Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults); | ||
long number1 = scanNumber(scanner1); | ||
long number2 = scanNumber(scanner2); | ||
long number3 = scanNumber(scanner3); | ||
|
@@ -155,76 +164,70 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart, | |
while (scanner1.hasNext()) { | ||
long word = scanner1.getLong(); | ||
long pos = findDelimiter(word); | ||
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1)); | ||
long wordB = scanner1.getLongAt(scanner1.pos() + 8); | ||
long posB = findDelimiter(wordB); | ||
record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1)); | ||
} | ||
while (scanner2.hasNext()) { | ||
long word = scanner2.getLong(); | ||
long pos = findDelimiter(word); | ||
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2)); | ||
long wordB = scanner2.getLongAt(scanner2.pos() + 8); | ||
long posB = findDelimiter(wordB); | ||
record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2)); | ||
} | ||
while (scanner3.hasNext()) { | ||
long word = scanner3.getLong(); | ||
long pos = findDelimiter(word); | ||
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3)); | ||
long wordB = scanner3.getLongAt(scanner3.pos() + 8); | ||
long posB = findDelimiter(wordB); | ||
record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3)); | ||
} | ||
} | ||
} | ||
|
||
private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) { | ||
private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, | ||
0xFFFFFFFFFFFFFFFFL }; | ||
private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL }; | ||
|
||
private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results, | ||
List<Result> collectedResults) { | ||
Result existingResult; | ||
long word = initialWord; | ||
long delimiterMask = initialDelimiterMask; | ||
long hash; | ||
long nameAddress = scanner.pos(); | ||
|
||
// Search for ';', one long at a time. There are two common cases that a specially treated: | ||
// (b) the ';' is found in the first 16 bytes | ||
if (delimiterMask != 0) { | ||
// Special case for when the ';' is found in the first 8 bytes. | ||
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||
word = (word << (63 - trailingZeros)); | ||
scanner.add(trailingZeros >>> 3); | ||
hash = word; | ||
long word2 = wordB; | ||
long delimiterMask2 = delimiterMaskB; | ||
if ((delimiterMask | delimiterMask2) != 0) { | ||
int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8 | ||
int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8 | ||
long mask = MASK2[letterCount1]; | ||
word = word & MASK1[letterCount1]; | ||
word2 = mask & word2 & MASK1[letterCount2]; | ||
hash = word ^ word2; | ||
existingResult = results[hashToIndex(hash, results)]; | ||
if (existingResult != null && existingResult.lastNameLong == word) { | ||
scanner.add(letterCount1 + (letterCount2 & mask)); | ||
if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) { | ||
return existingResult; | ||
} | ||
} | ||
else { | ||
// Special case for when the ';' is found in bytes 9-16. | ||
hash = word; | ||
long prevWord = word; | ||
scanner.add(8); | ||
word = scanner.getLong(); | ||
delimiterMask = findDelimiter(word); | ||
if (delimiterMask != 0) { | ||
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||
word = (word << (63 - trailingZeros)); | ||
scanner.add(trailingZeros >>> 3); | ||
hash ^= word; | ||
existingResult = results[hashToIndex(hash, results)]; | ||
if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) { | ||
return existingResult; | ||
// Slow-path for when the ';' could not be found in the first 16 bytes. | ||
hash = word ^ word2; | ||
scanner.add(16); | ||
while (true) { | ||
word = scanner.getLong(); | ||
delimiterMask = findDelimiter(word); | ||
if (delimiterMask != 0) { | ||
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||
word = (word << (63 - trailingZeros)); | ||
scanner.add(trailingZeros >>> 3); | ||
hash ^= word; | ||
break; | ||
} | ||
} | ||
else { | ||
// Slow-path for when the ';' could not be found in the first 16 bytes. | ||
scanner.add(8); | ||
hash ^= word; | ||
while (true) { | ||
word = scanner.getLong(); | ||
delimiterMask = findDelimiter(word); | ||
if (delimiterMask != 0) { | ||
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask); | ||
word = (word << (63 - trailingZeros)); | ||
scanner.add(trailingZeros >>> 3); | ||
hash ^= word; | ||
break; | ||
} | ||
else { | ||
scanner.add(8); | ||
hash ^= word; | ||
} | ||
else { | ||
scanner.add(8); | ||
hash ^= word; | ||
} | ||
} | ||
} | ||
|
@@ -249,8 +252,8 @@ private static Result findResult(long initialWord, long initialDelimiterMask, Sc | |
} | ||
} | ||
|
||
int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||
if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) { | ||
int remainingShift = (64 - ((nameLength + 1 - i) << 3)); | ||
if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) { | ||
break; | ||
} | ||
else { | ||
|
@@ -297,7 +300,7 @@ private static void record(Result existingResult, long number) { | |
} | ||
|
||
private static int hashToIndex(long hash, Result[] results) { | ||
long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17); | ||
long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15); | ||
return (int) (hashAsInt & (results.length - 1)); | ||
} | ||
|
||
|
@@ -324,21 +327,23 @@ private static long findDelimiter(long word) { | |
private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) { | ||
Result r = new Result(); | ||
results[hash] = r; | ||
int i = 0; | ||
for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) { | ||
int totalLength = nameLength + 1; | ||
r.firstNameWord = scanner.getLongAt(nameAddress); | ||
r.secondNameWord = scanner.getLongAt(nameAddress + 8); | ||
if (totalLength <= 8) { | ||
r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1]; | ||
r.secondNameWord = 0; | ||
} | ||
if (nameLength + 1 > 8) { | ||
r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8); | ||
else if (totalLength < 16) { | ||
r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9]; | ||
} | ||
int remainingShift = (64 - (nameLength + 1 - i) << 3); | ||
r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift); | ||
r.nameAddress = nameAddress; | ||
collectedResults.add(r); | ||
return r; | ||
} | ||
|
||
private static final class Result { | ||
long lastNameLong, secondLastNameLong; | ||
long firstNameWord, secondNameWord; | ||
short min, max; | ||
int count; | ||
long sum; | ||
|
@thomaswue could please teach me the knowledge about why dividing the data into 3 parts can improve the performance. Thank you. My email is [email protected]