Skip to content

Commit

Permalink
Merge pull request #14 from hasura/range-query-elasticsearch
Browse files Browse the repository at this point in the history
Range query elasticsearch
  • Loading branch information
gneeri authored Jun 28, 2024
2 parents f5938f8 + 5ae2f4f commit fb9b8cc
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 209 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- Support for Elasticsearch Range Queries.

## [0.2.0]

### Added
Expand Down
13 changes: 1 addition & 12 deletions connector/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ func prepareAggregateColumnCount(ctx context.Context, field string, path string,
// If the field is nested, it generates a nested query to perform the specified function on the field in the nested document.
func prepareAggregateSingleColumn(ctx context.Context, function, field string, path string, aggName string) (map[string]interface{}, error) {
// Validate the function
validFunctions := []string{"sum", "min", "max", "avg", "value_count", "cardinality", "stats", "string_stats"}
if !contains(validFunctions, function) {
if !internal.Contains(validFunctions, function) {
return nil, schema.UnprocessableContentError("invalid aggregate function", map[string]any{
"value": function,
})
Expand Down Expand Up @@ -148,13 +147,3 @@ func prepareNestedAggregate(ctx context.Context, aggName string, aggregation map

return aggregation
}

// contains checks if a string slice contains a specific element.
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
5 changes: 5 additions & 0 deletions connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ var testCases = []struct {
requestFile: "../testdata/query/filter/unary_predicate_on_nested_type_request.json",
responseFile: "../testdata/query/filter/unary_predicate_on_nested_type_response.json",
},
{
name: "range_query",
requestFile: "../testdata/query/filter/range_query_request.json",
responseFile: "../testdata/query/filter/range_query_response.json",
},
{
name: "star_count_aggregation",
requestFile: "../testdata/query/aggregation/star_count_request.json",
Expand Down
108 changes: 86 additions & 22 deletions connector/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func handleExpressionUnaryComparisonOperator(expr *schema.ExpressionUnaryCompari
filter := map[string]interface{}{"bool": map[string]interface{}{"must_not": map[string]interface{}{"exists": value}}}
if nestedFields, ok := state.NestedFields[collection]; ok {
if _, ok := nestedFields.(map[string]string)[expr.Column.Name]; ok {
filter = joinNestedFieldPath(state, "bool.must_not.exists", value, fieldName, len(expr.Column.FieldPath), collection)
filter = prepareNestedQuery(state, "bool.must_not.exists", value, fieldName, len(expr.Column.FieldPath), collection)
}
}
return filter, nil
Expand All @@ -79,35 +79,42 @@ func handleExpressionUnaryComparisonOperator(expr *schema.ExpressionUnaryCompari
}

// handleExpressionBinaryComparisonOperator processes the binary comparison operator expression.
func handleExpressionBinaryComparisonOperator(expr *schema.ExpressionBinaryComparisonOperator, state *types.State, collection string) (map[string]interface{}, error) {
var filter map[string]interface{}
fieldName, _ := joinFieldPath(state, expr.Column.FieldPath, expr.Column.Name, collection)
func handleExpressionBinaryComparisonOperator(
expr *schema.ExpressionBinaryComparisonOperator,
state *types.State,
collection string,
) (map[string]interface{}, error) {
fieldPath, nestedPath := joinFieldPath(state, expr.Column.FieldPath, expr.Column.Name, collection)
var bestSubField string

switch expr.Operator {
case "match", "match_phrase", "match_phrase_prefix", "match_bool_prefix":
bestSubField = getTextFieldFromState(state, fieldName, collection)
bestSubField = getTextFieldFromState(state, fieldPath, collection)
case "term", "prefix", "terms":
bestSubField = getKeywordFieldFromState(state, fieldName, collection)
bestSubField = getKeywordFieldFromState(state, fieldPath, collection)
case "wildcard", "regexp":
bestSubField = getWildcardFieldFromState(state, fieldName, collection)
bestSubField = getWildcardFieldFromState(state, fieldPath, collection)
case "range":
bestSubField = getNumericFieldFromState(state, fieldPath, collection)
default:
return nil, schema.UnprocessableContentError("invalid binary comaparison operator", map[string]any{
"expression": expr.Operator,
})
}

value, err := evalComparisonValue(expr.Value, bestSubField)
value, err := evalComparisonValue(expr.Value, bestSubField, expr.Operator)
if err != nil {
return nil, err
}
filter = map[string]interface{}{

filter := map[string]interface{}{
expr.Operator: value,
}
if nestedFields, ok := state.NestedFields[collection]; ok {
if _, ok := nestedFields.(map[string]string)[expr.Column.Name]; ok {
filter = joinNestedFieldPath(state, expr.Operator, value, fieldName, len(expr.Column.FieldPath), collection)
}

if nestedPath != "" {
filter = prepareNestedQuery(state, expr.Operator, value, fieldPath, len(expr.Column.FieldPath), collection)
}

return filter, nil
}

Expand Down Expand Up @@ -138,8 +145,15 @@ func joinFieldPath(state *types.State, fieldPath []string, columnName string, co
return joinedPath, nestedPath
}

// joinNestedFieldPath creates a Elasticsearch's nested query based on field_path.
func joinNestedFieldPath(state *types.State, operator string, value map[string]interface{}, fieldName string, nestedLevel int, collection string) map[string]interface{} {
// prepareNestedQuery creates a Elasticsearch's nested query based on field_path.
func prepareNestedQuery(
state *types.State,
operator string,
value map[string]interface{},
fieldName string,
nestedLevel int,
collection string,
) map[string]interface{} {
// Create the innermost query
operators := strings.Split(operator, ".")
query := value
Expand Down Expand Up @@ -205,6 +219,30 @@ func getTextFieldFromState(state *types.State, columnName string, collection str
return columnName
}

// getNumericFieldFromState retrieves the best matching field for range queries
// from the state. If the field is found, it returns the corresponding field
// name; otherwise, it returns the original columnName.
func getNumericFieldFromState(state *types.State, columnName string, collection string) string {
if collectionField, ok := state.SupportedFilterFields[collection]; ok {
if numericFields, ok := collectionField.(map[string]interface{})["range_queries"].(map[string]string); ok {
if numericField, ok := numericFields[columnName]; ok {
return numericField
}
}
if keywordFields, ok := collectionField.(map[string]interface{})["term_level_queries"].(map[string]string); ok {
if keywordField, ok := keywordFields[columnName]; ok {
return keywordField
}
}
if wildcardFields, ok := collectionField.(map[string]interface{})["unstructured_text"].(map[string]string); ok {
if wildcardField, ok := wildcardFields[columnName]; ok {
return wildcardField
}
}
}
return columnName
}

// getWildcardFieldFromState retrieves the best matching field for wildcard and regexp
// queries from the state. If the field is found, it returns the corresponding field
// name; otherwise, it returns the original columnName.
Expand All @@ -225,19 +263,45 @@ func getWildcardFieldFromState(state *types.State, columnName string, collection
}

// evalComparisonValue evaluates the comparison value for scalar and variable type.
func evalComparisonValue(comparisonValue schema.ComparisonValue, columnName string) (map[string]interface{}, error) {
func evalComparisonValue(comparisonValue schema.ComparisonValue, columnName string, operator string) (map[string]interface{}, error) {
switch compValue := comparisonValue.Interface().(type) {
case *schema.ComparisonValueScalar:
return map[string]interface{}{
columnName: compValue.Value,
}, nil
if operator == "range" {
validValue, err := processRangeValue(compValue.Value)
if err != nil {
return nil, err
}
return map[string]interface{}{columnName: validValue}, nil
}
return map[string]interface{}{columnName: compValue.Value}, nil
case *schema.ComparisonValueVariable:
return map[string]interface{}{
columnName: types.Variable(compValue.Name),
}, nil
return map[string]interface{}{columnName: types.Variable(compValue.Name)}, nil
default:
return nil, schema.UnprocessableContentError("invalid type of comparison value", map[string]any{
"value": comparisonValue["type"],
})
}
}

// processRangeValue processes the range value for a range comparison.
// It checks if the range value is valid and returns the valid range value.
// If the range value is invalid, it returns an error.
func processRangeValue(rangeValue interface{}) (map[string]interface{}, error) {
if rangeValue == nil {
return nil, schema.UnprocessableContentError("invalid range value", nil)
}

rangeMap, ok := rangeValue.(map[string]interface{})
if !ok {
return nil, schema.UnprocessableContentError("invalid range value", nil)
}

// Remove empty range values
for key, value := range rangeMap {
if valueStr, ok := value.(string); ok && valueStr == "" {
delete(rangeMap, key)
}
}

return rangeMap, nil
}
Loading

0 comments on commit fb9b8cc

Please sign in to comment.