Skip to content

Commit

Permalink
Refactor: add UnsafeSerialization.writeIntsWithPrefix method to simpl…
Browse files Browse the repository at this point in the history
…ify a couple LSH models (#593)
  • Loading branch information
alexklibisz authored Nov 20, 2023
1 parent 7ccdc10 commit 1a1a561
Show file tree
Hide file tree
Showing 9 changed files with 915 additions and 840 deletions.
1 change: 1 addition & 0 deletions developer-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ See ann-benchmarks/README.md
- To run Elasticsearch on Linux, you need to increase the `vm.max_map_count` setting. [See the Elasticsearch docs.](https://www.elastic.co/guide/en/elasticsearch/reference/current/vm-max-map-count.html)
- To run ann-benchmarks on MacOS, you might need to `export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES`. [See this Stackoverflow answer.](https://stackoverflow.com/a/52230415)
- If you're running on MacOS 13.x (Ventura), the operating system's privacy settings might block `task jvmRunLocal` from starting. One solution is to go to System Settings > Privacy & Security > Developer Tools, and add and check your terminal (e.g., iTerm) to the list of developer apps. If that doesn't work, see this thread for more ideas: https://github.com/elastic/elasticsearch/issues/91159.
- When running tests from Intellij, you might need to add `--add-modules jdk.incubator.vector` to the VM options.

## Nearest Neighbors Search

Expand Down
1,495 changes: 766 additions & 729 deletions docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 8 additions & 8 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|351.096|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.447|291.666|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|277.702|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|238.914|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.767|288.441|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|246.201|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|192.499|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|177.009|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=0|0.378|363.121|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=0|0.446|299.144|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=500 probes=3|0.634|270.522|
|eknn-l2lsh|L=100 k=4 w=1024 candidates=1000 probes=3|0.716|240.419|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=0|0.768|280.053|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=0|0.847|240.014|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=500 probes=3|0.922|186.668|
|eknn-l2lsh|L=100 k=4 w=2048 candidates=1000 probes=3|0.960|166.241|
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import java.util.Arrays;
import java.util.Random;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInt;
import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInts;
import static com.klibisz.elastiknn.storage.UnsafeSerialization.*;

public class JaccardLshModel implements HashingModel.SparseBool {

Expand Down Expand Up @@ -46,8 +45,7 @@ public HashAndFreq[] hash(int[] trueIndices, int totalIndices) {
} else {
HashAndFreq[] hashes = new HashAndFreq[L];
for (int ixL = 0; ixL < L; ixL++) {
int[] ints = new int[k + 1];
ints[0] = ixL;
int[] ints = new int[k];
for (int ixk = 0; ixk < k; ixk++) {
int a = A[ixL * k + ixk];
int b = B[ixL * k + ixk];
Expand All @@ -56,9 +54,9 @@ public HashAndFreq[] hash(int[] trueIndices, int totalIndices) {
int hash = ((1 + ti) * a + b) % HashingModel.HASH_PRIME;
if (hash < minHash) minHash = hash;
}
ints[ixk + 1] = minHash;
ints[ixk] = minHash;
}
hashes[ixL] = HashAndFreq.once(writeInts(ints));
hashes[ixL] = HashAndFreq.once(writeIntsWithPrefix(ixL, ints));
}
return hashes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.*;

import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeInts;
import static com.klibisz.elastiknn.storage.UnsafeSerialization.writeIntsWithPrefix;

public class L2LshModel implements HashingModel.DenseFloat {
private final int L;
Expand Down Expand Up @@ -67,17 +68,15 @@ public HashAndFreq[] hash(float[] values) {
}

private HashAndFreq[] hashNoProbing(float[] values) {
// Can this be panamized?
HashAndFreq[] hashes = new HashAndFreq[L];
for (int ixL = 0; ixL < L; ixL++) {
int[] ints = new int[1 + k];
ints[0] = ixL;
int[] ints = new int[k];
for (int ixk = 0; ixk < k; ixk++) {
float[] a = A[ixL * k + ixk];
float b = B[ixL * k + ixk];
ints[ixk + 1] = (int) Math.floor((floatVectorOps.dotProduct(a, values) + b) / w);
ints[ixk] = (int) Math.floor((floatVectorOps.dotProduct(a, values) + b) / w);
}
hashes[ixL] = HashAndFreq.once(writeInts(ints));
hashes[ixL] = HashAndFreq.once(writeIntsWithPrefix(ixL, ints));
}
return hashes;
}
Expand All @@ -89,8 +88,7 @@ private HashAndFreq[] hashWithProbing(float[] values, int probesPerTable) {
Perturbation[] zeroPerturbations = new Perturbation[L * k];
Perturbation[][] sortedPerturbations = new Perturbation[L][k * 2];
for (int ixL = 0; ixL < L; ixL++) {
int[] ints = new int[k + 1];
ints[0] = ixL;
int[] ints = new int[k];
for (int ixk = 0; ixk < k; ixk++) {
float[] a = A[ixL * k + ixk];
float b = B[ixL * k + ixk];
Expand All @@ -100,9 +98,9 @@ private HashAndFreq[] hashWithProbing(float[] values, int probesPerTable) {
sortedPerturbations[ixL][ixk * 2 + 0] = new Perturbation(ixL, ixk, -1, proj, hash, Math.abs(dneg));
sortedPerturbations[ixL][ixk * 2 + 1] = new Perturbation(ixL, ixk, 1, proj, hash, Math.abs(w - dneg));
zeroPerturbations[ixL * k + ixk] = new Perturbation(ixL, ixk, 0, proj, hash, 0);
ints[ixk + 1] = hash;
ints[ixk] = hash;
}
hashes[ixL] = HashAndFreq.once(writeInts(ints));
hashes[ixL] = HashAndFreq.once(writeIntsWithPrefix(ixL, ints));
}

PriorityQueue<PerturbationSet> heap = new PriorityQueue<>(Comparator.comparingDouble(o -> o.absDistsSum));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ public static byte[] writeInts(final int[] iarr) {
return buf;
}

/**
* Writes ints to a byte array with an integer prefix.
* This is equivalent to prepending the prefix to the integer array and calling writeInts
* on the resulting array.
*
* @param prefix integer prefix
* @param iarr integer array
* @return Array of bytes with length (4 + 4 * iarr.length)
*/
public static byte[] writeIntsWithPrefix(int prefix, final int[] iarr) {
final int bytesLen = (iarr.length + 1) * numBytesInInt;
byte[] buf = new byte[bytesLen];
u.unsafe.putInt(buf, u.byteArrayOffset, prefix);
u.unsafe.copyMemory(iarr, u.intArrayOffset, buf, numBytesInInt + u.byteArrayOffset, bytesLen);
return buf;
}

/**
* Reads ints from a byte array.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.klibisz.elastiknn.storage

import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers

import scala.util.Random

class UnsafeSerializationSpec extends AnyFreeSpec with Matchers {

"writeInts and readInts" - {
"round trip with randomized arrays" in {
val seed = System.currentTimeMillis()
val maxLen = 4096
val rng = new Random(seed)
for (i <- 0 to 1000) {
withClue(s"Failed on iteration $i with seed $seed and max length $maxLen") {
// Generate array of random ints.
val len = rng.nextInt(maxLen)
val iarr = (0 until len).map(_ => rng.nextInt(Int.MaxValue) * (if (rng.nextBoolean()) 1 else -1)).toArray

// Serialize and check serialized length.
val trimmed = UnsafeSerialization.writeInts(iarr)
trimmed should have length (iarr.length * UnsafeSerialization.numBytesInInt)

// Deserialize and check.
val iarrReadTrimmed = UnsafeSerialization.readInts(trimmed, 0, trimmed.length)
iarrReadTrimmed shouldBe iarr

// Place in larger array with random offset.
val offset = rng.nextInt(maxLen)
val embedded = new Array[Byte](offset) ++ trimmed ++ new Array[Byte](rng.nextInt(maxLen))

// Deserialize and check.
val iarrReadEmbedded = UnsafeSerialization.readInts(embedded, offset, trimmed.length)
iarrReadEmbedded shouldBe iarr
}
}
}
}

"writeIntsWithPrefix" - {
"equivalent to writeInts with prefix embedded in the array" in {
val seed = System.currentTimeMillis()
val maxLen = 4096
val rng = new Random(seed)
for (_ <- 0 to 1000) {
val len = rng.nextInt(maxLen)
val prefix = rng.nextInt()
val iarr = (0 until len).map(_ => rng.nextInt(Int.MaxValue) * (if (rng.nextBoolean()) 1 else -1)).toArray
val iarrWithPrefix = prefix +: iarr
val writeIntsWithPrefix = UnsafeSerialization.writeIntsWithPrefix(prefix, iarr)
val writeInts = UnsafeSerialization.writeInts(iarrWithPrefix)
writeIntsWithPrefix shouldBe writeInts
}
}
}

"writeFloats and readFloats" - {
"round trip with randomized arrays" in {
val seed = System.currentTimeMillis()
val maxLen = 4096
val rng = new Random(seed)
for (i <- 0 to 1000) {
withClue(s"Failed on iteration $i with seed $seed and max length $maxLen") {
// Generate array of random floats.
val len = rng.nextInt(maxLen)
val farr = (0 until len).map(_ => rng.nextFloat() * (if (rng.nextBoolean()) Float.MaxValue else Float.MinValue)).toArray

// Serialize and check length.
val trimmed = UnsafeSerialization.writeFloats(farr)
trimmed should have length (farr.length * UnsafeSerialization.numBytesInFloat)

// Deserialize and check.
val farrTrimmed = UnsafeSerialization.readFloats(trimmed, 0, trimmed.length)
farrTrimmed shouldBe farr

// Place in larger array with random offset.
val offset = rng.nextInt(maxLen)
val embedded = new Array[Byte](offset) ++ trimmed ++ new Array[Byte](rng.nextInt(maxLen))

// Deserialize and check.
val farrReadEmbedded = UnsafeSerialization.readFloats(embedded, offset, trimmed.length)
farrReadEmbedded shouldBe farr
}
}
}
}

"writeInt" - {
"variable length encoding" in {
UnsafeSerialization.writeInt(127) should have length 1
UnsafeSerialization.writeInt(-127) should have length 1
UnsafeSerialization.writeInt(32767) should have length 2
UnsafeSerialization.writeInt(-32767) should have length 2
}
}

"writeInt and readInt" - {
"round trip with randomized ints" in {
val seed = System.currentTimeMillis()
val rng = new Random(seed)
for (i <- 0 to 10000) {
withClue(s"Failed on iteration $i with seed $seed") {
val i = rng.nextInt(Int.MaxValue) * (if (rng.nextBoolean()) 1 else -1)
val barr = UnsafeSerialization.writeInt(i)
val iRead = UnsafeSerialization.readInt(barr)
iRead shouldBe i
}
}
}
}
}

This file was deleted.

0 comments on commit 1a1a561

Please sign in to comment.