Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic type support for ParquetDenseVectorDocumentGenerator #2667

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.parquet.example.data.Group;
import org.apache.parquet.hadoop.ParquetReader;
import org.apache.parquet.hadoop.example.GroupReadSupport;
import org.apache.parquet.schema.PrimitiveType;

/**
* Collection class for managing Parquet dense vectors
Expand Down Expand Up @@ -83,7 +84,7 @@ public FileSegment<ParquetDenseVectorCollection.Document> createFileSegment(Buff
* Inner class representing a file segment for ParquetDenseVectorCollection.
*/
public static class Segment extends FileSegment<ParquetDenseVectorCollection.Document> {
private List<double[]> vectors; // List to store vectors from the Parquet file
private List<float[]> vectors; // List to store vectors from the Parquet file
private List<String> ids; // List to store document IDs
private ParquetReader<Group> reader;
private boolean readerInitialized;
Expand Down Expand Up @@ -152,19 +153,24 @@ protected synchronized void readNext() throws IOException, NoSuchElementExceptio
throw new NoSuchElementException("End of file reached");
}

// Read each record from the Parquet file
// Extract the docid (String) from the record
String docid = record.getString("docid", 0);
ids.add(docid);

// Extract the vector (double[]) from the record
Group vectorGroup = record.getGroup("vector", 0); // Access the 'vector' field
int vectorSize = vectorGroup.getFieldRepetitionCount(0); // Get the number of elements in the vector
double[] vector = new double[vectorSize];
// Extract the vector (double[]) from the record
Group vectorGroup = record.getGroup("vector", 0);// Access the 'vector' field
int vectorSize = vectorGroup.getFieldRepetitionCount(0);// Get the number of elements in the vector
float[] vector = new float[vectorSize];

// We detect the type from the schema
Group firstElement = vectorGroup.getGroup(0, 0);
boolean isDouble = firstElement.getType().getFields().get(0).asPrimitiveType().getPrimitiveTypeName().equals(PrimitiveType.PrimitiveTypeName.DOUBLE);

// Single-pass read with conditional cast if needed
for (int i = 0; i < vectorSize; i++) {
Group listGroup = vectorGroup.getGroup(0, i); // Access the 'list' group
vector[i] = listGroup.getDouble("element", 0); // Get the double value from the 'element' field
Group listGroup = vectorGroup.getGroup(0, i);
vector[i] = isDouble ? (float) listGroup.getDouble("element", 0) : listGroup.getFloat("element", 0);
}

vectors.add(vector);

// Create a new Document object with the retrieved data
Expand All @@ -177,7 +183,7 @@ protected synchronized void readNext() throws IOException, NoSuchElementExceptio
*/
public static class Document implements SourceDocument {
private final String id;
private final double[] vector;
private final float[] vector;
private final String raw;

/**
Expand All @@ -187,7 +193,7 @@ public static class Document implements SourceDocument {
* @param vector the vector data.
* @param raw the raw data.
*/
public Document(String id, double[] vector, String raw) {
public Document(String id, float[] vector, String raw) {
this.id = id;
this.vector = vector;
this.raw = raw;
Expand Down
21 changes: 21 additions & 0 deletions src/test/java/io/anserini/index/IndexFlatDenseVectorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,25 @@ public void testQuantizedInt8() throws Exception {
assertNotNull(results);
assertEquals(100, results.get("documents"));
}

@Test
public void testParquetFloat() throws Exception {
String indexPath = "target/lucene-test-index.flat." + System.currentTimeMillis();
String[] indexArgs = new String[] {
"-collection", "ParquetDenseVectorCollection",
"-input", "src/test/resources/sample_docs/parquet/msmarco-passage-bge-base-en-v1.5.parquet-float",
"-index", indexPath,
"-generator", "ParquetDenseVectorDocumentGenerator",
"-threads", "1"
};

IndexFlatDenseVectors.main(indexArgs);

IndexReader reader = IndexReaderUtils.getReader(indexPath);
assertNotNull(reader);

Map<String, Object> results = IndexReaderUtils.getIndexStats(reader, Constants.VECTOR);
assertNotNull(results);
assertEquals(10, results.get("documents"));
}
}
22 changes: 22 additions & 0 deletions src/test/java/io/anserini/index/IndexHnswDenseVectorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,28 @@ public void testParquet() throws Exception {
assertEquals(10, results.get("documents"));
}

@Test
public void testParquetFloat() throws Exception {
String indexPath = "target/lucene-test-index.flat." + System.currentTimeMillis();
String[] indexArgs = new String[] {
"-collection", "ParquetDenseVectorCollection",
"-input", "src/test/resources/sample_docs/parquet/msmarco-passage-bge-base-en-v1.5.parquet-float",
"-index", indexPath,
"-generator", "ParquetDenseVectorDocumentGenerator",
"-threads", "1",
"-M", "16", "-efC", "100"
};

IndexHnswDenseVectors.main(indexArgs);

IndexReader reader = IndexReaderUtils.getReader(indexPath);
assertNotNull(reader);

Map<String, Object> results = IndexReaderUtils.getIndexStats(reader, Constants.VECTOR);
assertNotNull(results);
assertEquals(10, results.get("documents"));
}

@Test
public void testQuantizedInt8() throws Exception {
String indexPath = "target/lucene-test-index.hnsw." + System.currentTimeMillis();
Expand Down
Binary file not shown.
Loading