diff --git a/src/main/java/io/qdrant/spark/QdrantVectorHandler.java b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java index ec793fd..e1ef987 100644 --- a/src/main/java/io/qdrant/spark/QdrantVectorHandler.java +++ b/src/main/java/io/qdrant/spark/QdrantVectorHandler.java @@ -13,6 +13,7 @@ import java.util.Map; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -98,6 +99,12 @@ private static float[] extractFloatArray(InternalRow record, int fieldIndex, Dat throw new IllegalArgumentException("Vector field must be of type ArrayType"); } + ArrayType arrayType = (ArrayType) dataType; + + if (!arrayType.elementType().typeName().equalsIgnoreCase("float")) { + throw new IllegalArgumentException("Expected array elements to be of FloatType"); + } + return record.getArray(fieldIndex).toFloatArray(); } @@ -107,6 +114,12 @@ private static int[] extractIntArray(InternalRow record, int fieldIndex, DataTyp throw new IllegalArgumentException("Vector field must be of type ArrayType"); } + ArrayType arrayType = (ArrayType) dataType; + + if (!arrayType.elementType().typeName().equalsIgnoreCase("integer")) { + throw new IllegalArgumentException("Expected array elements to be of IntegerType"); + } + return record.getArray(fieldIndex).toIntArray(); } @@ -114,7 +127,13 @@ private static float[][] extractMultiVecArray( InternalRow record, int fieldIndex, DataType dataType) { if (!dataType.typeName().equalsIgnoreCase("array")) { - throw new IllegalArgumentException("Vector field must be of type ArrayType"); + throw new IllegalArgumentException("Multi Vector field must be of type ArrayType"); + } + + ArrayType arrayType = (ArrayType) dataType; + + if (!arrayType.elementType().typeName().equalsIgnoreCase("array")) { + throw new IllegalArgumentException("Multi Vector elements must be of type ArrayType"); } ArrayData arrayData = record.getArray(fieldIndex);