From 1ec6fe6986db682d51fe1f6e8309e9dd7543fe11 Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 12 Dec 2024 00:51:17 +0800 Subject: [PATCH] plan compound_field_access syntax --- Cargo.toml | 2 +- datafusion/sql/src/expr/mod.rs | 221 +++++++++++++++++----------- datafusion/sql/src/unparser/expr.rs | 10 +- 3 files changed, 145 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b854c670349..bae417bcee57 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,7 +147,7 @@ recursive = "0.1.1" regex = "1.8" rstest = "0.23.0" serde_json = "1" -sqlparser = { version = "0.53.0", features = ["visitor"] } +sqlparser = { git = "https://github.com/apache/datafusion-sqlparser-rs.git", rev ="df3c5652b10493df4db484f358514bb210673744", features = ["visitor"] } tempfile = "3" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } url = "2.2" diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 9b40ebdaf6a5..39d678fdadac 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -21,14 +21,15 @@ use datafusion_expr::planner::{ PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, }; use sqlparser::ast::{ - BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, - Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, - TrimWhereField, Value, + AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, + DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, + TrimWhereField, + Value, }; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, - ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, DFSchema, + Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{InList, WildcardOptions}; @@ -235,15 +236,14 @@ impl SqlToRel<'_, S> { SQLExpr::Identifier(id) => { self.sql_identifier_to_expr(id, schema, planner_context) } - - SQLExpr::MapAccess { .. } => { - not_impl_err!("Map Access") - } - // ["foo"], [4] or [4:5] - SQLExpr::Subscript { expr, subscript } => { - self.sql_subscript_to_expr(*expr, subscript, schema, planner_context) - } + SQLExpr::CompoundFieldAccess { root, access_chain } => self + .sql_compound_field_access_to_expr( + *root, + access_chain, + schema, + planner_context, + ), SQLExpr::CompoundIdentifier(ids) => { self.sql_compound_identifier_to_expr(ids, schema, planner_context) @@ -984,84 +984,141 @@ impl SqlToRel<'_, S> { Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) } - fn sql_subscript_to_expr( + fn sql_compound_field_access_to_expr( &self, - expr: SQLExpr, - subscript: Box, + root: SQLExpr, + access_chain: Vec, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let expr = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; - - let field_access = match *subscript { - Subscript::Index { index } => { - // index can be a name, in which case it is a named field access - match index { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => GetFieldAccess::NamedStructField { - name: ScalarValue::from(s), - }, - SQLExpr::JsonAccess { .. } => { - return not_impl_err!("JsonAccess"); + let mut root = self.sql_expr_to_logical_expr(root, schema, planner_context)?; + let fields = access_chain + .into_iter() + .map(|field| match field { + AccessExpr::Subscript(subscript) => { + match subscript { + Subscript::Index { index } => { + // index can be a name, in which case it is a named field access + match index { + SQLExpr::Value( + Value::SingleQuotedString(s) + | Value::DoubleQuotedString(s), + ) => Ok(Some(GetFieldAccess::NamedStructField { + name: ScalarValue::from(s), + })), + SQLExpr::JsonAccess { .. } => { + not_impl_err!("JsonAccess") + } + // otherwise treat like a list index + _ => Ok(Some(GetFieldAccess::ListIndex { + key: Box::new(self.sql_expr_to_logical_expr( + index, + schema, + planner_context, + )?), + })), + } + } + Subscript::Slice { + lower_bound, + upper_bound, + stride, + } => { + // Means access like [:2] + let lower_bound = if let Some(lower_bound) = lower_bound { + self.sql_expr_to_logical_expr( + lower_bound, + schema, + planner_context, + ) + } else { + not_impl_err!("Slice subscript requires a lower bound") + }?; + + // means access like [2:] + let upper_bound = if let Some(upper_bound) = upper_bound { + self.sql_expr_to_logical_expr( + upper_bound, + schema, + planner_context, + ) + } else { + not_impl_err!("Slice subscript requires an upper bound") + }?; + + // stride, default to 1 + let stride = if let Some(stride) = stride { + self.sql_expr_to_logical_expr( + stride, + schema, + planner_context, + )? + } else { + lit(1i64) + }; + + Ok(Some(GetFieldAccess::ListRange { + start: Box::new(lower_bound), + stop: Box::new(upper_bound), + stride: Box::new(stride), + })) + } } - // otherwise treat like a list index - _ => GetFieldAccess::ListIndex { - key: Box::new(self.sql_expr_to_logical_expr( - index, - schema, - planner_context, - )?), - }, } - } - Subscript::Slice { - lower_bound, - upper_bound, - stride, - } => { - // Means access like [:2] - let lower_bound = if let Some(lower_bound) = lower_bound { - self.sql_expr_to_logical_expr(lower_bound, schema, planner_context) - } else { - not_impl_err!("Slice subscript requires a lower bound") - }?; - - // means access like [2:] - let upper_bound = if let Some(upper_bound) = upper_bound { - self.sql_expr_to_logical_expr(upper_bound, schema, planner_context) - } else { - not_impl_err!("Slice subscript requires an upper bound") - }?; - - // stride, default to 1 - let stride = if let Some(stride) = stride { - self.sql_expr_to_logical_expr(stride, schema, planner_context)? - } else { - lit(1i64) - }; - - GetFieldAccess::ListRange { - start: Box::new(lower_bound), - stop: Box::new(upper_bound), - stride: Box::new(stride), + AccessExpr::Dot(expr) => { + let expr = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + match expr { + Expr::Column(Column { name, relation }) => { + if let Some(relation) = &relation { + // If the first part of the dot access is a column reference, we should + // check if the column is from the same table as the root expression. + // If it is, we should replace the root expression with the column reference. + // Otherwise, we should treat the dot access as a named field access. + if relation.table() == root.schema_name().to_string() { + root = Expr::Column(Column { + name, + relation: Some(relation.clone()), + }); + Ok(None) + } else { + plan_err!( + "table name mismatch: {} != {}", + relation.table(), + root.schema_name() + ) + } + } else { + Ok(Some(GetFieldAccess::NamedStructField { + name: ScalarValue::from(name), + })) + } + } + _ => not_impl_err!( + "Dot access not supported for non-column expr: {expr:?}" + ), + } } - } - }; + }) + .collect::>>()?; - let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; - for planner in self.context_provider.get_expr_planners() { - match planner.plan_field_access(field_access_expr, schema)? { - PlannerResult::Planned(expr) => return Ok(expr), - PlannerResult::Original(expr) => { - field_access_expr = expr; + fields + .into_iter() + .flatten() + .try_fold(root, |expr, field_access| { + let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_field_access(field_access_expr, schema)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(expr) => { + field_access_expr = expr; + } + } } - } - } - - not_impl_err!( - "GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}" - ) + not_impl_err!( + "GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}" + ) + }) } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d012d3437720..2b8e53c4243d 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,8 @@ use datafusion_expr::expr::Unnest; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, - Subscript, TimezoneInfo, UnaryOperator, + self, AccessExpr, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, + ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; use std::sync::Arc; use std::vec; @@ -523,9 +523,9 @@ impl Unparser<'_> { } let array = self.expr_to_sql(&args[0])?; let index = self.expr_to_sql(&args[1])?; - Ok(ast::Expr::Subscript { - expr: Box::new(array), - subscript: Box::new(Subscript::Index { index }), + Ok(ast::Expr::CompoundFieldAccess { + root: Box::new(array), + access_chain: vec![AccessExpr::Subscript(Subscript::Index { index })], }) }