From 101993f06d1e63e3d56ab57483ff11a3349c47aa Mon Sep 17 00:00:00 2001 From: Anita SV Date: Thu, 1 Feb 2024 03:15:23 -0800 Subject: [PATCH] CA_vaidhy final changes. (#708) --- .../onebrc/CalculateAverage_vaidhy.java | 367 +++++++++++++----- 1 file changed, 272 insertions(+), 95 deletions(-) diff --git a/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java index 5795077b3..f63374a10 100644 --- a/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java +++ b/src/main/java/dev/morling/onebrc/CalculateAverage_vaidhy.java @@ -21,6 +21,7 @@ import java.lang.foreign.Arena; import java.lang.reflect.Field; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.channels.FileChannel; import java.nio.charset.StandardCharsets; import java.nio.file.Path; @@ -37,69 +38,149 @@ public class CalculateAverage_vaidhy { private static final class HashEntry { private long startAddress; - private long endAddress; + private long keyLength; private long suffix; - private int hash; - + private int next; IntSummaryStatistics value; } private static class PrimitiveHashMap { private final HashEntry[] entries; + private final long[] hashes; + private final int twoPow; + private int next = -1; PrimitiveHashMap(int twoPow) { this.twoPow = twoPow; this.entries = new HashEntry[1 << twoPow]; + this.hashes = new long[1 << twoPow]; for (int i = 0; i < entries.length; i++) { this.entries[i] = new HashEntry(); } } - public HashEntry find(long startAddress, long endAddress, long suffix, int hash) { + public IntSummaryStatistics find(long startAddress, long endAddress, long hash, long suffix) { int len = entries.length; - int i = (hash ^ (hash >> twoPow)) & (len - 1); + int h = Long.hashCode(hash); + int initialIndex = (h ^ (h >> twoPow)) & (len - 1); + int i = initialIndex; + long lookupLength = endAddress - startAddress; - do { + long hashEntry = hashes[i]; + + if (hashEntry == hash) { HashEntry entry = entries[i]; - if (entry.value == null) { - return entry; + if (lookupLength <= 7) { + // This works because + // hash = suffix , when simpleHash is just xor. + // Since length is not 8, suffix will have a 0 at the end. + // Since utf-8 strings can't have 0 in middle of a string this means + // we can stop here. + return entry.value; } - if (entry.hash == hash) { - long entryLength = entry.endAddress - entry.startAddress; - long lookupLength = endAddress - startAddress; - if ((entryLength == lookupLength) && (entry.suffix == suffix)) { - boolean found = compareEntryKeys(startAddress, endAddress, entry); - - if (found) { - return entry; - } + boolean found = (entry.suffix == suffix && + compareEntryKeys(startAddress, endAddress, entry.startAddress)); + if (found) { + return entry.value; + } + } + + if (hashEntry == 0) { + HashEntry entry = entries[i]; + entry.startAddress = startAddress; + entry.keyLength = lookupLength; + hashes[i] = hash; + entry.suffix = suffix; + entry.next = next; + this.next = i; + entry.value = new IntSummaryStatistics(); + return entry.value; + } + + i++; + if (i == len) { + i = 0; + } + + if (i == initialIndex) { + return null; + } + + do { + hashEntry = hashes[i]; + if (hashEntry == hash) { + HashEntry entry = entries[i]; + if (lookupLength <= 7) { + return entry.value; + } + boolean found = (entry.suffix == suffix && + compareEntryKeys(startAddress, endAddress, entry.startAddress)); + if (found) { + return entry.value; } } + if (hashEntry == 0) { + HashEntry entry = entries[i]; + entry.startAddress = startAddress; + entry.keyLength = lookupLength; + hashes[i] = hash; + entry.suffix = suffix; + entry.next = next; + this.next = i; + entry.value = new IntSummaryStatistics(); + return entry.value; + } + i++; if (i == len) { i = 0; } - } while (i != hash); + } while (i != initialIndex); return null; } - private static boolean compareEntryKeys(long startAddress, long endAddress, HashEntry entry) { - long entryIndex = entry.startAddress; + private static boolean compareEntryKeys(long startAddress, long endAddress, long entryStartAddress) { + long entryIndex = entryStartAddress; long lookupIndex = startAddress; + long endAddressStop = endAddress - 7; - for (; (lookupIndex + 7) < endAddress; lookupIndex += 8) { + for (; lookupIndex < endAddressStop; lookupIndex += 8) { if (UNSAFE.getLong(entryIndex) != UNSAFE.getLong(lookupIndex)) { return false; } entryIndex += 8; } + return true; } + + public Iterable entrySet() { + return () -> new Iterator<>() { + int scan = next; + + @Override + public boolean hasNext() { + return scan != -1; + } + + @Override + public HashEntry next() { + HashEntry entry = entries[scan]; + scan = entry.next; + return entry; + } + }; + } } private static final String FILE = "./measurements.txt"; + private static long simpleHash(long hash, long nextData) { + return hash ^ nextData; + // return (hash ^ Long.rotateLeft((nextData * C1), R1)) * C2; + } + private static Unsafe initUnsafe() { try { Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); @@ -145,7 +226,7 @@ private static int parseDouble(long startAddress, long endAddress) { interface MapReduce { - void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix); + void process(long keyStartAddress, long keyEndAddress, long hash, long suffix, int temperature); I result(); } @@ -173,9 +254,13 @@ static class LineStream { private final long chunkEnd; private long position; - private int hash; + private long hash; + private long suffix; - byte[] b = new byte[4]; + + private final ByteBuffer buf = ByteBuffer + .allocate(8) + .order(ByteOrder.LITTLE_ENDIAN); public LineStream(FileService fileService, long offset, long chunkSize) { long fileStart = fileService.address(); @@ -186,50 +271,38 @@ public LineStream(FileService fileService, long offset, long chunkSize) { } public boolean hasNext() { - return position <= chunkEnd && position < fileEnd; + return position <= chunkEnd; } public long findSemi() { - int h = 0; - long s = 0; - long i = position; - while ((i + 3) < fileEnd) { - // Adding 16 as it is the offset for primitive arrays - ByteBuffer.wrap(b).putInt(UNSAFE.getInt(i)); - - if (b[3] == 0x3B) { - break; - } - i++; - h = ((h << 5) - h) ^ b[3]; - s = (s << 8) ^ b[3]; + long h = 0; + buf.rewind(); - if (b[2] == 0x3B) { - break; + for (long i = position; i < fileEnd; i++) { + byte ch = UNSAFE.getByte(i); + if (ch == ';') { + int discard = buf.remaining(); + buf.rewind(); + long nextData = (buf.getLong() << discard) >>> discard; + this.suffix = nextData; + this.hash = simpleHash(h, nextData); + position = i + 1; + return i; } - i++; - h = ((h << 5) - h) ^ b[2]; - s = (s << 8) ^ b[2]; - - if (b[1] == 0x3B) { - break; + if (buf.hasRemaining()) { + buf.put(ch); } - i++; - h = ((h << 5) - h) ^ b[1]; - s = (s << 8) ^ b[1]; - - if (b[0] == 0x3B) { - break; + else { + buf.flip(); + long nextData = buf.getLong(); + h = simpleHash(h, nextData); + buf.rewind(); } - i++; - h = ((h << 5) - h) ^ b[0]; - s = (s << 8) ^ b[0]; } - this.hash = h; - this.suffix = s; - position = i + 1; - return i; + this.suffix = buf.getLong(); + position = fileEnd; + return fileEnd; } public long skipLine() { @@ -258,7 +331,94 @@ public long findTemperature() { } } - private void worker(long offset, long chunkSize, MapReduce lineConsumer) { + private static final long START_BYTE_INDICATOR = 0x0101_0101_0101_0101L; + private static final long END_BYTE_INDICATOR = START_BYTE_INDICATOR << 7; + + private static final long NEW_LINE_DETECTION = START_BYTE_INDICATOR * '\n'; + + private static final long SEMI_DETECTION = START_BYTE_INDICATOR * ';'; + + private static final long ALL_ONES = 0xffff_ffff_ffff_ffffL; + + private long findByteOctet(long data, long pattern) { + long match = data ^ pattern; + return (match - START_BYTE_INDICATOR) & ((~match) & END_BYTE_INDICATOR); + } + + private void bigWorker(long offset, long chunkSize, MapReduce lineConsumer) { + long chunkStart = offset + fileService.address(); + long chunkEnd = chunkStart + chunkSize; + long fileEnd = fileService.address() + fileService.length(); + long stopPoint = Math.min(chunkEnd + 1, fileEnd); + + boolean skip = offset != 0; + for (long position = chunkStart; position < stopPoint;) { + if (skip) { + long data = UNSAFE.getLong(position); + long newLineMask = findByteOctet(data, NEW_LINE_DETECTION); + if (newLineMask != 0) { + int newLinePosition = Long.numberOfTrailingZeros(newLineMask) >>> 3; + skip = false; + position = position + newLinePosition + 1; + } + else { + position = position + 8; + } + continue; + } + + long stationStart = position; + long stationEnd = -1; + long hash = 0; + long suffix = 0; + do { + long data = UNSAFE.getLong(position); + long semiMask = findByteOctet(data, SEMI_DETECTION); + if (semiMask != 0) { + int semiPosition = Long.numberOfTrailingZeros(semiMask) >>> 3; + stationEnd = position + semiPosition; + position = stationEnd + 1; + + if (semiPosition != 0) { + suffix = data & (ALL_ONES >>> (64 - (semiPosition << 3))); + } + else { + suffix = UNSAFE.getLong(position - 8); + } + hash = simpleHash(hash, suffix); + break; + } + else { + hash = simpleHash(hash, data); + position = position + 8; + } + } while (true); + + int temperature = 0; + { + byte ch = UNSAFE.getByte(position++); + boolean negative = false; + if (ch == '-') { + negative = true; + ch = UNSAFE.getByte(position++); + } + do { + if (ch != '.') { + temperature *= 10; + temperature += (ch ^ '0'); + } + ch = UNSAFE.getByte(position++); + } while (ch != '\n'); + if (negative) { + temperature = -temperature; + } + } + + lineConsumer.process(stationStart, stationEnd, hash, suffix, temperature); + } + } + + private void smallWorker(long offset, long chunkSize, MapReduce lineConsumer) { LineStream lineStream = new LineStream(fileService, offset, chunkSize); if (offset != 0) { @@ -274,29 +434,58 @@ private void worker(long offset, long chunkSize, MapReduce lineConsumer) { while (lineStream.hasNext()) { long keyStartAddress = lineStream.position; long keyEndAddress = lineStream.findSemi(); - long keySuffix = lineStream.suffix; - int keyHash = lineStream.hash; + long keyHash = lineStream.hash; + long suffix = lineStream.suffix; long valueStartAddress = lineStream.position; long valueEndAddress = lineStream.findTemperature(); int temperature = parseDouble(valueStartAddress, valueEndAddress); - lineConsumer.process(keyStartAddress, keyEndAddress, keyHash, temperature, keySuffix); + // System.out.println("Small worker!"); + lineConsumer.process(keyStartAddress, keyEndAddress, keyHash, suffix, temperature); } } - public T master(long chunkSize, ExecutorService executor) { - long len = fileService.length(); + // file size = 7 + // (0,0) (0,0) small chunk= (0,7) + // a;0.1\n + + public T master(int shards, ExecutorService executor) { List> summaries = new ArrayList<>(); + long len = fileService.length(); + + if (len > 128) { + long bigChunk = Math.floorDiv(len, shards); + long bigChunkReAlign = bigChunk & 0xffff_ffff_ffff_fff8L; + + long smallChunkStart = bigChunkReAlign * shards; + long smallChunkSize = len - smallChunkStart; + + for (long offset = 0; offset < smallChunkStart; offset += bigChunkReAlign) { + MapReduce mr = chunkProcessCreator.get(); + final long transferOffset = offset; + Future task = executor.submit(() -> { + bigWorker(transferOffset, bigChunkReAlign, mr); + return mr.result(); + }); + summaries.add(task); + } + + MapReduce mrLast = chunkProcessCreator.get(); + Future lastTask = executor.submit(() -> { + smallWorker(smallChunkStart, smallChunkSize - 1, mrLast); + return mrLast.result(); + }); + summaries.add(lastTask); + } + else { - for (long offset = 0; offset < len; offset += chunkSize) { - long workerLength = Math.min(len, offset + chunkSize) - offset; - MapReduce mr = chunkProcessCreator.get(); - final long transferOffset = offset; - Future task = executor.submit(() -> { - worker(transferOffset, workerLength, mr); - return mr.result(); + MapReduce mrLast = chunkProcessCreator.get(); + Future lastTask = executor.submit(() -> { + smallWorker(0, len - 1, mrLast); + return mrLast.result(); }); - summaries.add(task); + summaries.add(lastTask); } + List summariesDone = summaries.stream() .map(task -> { try { @@ -336,22 +525,12 @@ public long address() { private static class ChunkProcessorImpl implements MapReduce { // 1 << 14 > 10,000 so it works - private final PrimitiveHashMap statistics = new PrimitiveHashMap(14); + private final PrimitiveHashMap statistics = new PrimitiveHashMap(15); @Override - public void process(long keyStartAddress, long keyEndAddress, int hash, int temperature, long suffix) { - HashEntry entry = statistics.find(keyStartAddress, keyEndAddress, suffix, hash); - if (entry == null) { - throw new IllegalStateException("Hash table too small :("); - } - if (entry.value == null) { - entry.startAddress = keyStartAddress; - entry.endAddress = keyEndAddress; - entry.suffix = suffix; - entry.hash = hash; - entry.value = new IntSummaryStatistics(); - } - entry.value.accept(temperature); + public void process(long keyStartAddress, long keyEndAddress, long hash, long suffix, int temperature) { + IntSummaryStatistics stats = statistics.find(keyStartAddress, keyEndAddress, hash, suffix); + stats.accept(temperature); } @Override @@ -368,13 +547,10 @@ public static void main(String[] args) throws IOException { ChunkProcessorImpl::new, CalculateAverage_vaidhy::combineOutputs); - int proc = 2 * Runtime.getRuntime().availableProcessors(); - - long fileSize = diskFileService.length(); - long chunkSize = Math.ceilDiv(fileSize, proc); + int proc = Runtime.getRuntime().availableProcessors(); ExecutorService executor = Executors.newFixedThreadPool(proc); - Map output = calculateAverageVaidhy.master(chunkSize, executor); + Map output = calculateAverageVaidhy.master(2 * proc, executor); executor.shutdown(); Map outputStr = toPrintMap(output); @@ -395,11 +571,12 @@ private static Map toPrintMap(Map private static Map combineOutputs( List list) { - Map output = new HashMap<>(10000); + Map output = HashMap.newHashMap(10000); for (PrimitiveHashMap map : list) { - for (HashEntry entry : map.entries) { + for (HashEntry entry : map.entrySet()) { if (entry.value != null) { - String keyStr = unsafeToString(entry.startAddress, entry.endAddress); + String keyStr = unsafeToString(entry.startAddress, + entry.startAddress + entry.keyLength); output.compute(keyStr, (ignore, val) -> { if (val == null) {