-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
731 additions
and
743 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,98 +1,93 @@ | ||
package io.qdrant.spark; | ||
|
||
import org.apache.spark.sql.sources.DataSourceRegister; | ||
import org.apache.spark.sql.types.StructType; | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap; | ||
import org.apache.spark.sql.connector.catalog.TableProvider; | ||
import org.apache.spark.sql.connector.catalog.Table; | ||
import org.apache.spark.sql.connector.expressions.Transform; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
import org.apache.spark.sql.connector.catalog.Table; | ||
import org.apache.spark.sql.connector.catalog.TableProvider; | ||
import org.apache.spark.sql.connector.expressions.Transform; | ||
import org.apache.spark.sql.sources.DataSourceRegister; | ||
import org.apache.spark.sql.types.StructType; | ||
import org.apache.spark.sql.util.CaseInsensitiveStringMap; | ||
|
||
/** | ||
* A class that implements the TableProvider and DataSourceRegister interfaces. | ||
* Provides methods to infer schema, get table, and check required options. | ||
* A class that implements the TableProvider and DataSourceRegister interfaces. Provides methods to | ||
* infer schema, get table, and check required options. | ||
*/ | ||
public class Qdrant implements TableProvider, DataSourceRegister { | ||
|
||
private final String[] requiredFields = new String[] { | ||
"schema", | ||
"collection_name", | ||
"embedding_field", | ||
"qdrant_url" | ||
}; | ||
private final String[] requiredFields = | ||
new String[] {"schema", "collection_name", "embedding_field", "qdrant_url"}; | ||
|
||
/** | ||
* Returns the short name of the data source. | ||
* | ||
* @return The short name of the data source. | ||
*/ | ||
@Override | ||
public String shortName() { | ||
return "qdrant"; | ||
} | ||
/** | ||
* Returns the short name of the data source. | ||
* | ||
* @return The short name of the data source. | ||
*/ | ||
@Override | ||
public String shortName() { | ||
return "qdrant"; | ||
} | ||
|
||
/** | ||
* Infers the schema of the data source based on the provided options. | ||
* | ||
* @param options The options used to infer the schema. | ||
* @return The inferred schema. | ||
*/ | ||
@Override | ||
public StructType inferSchema(CaseInsensitiveStringMap options) { | ||
/** | ||
* Infers the schema of the data source based on the provided options. | ||
* | ||
* @param options The options used to infer the schema. | ||
* @return The inferred schema. | ||
*/ | ||
@Override | ||
public StructType inferSchema(CaseInsensitiveStringMap options) { | ||
|
||
StructType schema = (StructType) StructType.fromJson(options.get("schema")); | ||
checkRequiredOptions(options, schema); | ||
StructType schema = (StructType) StructType.fromJson(options.get("schema")); | ||
checkRequiredOptions(options, schema); | ||
|
||
return schema; | ||
}; | ||
return schema; | ||
} | ||
; | ||
|
||
/** | ||
* Returns a table for the data source based on the provided schema, | ||
* partitioning, and properties. | ||
* | ||
* @param schema The schema of the table. | ||
* @param partitioning The partitioning of the table. | ||
* @param properties The properties of the table. | ||
* @return The table for the data source. | ||
*/ | ||
@Override | ||
public Table getTable(StructType schema, Transform[] partitioning, Map<String, String> properties) { | ||
QdrantOptions options = new QdrantOptions(properties); | ||
return new QdrantCluster(options, schema); | ||
} | ||
/** | ||
* Returns a table for the data source based on the provided schema, partitioning, and properties. | ||
* | ||
* @param schema The schema of the table. | ||
* @param partitioning The partitioning of the table. | ||
* @param properties The properties of the table. | ||
* @return The table for the data source. | ||
*/ | ||
@Override | ||
public Table getTable( | ||
StructType schema, Transform[] partitioning, Map<String, String> properties) { | ||
QdrantOptions options = new QdrantOptions(properties); | ||
return new QdrantCluster(options, schema); | ||
} | ||
|
||
/** | ||
* Checks if the required options are present in the provided options and if the | ||
* id_field and embedding_field | ||
* options are present in the provided schema. | ||
* | ||
* @param options The options to check. | ||
* @param schema The schema to check. | ||
*/ | ||
void checkRequiredOptions(CaseInsensitiveStringMap options, StructType schema) { | ||
for (String fieldName : requiredFields) { | ||
if (!options.containsKey(fieldName)) { | ||
throw new IllegalArgumentException(fieldName + " option is required"); | ||
} | ||
} | ||
/** | ||
* Checks if the required options are present in the provided options and if the id_field and | ||
* embedding_field options are present in the provided schema. | ||
* | ||
* @param options The options to check. | ||
* @param schema The schema to check. | ||
*/ | ||
void checkRequiredOptions(CaseInsensitiveStringMap options, StructType schema) { | ||
for (String fieldName : requiredFields) { | ||
if (!options.containsKey(fieldName)) { | ||
throw new IllegalArgumentException(fieldName + " option is required"); | ||
} | ||
} | ||
|
||
List<String> fieldNames = Arrays.asList(schema.fieldNames()); | ||
List<String> fieldNames = Arrays.asList(schema.fieldNames()); | ||
|
||
if (options.containsKey("id_field")) { | ||
String idField = options.get("id_field").toString(); | ||
if (options.containsKey("id_field")) { | ||
String idField = options.get("id_field").toString(); | ||
|
||
if (!fieldNames.contains(idField)) { | ||
throw new IllegalArgumentException("id_field option is not present in the schema"); | ||
} | ||
} | ||
if (!fieldNames.contains(idField)) { | ||
throw new IllegalArgumentException("id_field option is not present in the schema"); | ||
} | ||
} | ||
|
||
String embeddingField = options.get("embedding_field").toString(); | ||
String embeddingField = options.get("embedding_field").toString(); | ||
|
||
if (!fieldNames.contains(embeddingField)) { | ||
throw new IllegalArgumentException("embedding_field option is not present in the schema"); | ||
} | ||
if (!fieldNames.contains(embeddingField)) { | ||
throw new IllegalArgumentException("embedding_field option is not present in the schema"); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.