Skip to content

Commit

Permalink
feat(binder): support sql udf with unnamed parameters (#835)
Browse files Browse the repository at this point in the history
* support sql udf with unnamed parameters

Signed-off-by: Michael Xu <[email protected]>

* fix format

Signed-off-by: Michael Xu <[email protected]>

* improve comment

Signed-off-by: Michael Xu <[email protected]>

* update sql udf tests

Signed-off-by: Michael Xu <[email protected]>

---------

Signed-off-by: Michael Xu <[email protected]>
  • Loading branch information
xzhseh authored Mar 4, 2024
1 parent 0e4a846 commit 6f97db7
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 21 deletions.
14 changes: 13 additions & 1 deletion src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@ impl Binder {
/// Bind an expression.
pub fn bind_expr(&mut self, expr: Expr) -> Result {
let id = match expr {
Expr::Value(v) => Ok(self.egraph.add(Node::Constant(v.into()))),
Expr::Value(v) => {
// This is okay since only sql udf relies on
// parameter-like (i.e., `$1`) values at present
// TODO: consider formally `bind_parameter` in the future
// e.g., lambda function support, etc.
if let Value::Placeholder(key) = v {
self.udf_context
.get_expr(&key)
.map_or_else(|| Err(BindError::InvalidSQL), |&e| Ok(e))
} else {
Ok(self.egraph.add(Node::Constant(v.into())))
}
}
Expr::Identifier(ident) => self.bind_ident([ident]),
Expr::CompoundIdentifier(idents) => self.bind_ident(idents),
Expr::BinaryOp { left, op, right } => self.bind_binary_op(*left, op, *right),
Expand Down
2 changes: 1 addition & 1 deletion src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl UdfContext {
return Err(BindError::InvalidExpression("invalid syntax".to_string()));
};
if catalog.arg_names[i].is_empty() {
todo!("anonymous parameters not yet supported");
ret.insert(format!("${}", i + 1), e.clone());
} else {
// The index mapping here is accurate
// So that we could directly use the index
Expand Down
233 changes: 214 additions & 19 deletions tests/sql/sql_udf.slt
Original file line number Diff line number Diff line change
@@ -1,42 +1,237 @@
# Create a named sql udf with single quote definition
statement ok
create function add_named(a INT, b INT) returns int language sql as 'select a + b';
#############################################################
# Basic tests for sql udf with [unnamed / named] parameters #
#############################################################

# Create another named sql udf with double dollar definition
# Create a sql udf function with unnamed parameters with double dollar as clause
statement ok
create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$;
create function add(INT, INT) returns int language sql as $$select $1 + $2$$;

# Named sql udf with corner case
query I
select add(1, -1);
----
0

# Create a sql udf function with unnamed parameters with single quote as clause
statement ok
create function corner_case(a INT) returns varchar language sql as $$select '$1 + a + $3'$$;
create function sub(INT, INT) returns int language sql as 'select $1 - $2';

# Call sql udf inside named sql udf
query I
select sub(1, 1);
----
0

# Create a sql udf function with unamed parameters that calls other pre-defined sql udfs
statement ok
create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add_named(a, b)';
create function add_sub_binding() returns int language sql as 'select add(1, 1) + sub(2, 2)';

# TODO: add checks
# Named sql udf with invalid parameter in body definition
# Will be rejected at creation time
# statement error failed to find named parameter aa
# create function unknown_parameter(a INT) returns int language sql as 'select a + aa + a';
query I
select add_sub_binding();
----
2

# Use them all together
query III
select add(1, -1), sub(1, 1), add_sub_binding();
----
0 0 2

# Create a sql udf with named parameters with single quote as clause
statement ok
create function add_named(a INT, b INT) returns int language sql as 'select a + b';

# Call the defined named sql udfs
query I
select add_named(1, -1);
----
0

# Create another sql udf with named parameters with double dollar as clause
statement ok
create function sub_named(a INT, b INT) returns int language sql as $$select a - b$$;

query I
select sub_named(1, 1);
----
0

query T
select corner_case(1);
# Mixed with named / unnamed parameters
statement ok
create function add_sub_mix(INT, a INT, INT) returns int language sql as 'select $1 - a + $3';

query I
select add_sub_mix(1, 2, 3);
----
$1 + a + $3
2

# Call sql udf with unnamed parameters inside sql udf with named parameters
statement ok
create function add_named_wrapper(a INT, b INT) returns int language sql as 'select add(a, b)';

query I
select add_named_wrapper(1, -1);
----
0
0

# Create a sql udf with unnamed parameters with return expression
statement ok
create function add_return(INT, INT) returns int language sql return $1 + $2;

query I
select add_return(1, 1);
----
2

statement ok
create function add_return_binding() returns int language sql return add_return(1, 1) + add_return(1, 1);

query I
select add_return_binding();
----
4

statement ok
create function print(INT) returns int language sql as 'select $1';

query T
select print(114514);
----
114514

# Multiple type interleaving sql udf
statement ok
create function add_sub(INT, FLOAT, INT) returns float language sql as $$select -$1 + $2 - $3$$;

query I
select add_sub(1, 5.1415926, 1);
----
3.1415926

query III
select add(1, -1), sub(1, 1), add_sub(1, 5.1415926, 1);
----
0 0 3.1415926

# TODO: `Real` is not supported yet
# Complex types interleaving
# statement ok
# create function add_sub_types(INT, BIGINT, FLOAT, DECIMAL, REAL) returns double language sql as 'select $1 + $2 - $3 + $4 + $5';

# query I
# select add_sub_types(1, 1919810114514, 3.1415926, 1.123123, 101010.191919);
# ----
# 1919810215523.1734494

statement ok
create function add_sub_return(INT, FLOAT, INT) returns float language sql return -$1 + $2 - $3;

query I
select add_sub_return(1, 5.1415926, 1);
----
3.1415926

# Create a wrapper function for `add` & `sub`
statement ok
create function add_sub_wrapper(INT, INT) returns int language sql as 'select add($1, $2) + sub($1, $2) + 114512';

query I
select add_sub_wrapper(1, 1);
----
114514

##########################################################
# Basic sql udfs integrated with the use of mock tables #
# P.S. This is also a simulation of real world use cases #
##########################################################

statement ok
create table t1 (c1 INT, c2 INT);

statement ok
create table t2 (c1 INT, c2 FLOAT, c3 INT);

# Special table for named sql udf
statement ok
create table t3 (a INT, b INT);

statement ok
insert into t1 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

statement ok
insert into t2 values (1, 3.14, 2), (2, 4.44, 5), (20, 10.30, 02);

statement ok
insert into t3 values (1, 1), (2, 2), (3, 3), (4, 4), (5, 5);

query I
select c1, c2, add_return(c1, c2) from t1 order by c1 asc;
----
1 1 2
2 2 4
3 3 6
4 4 8
5 5 10

query III
select sub(c1, c2), c1, c2, add(c1, c2) from t1 order by c1 asc;
----
0 1 1 2
0 2 2 4
0 3 3 6
0 4 4 8
0 5 5 10

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub(c1, c2, c3) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

query IIIIII
select c1, c2, c3, add(c1, c3), sub(c1, c3), add_sub_return(c1, c2, c3) from t2 order by c1 asc;
----
1 3.14 2 3 -1 0.14000000000000012
2 4.44 5 7 -3 -2.5599999999999996
20 10.3 2 22 18 -11.7

query I
select add_named(a, b) from t3 order by a asc;
----
2
4
6
8
10

################################
# Corner & Special cases tests #
################################

# Mixed parameter with calling inner sql udfs
statement ok
create function add_sub_mix_wrapper(INT, a INT, INT) returns int language sql as 'select add($1, a) + a + sub(a, $3)';

query I
select add_sub_mix_wrapper(1, 2, 3);
----
4

# Named sql udf with corner case
statement ok
create function corner_case(INT, a INT, INT) returns varchar language sql as $$select '$1 + a + $3'$$;

query T
select corner_case(1, 2, 3);
----
$1 + a + $3

# Adjust the input value of the calling function (i.e., `print` here) with the actual input parameter
statement ok
create function print_add_one(INT) returns int language sql as 'select print($1 + 1)';

statement ok
create function print_add_two(INT) returns int language sql as 'select print($1 + $1)';

query III
select print_add_one(1), print_add_one(114513), print_add_two(2);
----
2 114514 4

0 comments on commit 6f97db7

Please sign in to comment.