Skip to content

Commit

Permalink
netrunnereve: more optimizations (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
netrunnereve authored Jan 19, 2024
1 parent ce8fe41 commit 144a6af
Showing 1 changed file with 67 additions and 53 deletions.
120 changes: 67 additions & 53 deletions src/main/java/dev/morling/onebrc/CalculateAverage_netrunnereve.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,18 @@
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.lang.Math;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CountDownLatch;
import java.lang.Math;

public class CalculateAverage_netrunnereve {

private static final String FILE = "./measurements.txt";
private static final int NUM_THREADS = 8; // test machine
private static final int LEN_EXTEND = 200; // guarantees a newline
private static final int HASHT_SIZE = 16384; // size of hash table, adjust tradeoff between colisions and cache utilization
private static final int DJB2_INIT = 5831;

private static class MeasurementAggregator { // min, max, sum stored as 0.1/unit
private MeasurementAggregator next = null; // linked list of entries for handling hash colisions
Expand All @@ -48,77 +51,51 @@ private static class ThreadCalcs {

// djb2 hash
private static int calc_hash(byte[] input, int len) {
int hash = 5831;
int hash = DJB2_INIT;
for (int i = 0; i < len; i++) {
hash = ((hash << 5) + hash) + Byte.toUnsignedInt(input[i]);
}
return Math.abs(hash % 16384);
return Math.abs(hash % HASHT_SIZE);
}

private static class ThreadedParser extends Thread {
private MappedByteBuffer mbuf;
private int mbs;
private ThreadCalcs[] threadOut;
private int threadID;
private CountDownLatch tpLatch;

private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID) {
private ThreadedParser(MappedByteBuffer mbuf, int mbs, ThreadCalcs[] threadOut, int threadID, CountDownLatch tpLatch) {
this.mbuf = mbuf;
this.mbs = mbs;
this.threadOut = threadOut;
this.threadID = threadID;
this.tpLatch = tpLatch;
}

public void run() {
MeasurementAggregator[] hashSpace = new MeasurementAggregator[16384]; // 14-bit hash
MeasurementAggregator[] hashSpace = new MeasurementAggregator[HASHT_SIZE]; // hash table
byte[] scratch = new byte[100]; // <= 100 characters in station name
String[] staArr = new String[10000]; // max 10000 station names
MeasurementAggregator ma = null;

int numStations = 0;
boolean state = false; // 0 for station pickup, 1 for measurement pickup
int negMul = 1;
int head = 0;
int tempCnt = -1; // 0 if 1 digit measurement, 1 if 2 digit
int hash = DJB2_INIT; // do calc_hash manually in loop

for (int i = 0; i < mbs; i++) {
int i = 0; // byte by byte iterator
while (true) {
byte cur = mbuf.get(i);
if (state == true) {
if (cur == 46) { // .
int tempa = mbuf.get(i + 1) - 48;
tempa += (scratch[0] - 48) * (10 + 90 * tempCnt) + (scratch[1] - 48) * (10 * tempCnt); // branchless
tempa *= negMul;

if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;

i += 2; // go to start of new line
state = false;
negMul = 1;
head = i + 1;
tempCnt = -1;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
}
else if (cur == 59) { // ;
int len = i - head;
if (cur == 59) { // ;
hash = Math.abs(hash % HASHT_SIZE);

// this is faster than filling scratch immediately after each byte is read
int len = i - head;
mbuf.position(head);
mbuf.get(scratch, 0, len);

int hash = calc_hash(scratch, len);
ma = hashSpace[hash];
MeasurementAggregator prev = null;

Expand Down Expand Up @@ -146,14 +123,53 @@ else if ((len != ma.station.length) || (Arrays.compare(scratch, 0, len, ma.stati
break;
}
}
state = true;
head = i + 1;

i++;
while (true) {
cur = mbuf.get(i);
if (cur == 46) { // .
int tempa = (negMul) * ((10 + 90 * tempCnt) * (scratch[0] - 48) + (10 * tempCnt) * (scratch[1] - 48) + (mbuf.get(i + 1) - 48)); // branchless

if (tempa < ma.min) {
ma.min = tempa;
}
if (tempa > ma.max) {
ma.max = tempa;
}
ma.sum += tempa;
ma.count++;

// this line is finished!
i += 2; // newline char
hash = DJB2_INIT;
negMul = 1;
head = i + 1; // start of next line
tempCnt = -1;
break;
}
else if (cur == 45) { // ascii -
negMul = -1;
}
else {
scratch[tempCnt + 1] = cur;
tempCnt++;
}
i++;
}
if (head >= mbs) {
break;
}
}
else {
hash = ((hash << 5) + hash) + Byte.toUnsignedInt(cur);
}
i++;
}
threadOut[threadID] = new ThreadCalcs();
threadOut[threadID].hashSpace = hashSpace;
threadOut[threadID].staArr = staArr;
threadOut[threadID].numStations = numStations;
tpLatch.countDown();
}
}

Expand All @@ -175,8 +191,8 @@ public static void main(String[] args) {
bufSize = Integer.MAX_VALUE;
}

ThreadedParser[] myThreads = new ThreadedParser[(int) threadNum];
ThreadCalcs[] threadOut = new ThreadCalcs[(int) threadNum];
CountDownLatch tpLatch = new CountDownLatch((int) threadNum);
int threadID = 0;

long h = 0;
Expand Down Expand Up @@ -206,27 +222,25 @@ public static void main(String[] args) {
}
}

myThreads[threadID] = new ThreadedParser(mbuf, mbs, threadOut, threadID);
myThreads[threadID].start();
ThreadedParser tpThr = new ThreadedParser(mbuf, mbs, threadOut, threadID, tpLatch);
tpThr.start();

h += mbs;
threadID++;
}

for (int i = 0; i < threadID; i++) {
try {
myThreads[i].join();
}
catch (InterruptedException ex) {
System.exit(1);
}
try {
tpLatch.await();
}
catch (InterruptedException ex) {
System.exit(1);
}

// use treemap to sort and uniquify
Map<String, Integer> staMap = new TreeMap<>();
Map<String, Boolean> staMap = new TreeMap<>();
for (int i = 0; i < threadID; i++) {
for (int j = 0; j < threadOut[i].numStations; j++) {
staMap.put(threadOut[i].staArr[j], 0);
staMap.put(threadOut[i].staArr[j], false);
}
}

Expand Down

0 comments on commit 144a6af

Please sign in to comment.