diff --git a/src/main/java/io/qdrant/spark/ObjectFactory.java b/src/main/java/io/qdrant/spark/ObjectFactory.java new file mode 100644 index 0000000..71e0749 --- /dev/null +++ b/src/main/java/io/qdrant/spark/ObjectFactory.java @@ -0,0 +1,78 @@ +package io.qdrant.spark; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.ArrayType; + +class ObjectFactory { + public static Object object(InternalRow record, StructField field, int fieldIndex) { + DataType dataType = field.dataType(); + + switch (dataType.typeName()) { + case "integer": + return record.getInt(fieldIndex); + case "float": + return record.getFloat(fieldIndex); + case "double": + return record.getDouble(fieldIndex); + case "long": + return record.getLong(fieldIndex); + case "boolean": + return record.getBoolean(fieldIndex); + case "string": + return record.getString(fieldIndex); + case "array": + ArrayType arrayType = (ArrayType) dataType; + ArrayData arrayData = record.getArray(fieldIndex); + return object(arrayData, arrayType.elementType()); + case "struct": + StructType structType = (StructType) dataType; + InternalRow structData = record.getStruct(fieldIndex, structType.fields().length); + return object(structData, structType); + default: + return null; + } + } + + public static Object object(ArrayData arrayData, DataType elementType) { + + switch (elementType.typeName()) { + case "string": { + int length = arrayData.numElements(); + String[] result = new String[length]; + for (int i = 0; i < length; i++) { + result[i] = arrayData.getUTF8String(i).toString(); + } + return result; + } + + case "struct": { + StructType structType = (StructType) elementType; + int length = arrayData.numElements(); + Object[] result = new Object[length]; + for (int i = 0; i < length; i++) { + InternalRow structData = arrayData.getStruct(i, structType.fields().length); + result[i] = object(structData, structType); + } + return result; + } + default: + return arrayData.toObjectArray(elementType); + } + } + + public static Object object(InternalRow structData, StructType structType) { + Map result = new HashMap<>(); + for (int i = 0; i < structType.fields().length; i++) { + StructField structField = structType.fields()[i]; + result.put(structField.name(), object(structData, structField, i)); + } + return result; + } +} \ No newline at end of file diff --git a/src/main/java/io/qdrant/spark/QdrantDataWriter.java b/src/main/java/io/qdrant/spark/QdrantDataWriter.java index aaa8325..b7f23bb 100644 --- a/src/main/java/io/qdrant/spark/QdrantDataWriter.java +++ b/src/main/java/io/qdrant/spark/QdrantDataWriter.java @@ -3,20 +3,18 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; -import java.util.Map; import java.util.UUID; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.connector.write.DataWriter; import org.apache.spark.sql.connector.write.WriterCommitMessage; -import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static io.qdrant.spark.ObjectFactory.object; + /** * A DataWriter implementation that writes data to Qdrant, a vector search * engine. This class takes @@ -53,12 +51,24 @@ public void write(InternalRow record) { for (StructField field : this.schema.fields()) { int fieldIndex = this.schema.fieldIndex(field.name()); if (this.options.idField != null && field.name().equals(this.options.idField)) { - point.id = record.get(fieldIndex, field.dataType()).toString(); + + DataType dataType = field.dataType(); + switch (dataType.typeName()) { + case "string": + point.id = record.getString(fieldIndex); + break; + + case "integer": + point.id = record.getInt(fieldIndex); + break; + default: + throw new IllegalArgumentException("Point ID should be of type string or integer"); + } } else if (field.name().equals(this.options.embeddingField)) { - float[] vector = record.getArray(fieldIndex).toFloatArray(); - point.vector = vector; + point.vector = record.getArray(fieldIndex).toFloatArray(); + } else { - payload.put(field.name(), convertToJavaType(record, field, fieldIndex)); + payload.put(field.name(), object(record, field, fieldIndex)); } } @@ -107,62 +117,10 @@ public void abort() { @Override public void close() { } - - private Object convertToJavaType(InternalRow record, StructField field, int fieldIndex) { - DataType dataType = field.dataType(); - - if (dataType == DataTypes.StringType) { - return record.getString(fieldIndex); - } else if (dataType == DataTypes.DateType || dataType == DataTypes.TimestampType) { - return record.getString(fieldIndex); - } else if (dataType instanceof ArrayType) { - ArrayType arrayType = (ArrayType) dataType; - ArrayData arrayData = record.getArray(fieldIndex); - return convertArrayToJavaType(arrayData, arrayType.elementType()); - } else if (dataType instanceof StructType) { - StructType structType = (StructType) dataType; - InternalRow structData = record.getStruct(fieldIndex, structType.fields().length); - return convertStructToJavaType(structData, structType); - } - - // Fall back to the generic get method - return record.get(fieldIndex, dataType); - } - - private Object convertArrayToJavaType(ArrayData arrayData, DataType elementType) { - if (elementType == DataTypes.StringType) { - int length = arrayData.numElements(); - String[] result = new String[length]; - for (int i = 0; i < length; i++) { - result[i] = arrayData.getUTF8String(i).toString(); - } - return result; - } else if (elementType instanceof StructType) { - StructType structType = (StructType) elementType; - int length = arrayData.numElements(); - Object[] result = new Object[length]; - for (int i = 0; i < length; i++) { - InternalRow structData = arrayData.getStruct(i, structType.fields().length); - result[i] = convertStructToJavaType(structData, structType); - } - return result; - } else { - return arrayData.toObjectArray(elementType); - } - } - - private Object convertStructToJavaType(InternalRow structData, StructType structType) { - Map result = new HashMap<>(); - for (int i = 0; i < structType.fields().length; i++) { - StructField structField = structType.fields()[i]; - result.put(structField.name(), convertToJavaType(structData, structField, i)); - } - return result; - } } class Point implements Serializable { - public String id; + public Object id; public float[] vector; public HashMap payload; } diff --git a/src/test/java/io/qdrant/spark/TestQdrant.java b/src/test/java/io/qdrant/spark/TestQdrant.java index 42bbeed..a1d383e 100644 --- a/src/test/java/io/qdrant/spark/TestQdrant.java +++ b/src/test/java/io/qdrant/spark/TestQdrant.java @@ -54,8 +54,7 @@ public void testGetTable() { options.put("embedding_field", "embedding"); options.put("qdrant_url", "http://localhost:8080"); CaseInsensitiveStringMap dataSourceOptions = new CaseInsensitiveStringMap(options); - var reader = qdrant.getTable(schema, null, dataSourceOptions); - Assert.assertTrue(reader instanceof QdrantCluster); + Assert.assertTrue(qdrant.getTable(schema, null, dataSourceOptions) instanceof QdrantCluster); } @Test()