From 6f97db7d510ef4978dce4223a87ed793393943ca Mon Sep 17 00:00:00 2001 From: Zihao Xu Date: Mon, 4 Mar 2024 04:07:44 -0500 Subject: [PATCH] feat(binder): support sql udf with unnamed parameters (#835) * support sql udf with unnamed parameters Signed-off-by: Michael Xu * fix format Signed-off-by: Michael Xu * improve comment Signed-off-by: Michael Xu * update sql udf tests Signed-off-by: Michael Xu --------- Signed-off-by: Michael Xu --- src/binder/expr.rs | 14 ++- src/binder/mod.rs | 2 +- tests/sql/sql_udf.slt | 233 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 228 insertions(+), 21 deletions(-) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 9c40e9e0..acffddf0 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -15,7 +15,19 @@ impl Binder { /// Bind an expression. pub fn bind_expr(&mut self, expr: Expr) -> Result { let id = match expr { - Expr::Value(v) => Ok(self.egraph.add(Node::Constant(v.into()))), + Expr::Value(v) => { + // This is okay since only sql udf relies on + // parameter-like (i.e., `$1`) values at present + // TODO: consider formally `bind_parameter` in the future + // e.g., lambda function support, etc. + if let Value::Placeholder(key) = v { + self.udf_context + .get_expr(&key) + .map_or_else(|| Err(BindError::InvalidSQL), |&e| Ok(e)) + } else { + Ok(self.egraph.add(Node::Constant(v.into()))) + } + } Expr::Identifier(ident) => self.bind_ident([ident]), Expr::CompoundIdentifier(idents) => self.bind_ident(idents), Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right), diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 644e08c9..f2ed4bcd 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -202,7 +202,7 @@ impl UdfContext { return Err(BindError::InvalidExpression("invalid syntax".to_string())); }; if catalog.arg_names[i].is_empty() { - todo!("anonymous parameters not yet supported"); + ret.insert(format!("${}", i + 1), e.clone()); } else { // The index mapping here is accurate // So that we could directly use the index diff --git a/tests/sql/sql_udf.slt b/tests/sql/sql_udf.slt index 3f7fa818..057c4e44 100644 --- a/tests/sql/sql_udf.slt +++ b/tests/sql/sql_udf.slt @@ -1,42 +1,237 @@ -# Create a named sql udf with single quote definition -statement ok -create function add_named(a INT, b INT) returns int language sql as 'select a + b'; +############################################################# +# Basic tests for sql udf with [unnamed / named] parameters # +############################################################# -# Create another named sql udf with double dollar definition +# Create a sql udf function with unnamed parameters with double dollar as clause statement ok -create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$; +create function add(INT, INT) returns int language sql as $$select $1 + $2$$; -# Named sql udf with corner case +query I +select add(1, -1); +---- +0 + +# Create a sql udf function with unnamed parameters with single quote as clause statement ok -create function corner_case(a INT) returns varchar language sql as $$select '$1 + a + $3'$$; +create function sub(INT, INT) returns int language sql as 'select $1 - $2'; -# Call sql udf inside named sql udf +query I +select sub(1, 1); +---- +0 + +# Create a sql udf function with unamed parameters that calls other pre-defined sql udfs statement ok -create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add_named(a, b)'; +create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)'; -# TODO: add checks -# Named sql udf with invalid parameter in body definition -# Will be rejected at creation time -# statement error failed to find named parameter aa -# create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a'; +query I +select add_sub_binding(); +---- +2 + +# Use them all together +query III +select add(1, -1), sub(1, 1), add_sub_binding(); +---- +0 0 2 + +# Create a sql udf with named parameters with single quote as clause +statement ok +create function add_named(a INT, b INT) returns int language sql as 'select a + b'; -# Call the defined named sql udfs query I select add_named(1, -1); ---- 0 +# Create another sql udf with named parameters with double dollar as clause +statement ok +create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$; + query I select sub_named(1, 1); ---- 0 -query T -select corner_case(1); +# Mixed with named / unnamed parameters +statement ok +create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3'; + +query I +select add_sub_mix(1, 2, 3); ---- -$1 + a + $3 +2 + +# Call sql udf with unnamed parameters inside sql udf with named parameters +statement ok +create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add(a, b)'; query I select add_named_wrapper(1, -1); ---- -0 \ No newline at end of file +0 + +# Create a sql udf with unnamed parameters with return expression +statement ok +create function add_return(INT, INT) returns int language sql return $1 + $2; + +query I +select add_return(1, 1); +---- +2 + +statement ok +create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1); + +query I +select add_return_binding(); +---- +4 + +statement ok +create function print(INT) returns int language sql as 'select $1'; + +query T +select print(114514); +---- +114514 + +# Multiple type interleaving sql udf +statement ok +create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$; + +query I +select add_sub(1, 5.1415926, 1); +---- +3.1415926 + +query III +select add(1, -1), sub(1, 1), add_sub(1, 5.1415926, 1); +---- +0 0 3.1415926 + +# TODO: `Real` is not supported yet +# Complex types interleaving +# statement ok +# create function add_sub_types(INT, BIGINT, FLOAT, DECIMAL, REAL) returns double language sql as 'select $1 + $2 - $3 + $4 + $5'; + +# query I +# select add_sub_types(1, 1919810114514, 3.1415926, 1.123123, 101010.191919); +# ---- +# 1919810215523.1734494 + +statement ok +create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3; + +query I +select add_sub_return(1, 5.1415926, 1); +---- +3.1415926 + +# Create a wrapper function for `add` & `sub` +statement ok +create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512'; + +query I +select add_sub_wrapper(1, 1); +---- +114514 + +########################################################## +# Basic sql udfs integrated with the use of mock tables # +# P.S. This is also a simulation of real world use cases # +########################################################## + +statement ok +create table t1 (c1 INT, c2 INT); + +statement ok +create table t2 (c1 INT, c2 FLOAT, c3 INT); + +# Special table for named sql udf +statement ok +create table t3 (a INT, b INT); + +statement ok +insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); + +statement ok +insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02); + +statement ok +insert into t3 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5); + +query I +select c1, c2, add_return(c1, c2) from t1 order by c1 asc; +---- +1 1 2 +2 2 4 +3 3 6 +4 4 8 +5 5 10 + +query III +select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc; +---- +0 1 1 2 +0 2 2 4 +0 3 3 6 +0 4 4 8 +0 5 5 10 + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1, c2, c3) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +query IIIIII +select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1, c2, c3) from t2 order by c1 asc; +---- +1 3.14 2 3 -1 0.14000000000000012 +2 4.44 5 7 -3 -2.5599999999999996 +20 10.3 2 22 18 -11.7 + +query I +select add_named(a, b) from t3 order by a asc; +---- +2 +4 +6 +8 +10 + +################################ +# Corner & Special cases tests # +################################ + +# Mixed parameter with calling inner sql udfs +statement ok +create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)'; + +query I +select add_sub_mix_wrapper(1, 2, 3); +---- +4 + +# Named sql udf with corner case +statement ok +create function corner_case(INT, a INT, INT) returns varchar language sql as $$select '$1 + a + $3'$$; + +query T +select corner_case(1, 2, 3); +---- +$1 + a + $3 + +# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter +statement ok +create function print_add_one(INT) returns int language sql as 'select print($1 + 1)'; + +statement ok +create function print_add_two(INT) returns int language sql as 'select print($1 + $1)'; + +query III +select print_add_one(1), print_add_one(114513), print_add_two(2); +---- +2 114514 4 \ No newline at end of file