Skip to content

Commit

Permalink
Spread work more evenly between threads
Browse files Browse the repository at this point in the history
Vectorize skipping first partial line
  • Loading branch information
ianopolous committed Jan 31, 2024
1 parent f1fd7b7 commit 3bfc538
Showing 1 changed file with 56 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.*;
import java.util.*;

Expand All @@ -45,6 +47,7 @@ public class CalculateAverage_ianopolousfast {
public static final int MAX_LINE_LENGTH = 107;
public static final int MAX_STATIONS = 1 << 14;
private static final OfLong LONG_LAYOUT = JAVA_LONG_UNALIGNED.withOrder(ByteOrder.BIG_ENDIAN);
public static final int CHUNK_SIZE = 32 * 1024 * 1024;
private static final VectorSpecies<Byte> BYTE_SPECIES = ByteVector.SPECIES_PREFERRED.length() >= 16
? ByteVector.SPECIES_128
: ByteVector.SPECIES_64;
Expand All @@ -55,11 +58,13 @@ public static void main(String[] args) throws Exception {
FileChannel channel = (FileChannel) Files.newByteChannel(input, StandardOpenOption.READ);
long filesize = Files.size(input);
MemorySegment mmap = channel.map(FileChannel.MapMode.READ_ONLY, 0, filesize, arena);
int nChunks = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
long chunkSize = (filesize + nChunks - 1) / nChunks;
List<Stat[]> allResults = IntStream.range(0, nChunks)
int nThreads = filesize < 4 * 1024 * 1024 ? 1 : Runtime.getRuntime().availableProcessors();
LinkedBlockingQueue<Long> chunks = new LinkedBlockingQueue<>();
for (long i = 0; i < (filesize + CHUNK_SIZE - 1) / CHUNK_SIZE; i++)
chunks.add(i);
List<Stat[]> allResults = IntStream.range(0, nThreads)
.parallel()
.mapToObj(i -> parseStats(i * chunkSize, Math.min((i + 1) * chunkSize, filesize), mmap))
.mapToObj(i -> parseStats(chunks, mmap))
.toList();

TreeMap<String, Stat> merged = allResults.stream()
Expand Down Expand Up @@ -163,47 +168,59 @@ private static int keySize(int lineSize, long lineStart, MemorySegment buffer) {
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
}

public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) {
// read first partial line
if (start1 > 0) {
for (int i = 0; i < MAX_LINE_LENGTH; i++) {
byte b = buffer.get(JAVA_BYTE, start1++);
if (b == '\n') {
break;
}
}
}

public static Stat[] parseStats(LinkedBlockingQueue<Long> chunks, MemorySegment buffer) {
Stat[] stations = new Stat[MAX_STATIONS];

// Handle reading the very last few lines in the file
// this allows us to not worry about reading beyond the end
// in the inner loop (reducing branches)
// We need at least the vector lane size bytes back
if (end2 == buffer.byteSize()) {
// reverse at least vector lane width
end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0);
while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n')
end2--;
while (!chunks.isEmpty()) {
long start1, end2;
try {
long nChunk = chunks.poll(0, TimeUnit.MICROSECONDS);
start1 = nChunk * CHUNK_SIZE;
end2 = Math.min(buffer.byteSize(), (nChunk + 1) * CHUNK_SIZE);
}
catch (InterruptedException e) {
continue;
}
// read first partial line
if (start1 > 0) {
int lineSize1 = lineSize(start1, buffer);
start1 += lineSize1 + 1;
}

// Handle reading the very last few lines in the file
// this allows us to not worry about reading beyond the end
// in the inner loop (reducing branches)
// We need at least the vector lane size bytes back
if (end2 == buffer.byteSize()) {
// reverse at least vector lane width
end2 = Math.max(buffer.byteSize() - 2 * BYTE_SPECIES.vectorByteSize(), 0);
while (end2 > 0 && buffer.get(JAVA_BYTE, end2) != '\n')
end2--;

if (end2 > 0)
end2++;
// copy into a larger buffer to avoid reading off end
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize());
for (long i = end2; i < buffer.byteSize(); i++)
end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i));
int index = 0;
while (end2 + index < buffer.byteSize()) {
int lineSize1 = lineSize(index, end);
int semiSearchStart = index + Math.max(0, lineSize1 - 6);
int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart,
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
Stat station1 = dedupeStation(index, index + keySize1, end, stations);
processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1);
index += lineSize1 + 1;
if (end2 > 0)
end2++;
// copy into a larger buffer to avoid reading off end
MemorySegment end = Arena.global().allocate(MAX_LINE_LENGTH + 2 * BYTE_SPECIES.vectorByteSize());
for (long i = end2; i < buffer.byteSize(); i++)
end.set(JAVA_BYTE, i - end2, buffer.get(JAVA_BYTE, i));
int index = 0;
while (end2 + index < buffer.byteSize()) {
int lineSize1 = lineSize(index, end);
int semiSearchStart = index + Math.max(0, lineSize1 - 6);
int keySize1 = semiSearchStart - index + ByteVector.fromMemorySegment(BYTE_SPECIES, end, semiSearchStart,
ByteOrder.nativeOrder()).compare(VectorOperators.EQ, ';').firstTrue();
Stat station1 = dedupeStation(index, index + keySize1, end, stations);
processTemperature(index + keySize1 + 1, lineSize1 - keySize1 - 1, end, station1);
index += lineSize1 + 1;
}
}

innerLoop(start1, end2, buffer, stations);
}
return stations;
}

private static void innerLoop(long start1, long end2, MemorySegment buffer, Stat[] stations) {
while (start1 < end2) {
int lineSize1 = lineSize(start1, buffer);
long start2 = start1 + lineSize1 + 1;
Expand All @@ -220,7 +237,6 @@ public static Stat[] parseStats(long start1, long end2, MemorySegment buffer) {
else
start1 += lineSize1 + 1;
}
return stations;
}

public static class Stat {
Expand Down

0 comments on commit 3bfc538

Please sign in to comment.