Skip to content

Commit

Permalink
fix: point id
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Feb 17, 2024
1 parent b2ac52d commit 62b7c50
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 63 deletions.
78 changes: 78 additions & 0 deletions src/main/java/io/qdrant/spark/ObjectFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package io.qdrant.spark;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.ArrayType;

class ObjectFactory {
public static Object object(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

switch (dataType.typeName()) {
case "integer":
return record.getInt(fieldIndex);
case "float":
return record.getFloat(fieldIndex);
case "double":
return record.getDouble(fieldIndex);
case "long":
return record.getLong(fieldIndex);
case "boolean":
return record.getBoolean(fieldIndex);
case "string":
return record.getString(fieldIndex);
case "array":
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return object(arrayData, arrayType.elementType());
case "struct":
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return object(structData, structType);
default:
return null;
}
}

public static Object object(ArrayData arrayData, DataType elementType) {

switch (elementType.typeName()) {
case "string": {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
}

case "struct": {
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = object(structData, structType);
}
return result;
}
default:
return arrayData.toObjectArray(elementType);
}
}

public static Object object(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), object(structData, structField, i));
}
return result;
}
}
80 changes: 19 additions & 61 deletions src/main/java/io/qdrant/spark/QdrantDataWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.connector.write.DataWriter;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static io.qdrant.spark.ObjectFactory.object;

/**
* A DataWriter implementation that writes data to Qdrant, a vector search
* engine. This class takes
Expand Down Expand Up @@ -53,12 +51,24 @@ public void write(InternalRow record) {
for (StructField field : this.schema.fields()) {
int fieldIndex = this.schema.fieldIndex(field.name());
if (this.options.idField != null && field.name().equals(this.options.idField)) {
point.id = record.get(fieldIndex, field.dataType()).toString();

DataType dataType = field.dataType();
switch (dataType.typeName()) {
case "string":
point.id = record.getString(fieldIndex);
break;

case "integer":
point.id = record.getInt(fieldIndex);
break;
default:
throw new IllegalArgumentException("Point ID should be of type string or integer");
}
} else if (field.name().equals(this.options.embeddingField)) {
float[] vector = record.getArray(fieldIndex).toFloatArray();
point.vector = vector;
point.vector = record.getArray(fieldIndex).toFloatArray();

} else {
payload.put(field.name(), convertToJavaType(record, field, fieldIndex));
payload.put(field.name(), object(record, field, fieldIndex));
}
}

Expand Down Expand Up @@ -107,62 +117,10 @@ public void abort() {
@Override
public void close() {
}

private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) {
DataType dataType = field.dataType();

if (dataType == DataTypes.StringType) {
return record.getString(fieldIndex);
} else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) {
return record.getString(fieldIndex);
} else if (dataType instanceof ArrayType) {
ArrayType arrayType = (ArrayType) dataType;
ArrayData arrayData = record.getArray(fieldIndex);
return convertArrayToJavaType(arrayData, arrayType.elementType());
} else if (dataType instanceof StructType) {
StructType structType = (StructType) dataType;
InternalRow structData = record.getStruct(fieldIndex, structType.fields().length);
return convertStructToJavaType(structData, structType);
}

// Fall back to the generic get method
return record.get(fieldIndex, dataType);
}

private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) {
if (elementType == DataTypes.StringType) {
int length = arrayData.numElements();
String[] result = new String[length];
for (int i = 0; i < length; i++) {
result[i] = arrayData.getUTF8String(i).toString();
}
return result;
} else if (elementType instanceof StructType) {
StructType structType = (StructType) elementType;
int length = arrayData.numElements();
Object[] result = new Object[length];
for (int i = 0; i < length; i++) {
InternalRow structData = arrayData.getStruct(i, structType.fields().length);
result[i] = convertStructToJavaType(structData, structType);
}
return result;
} else {
return arrayData.toObjectArray(elementType);
}
}

private Object convertStructToJavaType(InternalRow structData, StructType structType) {
Map<String, Object> result = new HashMap<>();
for (int i = 0; i < structType.fields().length; i++) {
StructField structField = structType.fields()[i];
result.put(structField.name(), convertToJavaType(structData, structField, i));
}
return result;
}
}

class Point implements Serializable {
public String id;
public Object id;
public float[] vector;
public HashMap<String, Object> payload;
}
3 changes: 1 addition & 2 deletions src/test/java/io/qdrant/spark/TestQdrant.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ public void testGetTable() {
options.put("embedding_field", "embedding");
options.put("qdrant_url", "http://localhost:8080");
CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options);
var reader = qdrant.getTable(schema, null, dataSourceOptions);
Assert.assertTrue(reader instanceof QdrantCluster);
Assert.assertTrue(qdrant.getTable(schema, null, dataSourceOptions) instanceof QdrantCluster);
}

@Test()
Expand Down

0 comments on commit 62b7c50

Please sign in to comment.