diff --git a/Cargo.lock b/Cargo.lock index 140ca65543e9..76a63bd13079 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4911,12 +4911,14 @@ dependencies = [ name = "databend-query" version = "0.1.0" dependencies = [ + "anyhow", "arrow-array", "arrow-buffer", "arrow-cast", "arrow-flight", "arrow-ipc", "arrow-schema", + "arrow-select", "arrow-udf-js", "arrow-udf-python", "arrow-udf-wasm", diff --git a/src/common/exception/src/exception_code.rs b/src/common/exception/src/exception_code.rs index ef4a30d220e6..c5923dfd08eb 100644 --- a/src/common/exception/src/exception_code.rs +++ b/src/common/exception/src/exception_code.rs @@ -212,6 +212,10 @@ build_exceptions! { // Geometry errors. GeometryError(1801), InvalidGeometryFormat(1802), + + // UDF errors. + UDFRuntimeError(1810), + // Tantivy errors. TantivyError(1901), TantivyOpenReadError(1902), diff --git a/src/meta/app/src/principal/mod.rs b/src/meta/app/src/principal/mod.rs index 568253041164..1159a9df46cb 100644 --- a/src/meta/app/src/principal/mod.rs +++ b/src/meta/app/src/principal/mod.rs @@ -95,6 +95,7 @@ pub use user_auth::AuthType; pub use user_auth::PasswordHashMethod; pub use user_defined_file_format::UserDefinedFileFormat; pub use user_defined_function::LambdaUDF; +pub use user_defined_function::UDAFScript; pub use user_defined_function::UDFDefinition; pub use user_defined_function::UDFScript; pub use user_defined_function::UDFServer; diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index 8da602fdbb66..19b72cf57dc5 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -18,6 +18,7 @@ use std::fmt::Formatter; use chrono::DateTime; use chrono::Utc; use databend_common_expression::types::DataType; +use databend_common_expression::DataField; #[derive(Clone, Debug, Eq, PartialEq)] pub struct LambdaUDF { @@ -44,11 +45,24 @@ pub struct UDFScript { pub runtime_version: String, } +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UDAFScript { + pub code: String, + pub language: String, + // aggregate function input types + pub arg_types: Vec, + // aggregate function state fields + pub state_fields: Vec, + pub return_type: DataType, + pub runtime_version: String, +} + #[derive(Clone, Debug, Eq, PartialEq)] pub enum UDFDefinition { LambdaUDF(LambdaUDF), UDFServer(UDFServer), UDFScript(UDFScript), + UDAFScript(UDAFScript), } #[derive(Clone, Debug, Eq, PartialEq)] @@ -160,7 +174,6 @@ impl Display for UDFDefinition { ") RETURNS {return_type} LANGUAGE {language} HANDLER = {handler} ADDRESS = {address}" )?; } - UDFDefinition::UDFScript(UDFScript { code, arg_types, @@ -180,6 +193,29 @@ impl Display for UDFDefinition { ") RETURNS {return_type} LANGUAGE {language} RUNTIME_VERSION = {runtime_version} HANDLER = {handler} AS $${code}$$" )?; } + UDFDefinition::UDAFScript(UDAFScript { + code, + arg_types, + state_fields, + return_type, + language, + runtime_version, + }) => { + for (i, item) in arg_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{item}")?; + } + write!(f, ") STATE {{ ")?; + for (i, item) in state_fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{} {}", item.name(), item.data_type())?; + } + write!(f, " }} RETURNS {return_type} LANGUAGE {language} RUNTIME_VERSION = {runtime_version} AS $${code}$$")?; + } } Ok(()) } diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs index 8b40c1379832..6eace2056e4d 100644 --- a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -17,6 +17,7 @@ use chrono::Utc; use databend_common_expression::infer_schema_type; use databend_common_expression::types::DataType; use databend_common_expression::TableDataType; +use databend_common_expression::TableField; use databend_common_meta_app::principal as mt; use databend_common_protos::pb; @@ -164,6 +165,89 @@ impl FromToProto for mt::UDFScript { } } +impl FromToProto for mt::UDAFScript { + type PB = pb::UdafScript; + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + fn from_pb(p: pb::UdafScript) -> Result { + reader_check_msg(p.ver, p.min_reader_ver)?; + + let arg_types = p + .arg_types + .into_iter() + .map(|arg_type| Ok((&TableDataType::from_pb(arg_type)?).into())) + .collect::, _>>()?; + + let state_fields = p + .state_fields + .into_iter() + .map(|field| TableField::from_pb(field).map(|field| (&field).into())) + .collect::, _>>()?; + + let return_type = + (&TableDataType::from_pb(p.return_type.ok_or_else(|| Incompatible { + reason: "UDAFScript.return_type can not be None".to_string(), + })?)?) + .into(); + + Ok(mt::UDAFScript { + code: p.code, + arg_types, + return_type, + language: p.language, + runtime_version: p.runtime_version, + state_fields, + }) + } + + fn to_pb(&self) -> Result { + let mut arg_types = Vec::with_capacity(self.arg_types.len()); + for arg_type in self.arg_types.iter() { + let arg_type = infer_schema_type(arg_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + arg_types.push(arg_type); + } + + let state_fields = self + .state_fields + .iter() + .map(|field| { + TableField::new( + field.name(), + infer_schema_type(field.data_type()).map_err(|e| Incompatible { + reason: format!( + "Convert DataType to TableDataType failed: {}", + e.message() + ), + })?, + ) + .to_pb() + }) + .collect::>()?; + + let return_type = infer_schema_type(&self.return_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + + Ok(pb::UdafScript { + ver: VER, + min_reader_ver: MIN_READER_VER, + code: self.code.clone(), + language: self.language.clone(), + runtime_version: self.runtime_version.clone(), + arg_types, + state_fields, + return_type: Some(return_type), + }) + } +} + impl FromToProto for mt::UserDefinedFunction { type PB = pb::UserDefinedFunction; fn get_pb_ver(p: &Self::PB) -> u64 { @@ -181,6 +265,9 @@ impl FromToProto for mt::UserDefinedFunction { Some(pb::user_defined_function::Definition::UdfScript(udf_script)) => { mt::UDFDefinition::UDFScript(mt::UDFScript::from_pb(udf_script)?) } + Some(pb::user_defined_function::Definition::UdafScript(udaf_script)) => { + mt::UDFDefinition::UDAFScript(mt::UDAFScript::from_pb(udaf_script)?) + } None => { return Err(Incompatible { reason: "UserDefinedFunction.definition cannot be None".to_string(), @@ -210,6 +297,9 @@ impl FromToProto for mt::UserDefinedFunction { mt::UDFDefinition::UDFScript(udf_script) => { pb::user_defined_function::Definition::UdfScript(udf_script.to_pb()?) } + mt::UDFDefinition::UDAFScript(udaf_script) => { + pb::user_defined_function::Definition::UdafScript(udaf_script.to_pb()?) + } }; Ok(pb::UserDefinedFunction { diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index 66ad6789451c..ec11d2f2f37b 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -144,6 +144,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (112, "2024-11-28: Add: virtual_column add data_types field"), (113, "2024-12-10: Add: GrantWarehouseObject"), (114, "2024-12-12: Add: New DataType Interval."), + (115, "2024-12-16: Add: udf.proto: add UDAFScript and UDAFServer"), // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index f2641a1c17d5..4a67d26501a4 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -112,3 +112,4 @@ mod v111_add_glue_as_iceberg_catalog_option; mod v112_virtual_column; mod v113_warehouse_grantobject; mod v114_interval_datatype; +mod v115_add_udaf_script; diff --git a/src/meta/proto-conv/tests/it/v081_udf_script.rs b/src/meta/proto-conv/tests/it/v081_udf_script.rs index f15544ac1e1f..0ac897d1a3eb 100644 --- a/src/meta/proto-conv/tests/it/v081_udf_script.rs +++ b/src/meta/proto-conv/tests/it/v081_udf_script.rs @@ -18,6 +18,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_meta_app::principal::LambdaUDF; use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; use databend_common_meta_app::principal::UserDefinedFunction; use fastrace::func_name; @@ -90,3 +91,33 @@ fn test_decode_v81_udf_sql() -> anyhow::Result<()> { common::test_pb_from_to(func_name!(), want())?; common::test_load_old(func_name!(), bytes.as_slice(), 81, want()) } + +#[test] +fn test_decode_udf_script() -> anyhow::Result<()> { + let bytes: Vec = vec![ + 10, 5, 109, 121, 95, 102, 110, 18, 21, 84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 100, + 101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 50, 78, 10, 9, 115, 111, 109, 101, 32, 99, + 111, 100, 101, 18, 5, 109, 121, 95, 102, 110, 26, 6, 112, 121, 116, 104, 111, 110, 34, 17, + 154, 2, 8, 58, 0, 160, 6, 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 42, 17, 154, 2, 8, 74, + 0, 160, 6, 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 50, 6, 51, 46, 49, 50, 46, 50, 160, 6, + 115, 168, 6, 24, 42, 23, 49, 57, 55, 48, 45, 48, 49, 45, 48, 49, 32, 48, 48, 58, 48, 48, + 58, 48, 48, 32, 85, 84, 67, 160, 6, 115, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "my_fn".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::UDFScript(UDFScript { + code: "some code".to_string(), + handler: "my_fn".to_string(), + language: "python".to_string(), + arg_types: vec![DataType::Number(NumberDataType::Int32)], + return_type: DataType::Number(NumberDataType::Float32), + runtime_version: "3.12.2".to_string(), + }), + created_on: DateTime::::default(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 115, want()) +} diff --git a/src/meta/proto-conv/tests/it/v115_add_udaf_script.rs b/src/meta/proto-conv/tests/it/v115_add_udaf_script.rs new file mode 100644 index 000000000000..c6b49ee0f70c --- /dev/null +++ b/src/meta/proto-conv/tests/it/v115_add_udaf_script.rs @@ -0,0 +1,70 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::DateTime; +use chrono::Utc; +use databend_common_expression::types::DataType; +use databend_common_expression::types::NumberDataType; +use databend_common_expression::DataField; +use databend_common_meta_app::principal::UDAFScript; +use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UserDefinedFunction; +use fastrace::func_name; + +use crate::common; + +// These bytes are built when a new version in introduced, +// and are kept for backward compatibility test. +// +// ************************************************************* +// * These messages should never be updated, * +// * only be added when a new version is added, * +// * or be removed when an old version is no longer supported. * +// ************************************************************* +// +// The message bytes are built from the output of `proto_conv::test_build_pb_buf()` + +#[test] +fn test_decode_v115_add_udaf_script() -> anyhow::Result<()> { + let bytes: Vec = vec![ + 10, 5, 109, 121, 95, 102, 110, 18, 21, 84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 100, + 101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 58, 99, 10, 9, 115, 111, 109, 101, 32, 99, + 111, 100, 101, 18, 10, 106, 97, 118, 97, 115, 99, 114, 105, 112, 116, 34, 17, 154, 2, 8, + 74, 0, 160, 6, 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 42, 17, 154, 2, 8, 58, 0, 160, 6, + 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 50, 30, 10, 3, 115, 117, 109, 26, 17, 154, 2, 8, + 66, 0, 160, 6, 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 160, 6, 115, 168, 6, 24, 160, 6, + 115, 168, 6, 24, 42, 23, 49, 57, 55, 48, 45, 48, 49, 45, 48, 49, 32, 48, 48, 58, 48, 48, + 58, 48, 48, 32, 85, 84, 67, 160, 6, 115, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "my_fn".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::UDAFScript(UDAFScript { + code: "some code".to_string(), + language: "javascript".to_string(), + arg_types: vec![DataType::Number(NumberDataType::Int32)], + state_fields: vec![DataField::new( + "sum", + DataType::Number(NumberDataType::Int64), + )], + return_type: DataType::Number(NumberDataType::Float32), + runtime_version: "".to_string(), + }), + created_on: DateTime::::default(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 115, want()) +} diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto index 3ed23b3c4fbb..3a0a7b3d278a 100644 --- a/src/meta/protos/proto/udf.proto +++ b/src/meta/protos/proto/udf.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package databend_proto; import "datatype.proto"; +import "metadata.proto"; message LambdaUDF { uint64 ver = 100; @@ -49,6 +50,17 @@ message UDFScript { string runtime_version = 6; } +message UDAFScript { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + string code = 1; + string language = 2; + string runtime_version = 3; + DataType return_type = 4; + repeated DataType arg_types = 5; + repeated DataField state_fields = 6; +} message UserDefinedFunction { uint64 ver = 100; @@ -60,7 +72,8 @@ message UserDefinedFunction { LambdaUDF lambda_udf = 3; UDFServer udf_server = 4; UDFScript udf_script = 6; + UDAFScript udaf_script = 7; } // The time udf created. optional string created_on = 5; -} \ No newline at end of file +} diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs index 4db598cb18de..c1b7887051b8 100644 --- a/src/query/ast/src/ast/statements/udf.rs +++ b/src/query/ast/src/ast/statements/udf.rs @@ -37,7 +37,6 @@ pub enum UDFDefinition { handler: String, language: String, }, - UDFScript { arg_types: Vec, return_type: TypeName, @@ -46,6 +45,21 @@ pub enum UDFDefinition { language: String, runtime_version: String, }, + UDAFServer { + arg_types: Vec, + state_fields: Vec, + return_type: TypeName, + address: String, + language: String, + }, + UDAFScript { + arg_types: Vec, + state_fields: Vec, + return_type: TypeName, + code: String, + language: String, + runtime_version: String, + }, } impl Display for UDFDefinition { @@ -66,11 +80,11 @@ impl Display for UDFDefinition { handler, language, } => { - write!(f, "(")?; + write!(f, "( ")?; write_comma_separated_list(f, arg_types)?; write!( f, - ") RETURNS {return_type} LANGUAGE {language} HANDLER = '{handler}' ADDRESS = '{address}'" + " ) RETURNS {return_type} LANGUAGE {language} HANDLER = '{handler}' ADDRESS = '{address}'" )?; } UDFDefinition::UDFScript { @@ -81,11 +95,44 @@ impl Display for UDFDefinition { language, runtime_version: _, } => { - write!(f, "(")?; + write!(f, "( ")?; write_comma_separated_list(f, arg_types)?; write!( f, - ") RETURNS {return_type} LANGUAGE {language} HANDLER = '{handler}' AS $$\n{code}\n$$" + " ) RETURNS {return_type} LANGUAGE {language} HANDLER = '{handler}' AS $$\n{code}\n$$" + )?; + } + UDFDefinition::UDAFServer { + arg_types, + state_fields: state_types, + return_type, + address, + language, + } => { + write!(f, "( ")?; + write_comma_separated_list(f, arg_types)?; + write!(f, " ) STATE {{ ")?; + write_comma_separated_list(f, state_types)?; + write!( + f, + " }} RETURNS {return_type} LANGUAGE {language} ADDRESS = '{address}'" + )?; + } + UDFDefinition::UDAFScript { + arg_types, + state_fields: state_types, + return_type, + code, + language, + runtime_version: _, + } => { + write!(f, "( ")?; + write_comma_separated_list(f, arg_types)?; + write!(f, " ) STATE {{ ")?; + write_comma_separated_list(f, state_types)?; + write!( + f, + " }} RETURNS {return_type} LANGUAGE {language} AS $$\n{code}\n$$" )?; } } @@ -93,6 +140,19 @@ impl Display for UDFDefinition { } } +#[derive(Debug, Clone, PartialEq, Drive, DriveMut)] +pub struct UDAFStateField { + pub name: Identifier, + pub type_name: TypeName, +} + +impl Display for UDAFStateField { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{} {}", self.name, self.type_name)?; + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq, Drive, DriveMut)] pub struct CreateUDFStmt { pub create_option: CreateOption, diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index 8f5aa6d31829..fef252153be0 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -2336,9 +2336,9 @@ pub fn statement_body(i: Input) -> IResult { | #show_roles : "`SHOW ROLES`" | #create_role : "`CREATE ROLE [IF NOT EXISTS] `" | #drop_role : "`DROP ROLE [IF EXISTS] `" - | #create_udf : "`CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] {AS (, ...) -> | (, ...) RETURNS LANGUAGE HANDLER= ADDRESS=} [DESC = ]`" + | #create_udf : "`CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [DESC = ]`" | #drop_udf : "`DROP FUNCTION [IF EXISTS] `" - | #alter_udf : "`ALTER FUNCTION (, ...) -> [DESC = ]`" + | #alter_udf : "`ALTER FUNCTION [DESC = ]`" | #set_role: "`SET [DEFAULT] ROLE `" | #set_secondary_roles: "`SET SECONDARY ROLES (ALL | NONE)`" | #show_user_functions : "`SHOW USER FUNCTIONS []`" @@ -4309,15 +4309,35 @@ pub fn update_expr(i: Input) -> IResult { })(i) } -pub fn udf_arg_type(i: Input) -> IResult { +pub fn udaf_state_field(i: Input) -> IResult { map( rule! { - #type_name + #ident + ~ #type_name + : "` `" }, - |type_name| match type_name { - TypeName::Nullable(_) | TypeName::NotNull(_) => type_name, - _ => type_name.wrap_nullable(), + |(name, type_name)| UDAFStateField { name, type_name }, + )(i) +} + +pub fn udf_script_or_address(i: Input) -> IResult<(String, bool)> { + let script = map( + rule! { + AS ~ ^(#code_string | #literal_string) }, + |(_, code)| (code, true), + ); + + let address = map( + rule! { + ADDRESS ~ ^"=" ~ ^#literal_string + }, + |(_, _, address)| (address, false), + ); + + rule!( + #script: "AS " + | #address: "ADDRESS=" )(i) } @@ -4333,51 +4353,75 @@ pub fn udf_definition(i: Input) -> IResult { }, ); - let udf_server = map( + let udf = map( rule! { - "(" ~ #comma_separated_list0(udf_arg_type) ~ ")" - ~ RETURNS ~ #udf_arg_type + "(" ~ #comma_separated_list0(type_name) ~ ")" + ~ RETURNS ~ #type_name ~ LANGUAGE ~ #ident ~ HANDLER ~ ^"=" ~ ^#literal_string - ~ ADDRESS ~ ^"=" ~ ^#literal_string - }, - |(_, arg_types, _, _, return_type, _, language, _, _, handler, _, _, address)| { - UDFDefinition::UDFServer { - arg_types, - return_type, - address, - handler, - language: language.to_string(), + ~ #udf_script_or_address + }, + |(_, arg_types, _, _, return_type, _, language, _, _, handler, address_or_code)| { + if address_or_code.1 { + UDFDefinition::UDFScript { + arg_types, + return_type, + code: address_or_code.0, + handler, + language: language.to_string(), + // TODO inject runtime_version by user + // Now we use fixed runtime version + runtime_version: "".to_string(), + } + } else { + UDFDefinition::UDFServer { + arg_types, + return_type, + address: address_or_code.0, + handler, + language: language.to_string(), + } } }, ); - let udf_script = map( + let udaf = map( rule! { - "(" ~ #comma_separated_list0(udf_arg_type) ~ ")" - ~ RETURNS ~ #udf_arg_type + "(" ~ #comma_separated_list0(type_name) ~ ")" + ~ STATE ~ "{" ~ #comma_separated_list0(udaf_state_field) ~ "}" + ~ RETURNS ~ #type_name ~ LANGUAGE ~ #ident - ~ HANDLER ~ ^"=" ~ ^#literal_string - ~ AS ~ ^(#code_string | #literal_string) - }, - |(_, arg_types, _, _, return_type, _, language, _, _, handler, _, code)| { - UDFDefinition::UDFScript { - arg_types, - return_type, - code, - handler, - language: language.to_string(), - // TODO inject runtime_version by user - // Now we use fixed runtime version - runtime_version: "".to_string(), + ~ #udf_script_or_address + }, + |(_, arg_types, _, _, _, state_types, _, _, return_type, _, language, address_or_code)| { + if address_or_code.1 { + UDFDefinition::UDAFScript { + arg_types, + state_fields: state_types, + return_type, + code: address_or_code.0, + language: language.to_string(), + // TODO inject runtime_version by user + // Now we use fixed runtime version + runtime_version: "".to_string(), + } + } else { + UDFDefinition::UDAFServer { + arg_types, + state_fields: state_types, + return_type, + address: address_or_code.0, + language: language.to_string(), + } } }, ); rule!( - #udf_server: "(, ...) RETURNS LANGUAGE HANDLER= ADDRESS=" - | #lambda_udf: "AS (, ...) -> " - | #udf_script: "(, ...) RETURNS LANGUAGE HANDLER= AS " + #lambda_udf: "AS (, ...) -> " + | #udaf: "(, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } " + | #udf: "(, ...) RETURNS LANGUAGE HANDLER= { ADDRESS= | AS } " + )(i) } diff --git a/src/query/ast/src/parser/token.rs b/src/query/ast/src/parser/token.rs index 82974cd017ce..fd2e22d55630 100644 --- a/src/query/ast/src/parser/token.rs +++ b/src/query/ast/src/parser/token.rs @@ -1302,6 +1302,8 @@ pub enum TokenKind { HANDLER, #[token("LANGUAGE", ignore(ascii_case))] LANGUAGE, + #[token("STATE", ignore(ascii_case))] + STATE, #[token("TASK", ignore(ascii_case))] TASK, #[token("TASKS", ignore(ascii_case))] diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index dfd95ccd7a19..deecd6b6a1c7 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -81,6 +81,7 @@ fn run_parser_with_dialect( } } +// UPDATE_GOLDENFILES=1 cargo test --package databend-common-ast --test it -- parser::test_statement #[test] fn test_statement() { let mut mint = Mint::new("tests/it/testdata"); @@ -772,6 +773,7 @@ fn test_statement() { r#"CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p));"#, r#"CREATE OR REPLACE FUNCTION isnotempty_test_replace AS(p) -> not(is_null(p)) DESC = 'This is a description';"#, r#"CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, + r#"ALTER FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE OR REPLACE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815';"#, r#"CREATE file format my_orc type = orc"#, r#"CREATE file format my_orc type = orc missing_field_as=field_default"#, @@ -797,6 +799,9 @@ fn test_statement() { "#, r#"DROP FUNCTION binary_reverse;"#, r#"DROP FUNCTION isnotempty;"#, + r#"CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815';"#, + r#"CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING, i INT NOT NULL } RETURNS BOOLEAN LANGUAGE javascript AS 'some code';"#, + r#"ALTER FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript AS 'some code';"#, r#" EXECUTE IMMEDIATE $$ @@ -982,6 +987,8 @@ fn test_statement_error() { r#"REVOKE OWNERSHIP ON d20_0014.* FROM ROLE A;"#, r#"GRANT OWNERSHIP ON *.* TO ROLE 'd20_0015_owner';"#, r#"CREATE FUNCTION IF NOT EXISTS isnotempty AS(p) -> not(is_null(p)"#, + r#"CREATE FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript HANDLER = 'my_agg' ADDRESS = 'http://0.0.0.0:8815';"#, + r#"CREATE FUNCTION my_agg (INT) STATE { s STRIN } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815';"#, r#"drop table :a"#, r#"drop table IDENTIFIER(a)"#, r#"drop table IDENTIFIER(:a)"#, diff --git a/src/query/ast/tests/it/testdata/stmt-error.txt b/src/query/ast/tests/it/testdata/stmt-error.txt index 9c53c0f64c52..2704153c85a4 100644 --- a/src/query/ast/tests/it/testdata/stmt-error.txt +++ b/src/query/ast/tests/it/testdata/stmt-error.txt @@ -870,7 +870,33 @@ error: | | | | while parsing `( [, ...])` | | | while parsing expression | | while parsing AS (, ...) -> - | while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] {AS (, ...) -> | (, ...) RETURNS LANGUAGE HANDLER= ADDRESS=} [DESC = ]` + | while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [DESC = ]` + + +---------- Input ---------- +CREATE FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript HANDLER = 'my_agg' ADDRESS = 'http://0.0.0.0:8815'; +---------- Output --------- +error: + --> SQL:1:85 + | +1 | CREATE FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript HANDLER = 'my_agg' ADDRESS = 'http://0.0.0.0:8815'; + | ------ - ^^^^^^^ unexpected `HANDLER`, expecting `ADDRESS` or `AS` + | | | + | | while parsing (, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } + | while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [DESC = ]` + + +---------- Input ---------- +CREATE FUNCTION my_agg (INT) STATE { s STRIN } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815'; +---------- Output --------- +error: + --> SQL:1:40 + | +1 | CREATE FUNCTION my_agg (INT) STATE { s STRIN } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815'; + | ------ - ^^^^^ unexpected `STRIN`, expecting `STRING`, `SIGNED`, `INTERVAL`, `TINYINT`, `VARIANT`, `SMALLINT`, `TINYBLOB`, `VARBINARY`, `INT8`, `JSON`, `INT16`, `INT32`, `INT64`, `UINT8`, `BIGINT`, `UINT16`, `UINT32`, `UINT64`, `BINARY`, `INTEGER`, `DATETIME`, `TIMESTAMP`, `UNSIGNED`, `DATE`, `CHAR`, `TEXT`, `ARRAY`, `TUPLE`, `BOOLEAN`, `DECIMAL`, `VARCHAR`, `LONGBLOB`, `NULLABLE`, `CHARACTER`, `GEOGRAPHY`, `MEDIUMBLOB`, `BITMAP`, `}`, `BOOL`, `INT`, `FLOAT32`, `FLOAT`, `FLOAT64`, `DOUBLE`, `MAP`, `BLOB`, or `GEOMETRY` + | | | + | | while parsing (, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } + | while parsing `CREATE [OR REPLACE] FUNCTION [IF NOT EXISTS] [DESC = ]` ---------- Input ---------- diff --git a/src/query/ast/tests/it/testdata/stmt.txt b/src/query/ast/tests/it/testdata/stmt.txt index 678365bbd6c1..21279f91614c 100644 --- a/src/query/ast/tests/it/testdata/stmt.txt +++ b/src/query/ast/tests/it/testdata/stmt.txt @@ -23254,7 +23254,7 @@ CreateUDF( ---------- Input ---------- CREATE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815'; ---------- Output --------- -CREATE FUNCTION binary_reverse (BINARY NULL) RETURNS BINARY NULL LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' +CREATE FUNCTION binary_reverse ( BINARY ) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' ---------- AST ------------ CreateUDF( CreateUDFStmt { @@ -23270,13 +23270,38 @@ CreateUDF( description: None, definition: UDFServer { arg_types: [ - Nullable( - Binary, - ), - ], - return_type: Nullable( Binary, + ], + return_type: Binary, + address: "http://0.0.0.0:8815", + handler: "binary_reverse", + language: "python", + }, + }, +) + + +---------- Input ---------- +ALTER FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815'; +---------- Output --------- +ALTER FUNCTION binary_reverse ( BINARY ) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' +---------- AST ------------ +AlterUDF( + AlterUDFStmt { + udf_name: Identifier { + span: Some( + 15..29, ), + name: "binary_reverse", + quote: None, + ident_type: None, + }, + description: None, + definition: UDFServer { + arg_types: [ + Binary, + ], + return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", language: "python", @@ -23288,7 +23313,7 @@ CreateUDF( ---------- Input ---------- CREATE OR REPLACE FUNCTION binary_reverse (BINARY) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815'; ---------- Output --------- -CREATE OR REPLACE FUNCTION binary_reverse (BINARY NULL) RETURNS BINARY NULL LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' +CREATE OR REPLACE FUNCTION binary_reverse ( BINARY ) RETURNS BINARY LANGUAGE python HANDLER = 'binary_reverse' ADDRESS = 'http://0.0.0.0:8815' ---------- AST ------------ CreateUDF( CreateUDFStmt { @@ -23304,13 +23329,9 @@ CreateUDF( description: None, definition: UDFServer { arg_types: [ - Nullable( - Binary, - ), - ], - return_type: Nullable( Binary, - ), + ], + return_type: Binary, address: "http://0.0.0.0:8815", handler: "binary_reverse", language: "python", @@ -23415,7 +23436,7 @@ def addone_py(i): return i+1 $$; ---------- Output --------- -CREATE OR REPLACE FUNCTION addone (Int32 NULL) RETURNS Int32 NULL LANGUAGE python HANDLER = 'addone_py' AS $$ +CREATE OR REPLACE FUNCTION addone ( Int32 ) RETURNS Int32 LANGUAGE python HANDLER = 'addone_py' AS $$ def addone_py(i): return i+1 $$ @@ -23434,13 +23455,9 @@ CreateUDF( description: None, definition: UDFScript { arg_types: [ - Nullable( - Int32, - ), - ], - return_type: Nullable( Int32, - ), + ], + return_type: Int32, code: "def addone_py(i):\nreturn i+1", handler: "addone_py", language: "python", @@ -23457,7 +23474,7 @@ language python handler = 'addone_py' as '@data/abc/a.py'; ---------- Output --------- -CREATE OR REPLACE FUNCTION addone (Int32 NULL) RETURNS Int32 NULL LANGUAGE python HANDLER = 'addone_py' AS $$ +CREATE OR REPLACE FUNCTION addone ( Int32 ) RETURNS Int32 LANGUAGE python HANDLER = 'addone_py' AS $$ @data/abc/a.py $$ ---------- AST ------------ @@ -23475,13 +23492,9 @@ CreateUDF( description: None, definition: UDFScript { arg_types: [ - Nullable( - Int32, - ), - ], - return_type: Nullable( Int32, - ), + ], + return_type: Int32, code: "@data/abc/a.py", handler: "addone_py", language: "python", @@ -23527,6 +23540,150 @@ DropUDF { } +---------- Input ---------- +CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815'; +---------- Output --------- +CREATE FUNCTION IF NOT EXISTS my_agg ( Int32 ) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript ADDRESS = 'http://0.0.0.0:8815' +---------- AST ------------ +CreateUDF( + CreateUDFStmt { + create_option: CreateIfNotExists, + udf_name: Identifier { + span: Some( + 30..36, + ), + name: "my_agg", + quote: None, + ident_type: None, + }, + description: None, + definition: UDAFServer { + arg_types: [ + Int32, + ], + state_fields: [ + UDAFStateField { + name: Identifier { + span: Some( + 51..52, + ), + name: "s", + quote: None, + ident_type: None, + }, + type_name: String, + }, + ], + return_type: Boolean, + address: "http://0.0.0.0:8815", + language: "javascript", + }, + }, +) + + +---------- Input ---------- +CREATE FUNCTION IF NOT EXISTS my_agg (INT) STATE { s STRING, i INT NOT NULL } RETURNS BOOLEAN LANGUAGE javascript AS 'some code'; +---------- Output --------- +CREATE FUNCTION IF NOT EXISTS my_agg ( Int32 ) STATE { s STRING, i Int32 NOT NULL } RETURNS BOOLEAN LANGUAGE javascript AS $$ +some code +$$ +---------- AST ------------ +CreateUDF( + CreateUDFStmt { + create_option: CreateIfNotExists, + udf_name: Identifier { + span: Some( + 30..36, + ), + name: "my_agg", + quote: None, + ident_type: None, + }, + description: None, + definition: UDAFScript { + arg_types: [ + Int32, + ], + state_fields: [ + UDAFStateField { + name: Identifier { + span: Some( + 51..52, + ), + name: "s", + quote: None, + ident_type: None, + }, + type_name: String, + }, + UDAFStateField { + name: Identifier { + span: Some( + 61..62, + ), + name: "i", + quote: None, + ident_type: None, + }, + type_name: NotNull( + Int32, + ), + }, + ], + return_type: Boolean, + code: "some code", + language: "javascript", + runtime_version: "", + }, + }, +) + + +---------- Input ---------- +ALTER FUNCTION my_agg (INT) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript AS 'some code'; +---------- Output --------- +ALTER FUNCTION my_agg ( Int32 ) STATE { s STRING } RETURNS BOOLEAN LANGUAGE javascript AS $$ +some code +$$ +---------- AST ------------ +AlterUDF( + AlterUDFStmt { + udf_name: Identifier { + span: Some( + 15..21, + ), + name: "my_agg", + quote: None, + ident_type: None, + }, + description: None, + definition: UDAFScript { + arg_types: [ + Int32, + ], + state_fields: [ + UDAFStateField { + name: Identifier { + span: Some( + 36..37, + ), + name: "s", + quote: None, + ident_type: None, + }, + type_name: String, + }, + ], + return_type: Boolean, + code: "some code", + language: "javascript", + runtime_version: "", + }, + }, +) + + ---------- Input ---------- EXECUTE IMMEDIATE $$ diff --git a/src/query/expression/src/types/number.rs b/src/query/expression/src/types/number.rs index 0462dff42aa3..0c2374405008 100644 --- a/src/query/expression/src/types/number.rs +++ b/src/query/expression/src/types/number.rs @@ -699,6 +699,12 @@ impl NumberColumn { ))), } } + + pub fn data_type(&self) -> NumberDataType { + crate::with_number_type!(|NUM_TYPE| match self { + NumberColumn::NUM_TYPE(_) => NumberDataType::NUM_TYPE, + }) + } } impl NumberColumnBuilder { @@ -765,8 +771,8 @@ impl NumberColumnBuilder { } (this, other) => unreachable!( "unable append column(data type: {:?}) into builder(data type: {:?})", - type_name_of(other), - type_name_of(this) + other.data_type(), + this.data_type() ), }) } @@ -790,6 +796,12 @@ impl NumberColumnBuilder { NumberColumnBuilder::NUM_TYPE(builder) => builder.pop().map(NumberScalar::NUM_TYPE), }) } + + pub fn data_type(&self) -> NumberDataType { + crate::with_number_type!(|NUM_TYPE| match self { + NumberColumnBuilder::NUM_TYPE(_) => NumberDataType::NUM_TYPE, + }) + } } impl SimpleDomain { @@ -822,10 +834,6 @@ fn overflow_cast_with_minmax(src: T, min: U, max: U) -> Op Some((dest, overflowing)) } -fn type_name_of(_: T) -> &'static str { - std::any::type_name::() -} - #[macro_export] macro_rules! with_number_type { ( | $t:tt | $($tail:tt)* ) => { diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs index 66cb3b061c5c..fa17c88576cf 100644 --- a/src/query/expression/src/utils/udf_client.rs +++ b/src/query/expression/src/utils/udf_client.rs @@ -189,7 +189,7 @@ impl UDFFlightClient { } let (input_fields, output_fields) = schema.fields().split_at(fields_num - 1); - let expect_arg_types = input_fields + let remote_arg_types = input_fields .iter() .map(|f| f.data_type().clone()) .collect::>(); @@ -197,10 +197,15 @@ impl UDFFlightClient { .iter() .map(|f| f.data_type().clone()) .collect::>(); - if expect_arg_types != arg_types { + if remote_arg_types != arg_types { return Err(ErrorCode::UDFSchemaMismatch(format!( - "UDF arg types mismatch, actual arg types: ({:?})", - expect_arg_types + "UDF arg types mismatch, remote arg types: ({:?}), defined arg types: ({:?})", + remote_arg_types + .iter() + .map(ToString::to_string) + .collect::>() + .join(", "), + arg_types .iter() .map(ToString::to_string) .collect::>() diff --git a/src/query/service/Cargo.toml b/src/query/service/Cargo.toml index 78ed377acc17..5a3836b8a22a 100644 --- a/src/query/service/Cargo.toml +++ b/src/query/service/Cargo.toml @@ -26,11 +26,13 @@ io-uring = [ enable_queries_executor = [] [dependencies] +anyhow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-flight = { workspace = true } arrow-ipc = { workspace = true, features = ["lz4", "zstd"] } arrow-schema = { workspace = true } +arrow-select = { workspace = true } arrow-udf-js = { workspace = true } arrow-udf-python = { workspace = true, optional = true } arrow-udf-wasm = { workspace = true } diff --git a/src/query/service/src/builtin/builtin_udfs.rs b/src/query/service/src/builtin/builtin_udfs.rs index 8a091e80e7ee..e03ea042b7a5 100644 --- a/src/query/service/src/builtin/builtin_udfs.rs +++ b/src/query/service/src/builtin/builtin_udfs.rs @@ -24,7 +24,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_meta_app::principal::UserDefinedFunction; -use databend_common_sql::resolve_type_name; +use databend_common_sql::resolve_type_name_udf; use log::error; pub struct BuiltinUDFs { @@ -52,9 +52,9 @@ impl BuiltinUDFs { } => { let mut arg_datatypes = Vec::with_capacity(arg_types.len()); for arg_type in arg_types { - arg_datatypes.push(DataType::from(&resolve_type_name(&arg_type, true)?)); + arg_datatypes.push(DataType::from(&resolve_type_name_udf(&arg_type)?)); } - let return_type = DataType::from(&resolve_type_name(&return_type, true)?); + let return_type = DataType::from(&resolve_type_name_udf(&return_type)?); let udf = UserDefinedFunction::create_udf_server( name, &address, diff --git a/src/query/service/src/pipelines/builders/builder_aggregate.rs b/src/query/service/src/pipelines/builders/builder_aggregate.rs index 1bd90e015149..3270a1138ac6 100644 --- a/src/query/service/src/pipelines/builders/builder_aggregate.rs +++ b/src/query/service/src/pipelines/builders/builder_aggregate.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use databend_common_catalog::table_context::TableContext; use databend_common_exception::Result; use databend_common_expression::AggregateFunctionRef; +use databend_common_expression::DataField; use databend_common_expression::DataSchemaRef; use databend_common_expression::HashTableConfig; use databend_common_expression::LimitType; @@ -30,10 +31,12 @@ use databend_common_sql::executor::physical_plans::AggregateFinal; use databend_common_sql::executor::physical_plans::AggregateFunctionDesc; use databend_common_sql::executor::physical_plans::AggregatePartial; use databend_common_sql::executor::PhysicalPlan; +use databend_common_sql::plans::UDFType; use databend_common_sql::IndexType; use databend_common_storage::DataOperator; use crate::pipelines::processors::transforms::aggregator::build_partition_bucket; +use crate::pipelines::processors::transforms::aggregator::create_udaf_script_function; use crate::pipelines::processors::transforms::aggregator::AggregateInjector; use crate::pipelines::processors::transforms::aggregator::AggregatorParams; use crate::pipelines::processors::transforms::aggregator::FinalSingleStateAggregator; @@ -221,7 +224,7 @@ impl PipelineBuilder { build_partition_bucket(&mut self.main_pipeline, params.clone()) } - pub fn build_aggregator_params( + fn build_aggregator_params( input_schema: DataSchemaRef, group_by: &[IndexType], agg_funcs: &[AggregateFunctionDesc], @@ -247,17 +250,37 @@ impl PipelineBuilder { let args = agg_func .arg_indices .iter() - .map(|i| { - let index = input_schema.index_of(&i.to_string())?; - Ok(index) - }) + .map(|i| input_schema.index_of(&i.to_string())) .collect::>>()?; agg_args.push(args); - AggregateFunctionFactory::instance().get( - agg_func.sig.name.as_str(), - agg_func.sig.params.clone(), - agg_func.sig.args.clone(), - ) + + match &agg_func.sig.udaf { + None => AggregateFunctionFactory::instance().get( + agg_func.sig.name.as_str(), + agg_func.sig.params.clone(), + agg_func.sig.args.clone(), + ), + Some((UDFType::Script(code), state_fields)) => create_udaf_script_function( + code, + agg_func.sig.name.clone(), + agg_func.display.clone(), + state_fields + .iter() + .map(|f| DataField::new(&f.name, f.data_type.clone())) + .collect(), + agg_func + .sig + .args + .iter() + .enumerate() + .map(|(i, data_type)| { + DataField::new(&format!("arg_{}", i), data_type.clone()) + }) + .collect(), + agg_func.sig.return_type.clone(), + ), + Some((UDFType::Server(_), _state_fields)) => unimplemented!(), + } }) .collect::>()?; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs index 0c6998e3bcd4..a4dd56da1e1c 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs @@ -21,6 +21,7 @@ mod transform_aggregate_expand; mod transform_aggregate_final; mod transform_aggregate_partial; mod transform_single_key; +mod udaf_script; mod utils; pub use aggregate_exchange_injector::AggregateInjector; @@ -32,6 +33,7 @@ pub use transform_aggregate_final::TransformFinalAggregate; pub use transform_aggregate_partial::TransformPartialAggregate; pub use transform_single_key::FinalSingleStateAggregator; pub use transform_single_key::PartialSingleStateAggregator; +pub use udaf_script::*; pub use utils::*; pub use self::serde::*; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs new file mode 100644 index 000000000000..9cef5d96932d --- /dev/null +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/udaf_script.rs @@ -0,0 +1,600 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::Layout; +use std::fmt; +use std::io::BufRead; +use std::io::Cursor; +use std::sync::Arc; +use std::sync::Mutex; + +use arrow_array::Array; +use arrow_array::RecordBatch; +use arrow_schema::ArrowError; +use arrow_schema::DataType as ArrowType; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::converts::arrow::ARROW_EXT_TYPE_VARIANT; +use databend_common_expression::converts::arrow::EXTENSION_KEY; +use databend_common_expression::types::Bitmap; +use databend_common_expression::types::DataType; +use databend_common_expression::Column; +use databend_common_expression::ColumnBuilder; +use databend_common_expression::DataBlock; +use databend_common_expression::DataField; +use databend_common_expression::DataSchema; +use databend_common_expression::InputColumns; +use databend_common_expression::StateAddr; +use databend_common_functions::aggregates::AggregateFunction; +use databend_common_sql::plans::UDFLanguage; +use databend_common_sql::plans::UDFScriptCode; + +#[cfg(feature = "python-udf")] +use super::super::python_udf::GLOBAL_PYTHON_RUNTIME; + +pub struct AggregateUdfScript { + display_name: String, + runtime: UDAFRuntime, + argument_schema: DataSchema, + init_state: UdfAggState, +} + +impl AggregateFunction for AggregateUdfScript { + fn name(&self) -> &str { + self.runtime.name() + } + + fn return_type(&self) -> Result { + Ok(self.runtime.return_type()) + } + + fn init_state(&self, place: StateAddr) { + place.write_state(UdfAggState(self.init_state.0.clone())); + } + + fn state_layout(&self) -> Layout { + Layout::new::() + } + + fn accumulate( + &self, + place: StateAddr, + columns: InputColumns, + validity: Option<&Bitmap>, + _input_rows: usize, + ) -> Result<()> { + let input_batch = self.create_input_batch(columns, validity)?; + let state = place.get::(); + let state = self + .runtime + .accumulate(state, &input_batch) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to accumulate: {e}")))?; + place.write_state(state); + Ok(()) + } + + fn accumulate_row(&self, place: StateAddr, columns: InputColumns, row: usize) -> Result<()> { + let input_batch = self.create_input_batch_row(columns, row)?; + let state = place.get::(); + let state = self + .runtime + .accumulate(state, &input_batch) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to accumulate_row: {e}")))?; + place.write_state(state); + Ok(()) + } + + fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { + let state = place.get::(); + state + .serialize(writer) + .map_err(|e| ErrorCode::Internal(format!("state failed to serialize: {e}"))) + } + + fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { + let state = place.get::(); + let rhs = + UdfAggState::deserialize(reader).map_err(|e| ErrorCode::Internal(e.to_string()))?; + let states = arrow_select::concat::concat(&[&state.0, &rhs.0])?; + let state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge: {e}")))?; + place.write_state(state); + Ok(()) + } + + fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { + let state = place.get::(); + let other = rhs.get::(); + let states = arrow_select::concat::concat(&[&state.0, &other.0]) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let state = self + .runtime + .merge(&states) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge_states: {e}")))?; + place.write_state(state); + Ok(()) + } + + fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { + let state = place.get::(); + let array = self + .runtime + .finish(state) + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to merge_result: {e}")))?; + let result = Column::from_arrow_rs(array, &self.runtime.return_type())?; + builder.append_column(&result); + Ok(()) + } +} + +impl fmt::Display for AggregateUdfScript { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.display_name) + } +} + +impl AggregateUdfScript { + #[cfg(debug_assertions)] + fn check_columns(&self, columns: InputColumns) { + let fields = self.argument_schema.fields(); + assert_eq!(columns.len(), fields.len()); + for (i, (col, field)) in columns.iter().zip(fields).enumerate() { + assert_eq!(&col.data_type(), field.data_type(), "args {}", i) + } + } + + fn create_input_batch( + &self, + columns: InputColumns, + validity: Option<&Bitmap>, + ) -> Result { + #[cfg(debug_assertions)] + self.check_columns(columns); + + let num_columns = columns.len(); + + let columns = columns.iter().cloned().collect(); + match validity { + Some(bitmap) => DataBlock::new_from_columns(columns).filter_with_bitmap(bitmap)?, + None => DataBlock::new_from_columns(columns), + } + .to_record_batch_with_dataschema(&self.argument_schema) + .map_err(|err| { + ErrorCode::UDFDataError(format!( + "Failed to create input batch with {} columns: {}", + num_columns, err + )) + }) + } + + fn create_input_batch_row(&self, columns: InputColumns, row: usize) -> Result { + #[cfg(debug_assertions)] + self.check_columns(columns); + + let num_columns = columns.len(); + + let columns = columns.iter().cloned().collect(); + DataBlock::new_from_columns(columns) + .slice(row..row + 1) + .to_record_batch_with_dataschema(&self.argument_schema) + .map_err(|err| { + ErrorCode::UDFDataError(format!( + "Failed to create input batch with {} columns: {}", + num_columns, err + )) + }) + } +} + +#[derive(Debug)] +pub struct UdfAggState(Arc); + +impl UdfAggState { + fn serialize(&self, writer: &mut Vec) -> std::result::Result<(), ArrowError> { + let schema = arrow_schema::Schema::new(vec![arrow_schema::Field::new( + "state", + self.0.data_type().clone(), + true, + )]); + let mut writer = arrow_ipc::writer::FileWriter::try_new_with_options( + writer, + &schema, + arrow_ipc::writer::IpcWriteOptions::default(), + )?; + let batch = RecordBatch::try_new(Arc::new(schema), vec![self.0.clone()])?; + writer.write(&batch)?; + writer.finish() + } + + fn deserialize(bytes: &mut &[u8]) -> std::result::Result { + let mut cursor = Cursor::new(&bytes); + let mut reader = arrow_ipc::reader::FileReaderBuilder::new().build(&mut cursor)?; + let array = reader + .next() + .ok_or(ArrowError::ComputeError( + "expected one arrow array".to_string(), + ))?? + .remove_column(0); + bytes.consume(cursor.position() as usize); + Ok(Self(array)) + } +} + +pub fn create_udaf_script_function( + code: &UDFScriptCode, + name: String, + display_name: String, + state_fields: Vec, + arguments: Vec, + output_type: DataType, +) -> Result> { + let UDFScriptCode { language, code, .. } = code; + let runtime = match language { + UDFLanguage::JavaScript => { + let pool = JsRuntimePool::new( + name, + String::from_utf8(code.to_vec())?, + ArrowType::Struct( + state_fields + .iter() + .map(|f| f.into()) + .collect::>() + .into(), + ), + output_type, + ); + UDAFRuntime::JavaScript(pool) + } + UDFLanguage::WebAssembly => unimplemented!(), + #[cfg(not(feature = "python-udf"))] + UDFLanguage::Python => { + return Err(ErrorCode::EnterpriseFeatureNotEnable( + "Failed to create python script udf", + )); + } + #[cfg(feature = "python-udf")] + UDFLanguage::Python => { + let mut runtime = GLOBAL_PYTHON_RUNTIME.write(); + let code = String::from_utf8(code.to_vec())?; + runtime.add_aggregate( + &name, + ArrowType::Struct( + state_fields + .iter() + .map(|f| f.into()) + .collect::>() + .into(), + ), + ArrowType::from(&output_type), + arrow_udf_python::CallMode::CalledOnNullInput, + &code, + )?; + UDAFRuntime::Python(PythonInfo { name, output_type }) + } + }; + let init_state = runtime + .create_state() + .map_err(|e| ErrorCode::UDFRuntimeError(format!("failed to create state: {e}")))?; + + Ok(Arc::new(AggregateUdfScript { + display_name, + runtime, + argument_schema: DataSchema::new(arguments), + init_state, + })) +} + +struct JsRuntimePool { + name: String, + code: String, + state_type: ArrowType, + output_type: DataType, + + runtimes: Mutex>, +} + +impl JsRuntimePool { + fn new(name: String, code: String, state_type: ArrowType, output_type: DataType) -> Self { + Self { + name, + code, + state_type, + output_type, + runtimes: Mutex::new(vec![]), + } + } + + fn create(&self) -> Result { + let mut runtime = match arrow_udf_js::Runtime::new() { + Ok(runtime) => runtime, + Err(e) => { + return Err(ErrorCode::UDFDataError(format!( + "Cannot create js runtime: {e}" + ))) + } + }; + + let converter = runtime.converter_mut(); + converter.set_arrow_extension_key(EXTENSION_KEY); + converter.set_json_extension_name(ARROW_EXT_TYPE_VARIANT); + + let output_type: ArrowType = (&self.output_type).into(); + runtime + .add_aggregate( + &self.name, + self.state_type.clone(), + output_type, + arrow_udf_js::CallMode::CalledOnNullInput, + &self.code, + ) + .map_err(|e| ErrorCode::UDFDataError(format!("Cannot add aggregate: {e}")))?; + + Ok(runtime) + } + + fn call(&self, op: F) -> anyhow::Result + where F: FnOnce(&arrow_udf_js::Runtime) -> anyhow::Result { + let mut runtimes = self.runtimes.lock().unwrap(); + let runtime = match runtimes.pop() { + Some(runtime) => runtime, + None => self.create()?, + }; + drop(runtimes); + + let result = op(&runtime)?; + + let mut runtimes = self.runtimes.lock().unwrap(); + runtimes.push(runtime); + + Ok(result) + } +} + +enum UDAFRuntime { + JavaScript(JsRuntimePool), + #[expect(unused)] + WebAssembly, + #[cfg(feature = "python-udf")] + Python(PythonInfo), +} + +#[cfg(feature = "python-udf")] +struct PythonInfo { + name: String, + output_type: DataType, +} + +impl UDAFRuntime { + fn name(&self) -> &str { + match self { + UDAFRuntime::JavaScript(pool) => &pool.name, + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => &info.name, + _ => unimplemented!(), + } + } + + fn return_type(&self) -> DataType { + match self { + UDAFRuntime::JavaScript(pool) => pool.output_type.clone(), + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => info.output_type.clone(), + _ => unimplemented!(), + } + } + + fn create_state(&self) -> anyhow::Result { + let state = match self { + UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.create_state(&pool.name)), + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => { + let runtime = GLOBAL_PYTHON_RUNTIME.read(); + runtime.create_state(&info.name) + } + _ => unimplemented!(), + }?; + Ok(UdfAggState(state)) + } + + fn accumulate(&self, state: &UdfAggState, input: &RecordBatch) -> anyhow::Result { + let state = match self { + UDAFRuntime::JavaScript(pool) => { + pool.call(|runtime| runtime.accumulate(&pool.name, &state.0, input)) + } + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => { + let runtime = GLOBAL_PYTHON_RUNTIME.read(); + runtime.accumulate(&info.name, &state.0, input) + } + _ => unimplemented!(), + }?; + Ok(UdfAggState(state)) + } + + fn merge(&self, states: &Arc) -> anyhow::Result { + let state = match self { + UDAFRuntime::JavaScript(pool) => pool.call(|runtime| runtime.merge(&pool.name, states)), + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => { + let runtime = GLOBAL_PYTHON_RUNTIME.read(); + runtime.merge(&info.name, states) + } + _ => unimplemented!(), + }?; + Ok(UdfAggState(state)) + } + + fn finish(&self, state: &UdfAggState) -> anyhow::Result> { + match self { + UDAFRuntime::JavaScript(pool) => { + pool.call(|runtime| runtime.finish(&pool.name, &state.0)) + } + #[cfg(feature = "python-udf")] + UDAFRuntime::Python(info) => { + let runtime = GLOBAL_PYTHON_RUNTIME.read(); + runtime.finish(&info.name, &state.0) + } + _ => unimplemented!(), + } + } +} + +#[cfg(test)] +mod tests { + use arrow_array::Array; + use arrow_array::Int32Array; + use arrow_array::Int64Array; + use arrow_array::StructArray; + use arrow_schema::DataType as ArrowType; + use arrow_schema::Field; + use databend_common_expression::types::ArgType; + use databend_common_expression::types::Float32Type; + + use super::*; + + #[test] + fn test_serialize() { + let want: Arc = Arc::new(StructArray::new( + vec![ + Field::new("a", ArrowType::Int32, false), + Field::new("b", ArrowType::Int64, false), + ] + .into(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int64Array::from(vec![4, 5, 6])), + ], + None, + )); + + let state = UdfAggState(want.clone()); + let mut buf = Vec::new(); + state.serialize(&mut buf).unwrap(); + + let state = UdfAggState::deserialize(&mut buf.as_slice()).unwrap(); + assert_eq!(&want, &state.0); + } + + #[test] + fn test_js_pool() -> Result<()> { + let agg_name = "weighted_avg".to_string(); + let fields = vec![ + Field::new("sum", ArrowType::Int64, false), + Field::new("weight", ArrowType::Int64, false), + ]; + let pool = JsRuntimePool::new( + agg_name.clone(), + r#" +export function create_state() { + return {sum: 0, weight: 0}; +} +export function accumulate(state, value, weight) { + state.sum += value * weight; + state.weight += weight; + return state; +} +export function retract(state, value, weight) { + state.sum -= value * weight; + state.weight -= weight; + return state; +} +export function merge(state1, state2) { + state1.sum += state2.sum; + state1.weight += state2.weight; + return state1; +} +export function finish(state) { + return state.sum / state.weight; +} + "# + .to_string(), + ArrowType::Struct(fields.clone().into()), + Float32Type::data_type(), + ); + + let state = pool.call(|runtime| runtime.create_state(&agg_name))?; + + let want: Arc = Arc::new(StructArray::new( + fields.into(), + vec![ + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int64Array::from(vec![0])), + ], + None, + )); + + assert_eq!(&want, &state); + Ok(()) + } + + #[cfg(feature = "python-udf")] + #[test] + fn test_python_runtime() -> Result<()> { + use databend_common_expression::types::Int32Type; + + let code = Vec::from( + r#" +class State: + def __init__(self): + self.sum = 0 + self.weight = 0 + +def create_state(): + return State() + +def accumulate(state, value, weight): + state.sum += value * weight + state.weight += weight + return state + +def merge(state1, state2): + state1.sum += state2.sum + state1.weight += state2.weight + return state1 + +def finish(state): + if state.weight == 0: + return None + else: + return state.sum / state.weight +"#, + ) + .into_boxed_slice(); + + let script = UDFScriptCode { + language: UDFLanguage::Python, + code: code.into(), + runtime_version: "3.12".to_string(), + }; + let name = "test".to_string(); + let display_name = "test".to_string(); + let state_fields = vec![ + DataField::new("sum", Int32Type::data_type()), + DataField::new("weight", Int32Type::data_type()), + ]; + let arguments = vec![DataField::new("value", Int32Type::data_type())]; + let output_type = Float32Type::data_type(); + create_udaf_script_function( + &script, + name, + display_name, + state_fields, + arguments, + output_type, + )?; + Ok(()) + } +} diff --git a/src/query/service/src/pipelines/processors/transforms/mod.rs b/src/query/service/src/pipelines/processors/transforms/mod.rs index 519fab4ad6d9..55078470553b 100644 --- a/src/query/service/src/pipelines/processors/transforms/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/mod.rs @@ -66,3 +66,16 @@ pub use transform_stream_sort_spill::*; pub use transform_udf_script::TransformUdfScript; pub use transform_udf_server::TransformUdfServer; pub use window::*; + +#[cfg(feature = "python-udf")] +mod python_udf { + use std::sync::Arc; + use std::sync::LazyLock; + + use arrow_udf_python::Runtime; + use parking_lot::RwLock; + + /// python runtime should be only initialized once by gil lock, see: https://github.com/python/cpython/blob/main/Python/pystate.c + pub static GLOBAL_PYTHON_RUNTIME: LazyLock>> = + LazyLock::new(|| Arc::new(RwLock::new(Runtime::new().unwrap()))); +} diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs index caf34b7c5ac2..9c499496806f 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_script.rs @@ -33,24 +33,25 @@ use databend_common_expression::DataSchema; use databend_common_expression::FunctionContext; use databend_common_pipeline_transforms::processors::Transform; use databend_common_sql::executor::physical_plans::UdfFunctionDesc; +use databend_common_sql::plans::UDFLanguage; +use databend_common_sql::plans::UDFScriptCode; use databend_common_sql::plans::UDFType; use parking_lot::RwLock; -/// python runtime should be only initialized once by gil lock, see: https://github.com/python/cpython/blob/main/Python/pystate.c #[cfg(feature = "python-udf")] -static GLOBAL_PYTHON_RUNTIME: std::sync::LazyLock>> = - std::sync::LazyLock::new(|| Arc::new(RwLock::new(arrow_udf_python::Runtime::new().unwrap()))); +use super::python_udf::GLOBAL_PYTHON_RUNTIME; pub enum ScriptRuntime { JavaScript(Vec>>), WebAssembly(Arc>), + #[cfg(feature = "python-udf")] Python, } impl ScriptRuntime { - pub fn try_create(lang: &str, code: Option<&[u8]>, runtime_num: usize) -> Result { + pub fn try_create(lang: UDFLanguage, code: Option<&[u8]>, runtime_num: usize) -> Result { match lang { - "javascript" => { + UDFLanguage::JavaScript => { // Create multiple runtimes to execute in parallel to avoid blocking caused by js udf runtime locks. let runtimes = (0..runtime_num) .map(|_| { @@ -65,21 +66,21 @@ impl ScriptRuntime { Arc::new(RwLock::new(runtime)) }) .map_err(|err| { - ErrorCode::UDFDataError(format!( - "Cannot create js runtime: {}", - err + ErrorCode::UDFRuntimeError(format!( + "Cannot create js runtime: {err}", )) }) }) .collect::>>>>()?; Ok(Self::JavaScript(runtimes)) } - "wasm" => Self::create_wasm_runtime(code), - "python" => Ok(Self::Python), - _ => Err(ErrorCode::from_string(format!( - "Invalid {} lang Runtime not supported", - lang - ))), + UDFLanguage::WebAssembly => Self::create_wasm_runtime(code), + #[cfg(feature = "python-udf")] + UDFLanguage::Python => Ok(Self::Python), + #[cfg(not(feature = "python-udf"))] + UDFLanguage::Python => Err(ErrorCode::EnterpriseFeatureNotEnable( + "Failed to create python script udf", + )), } } @@ -88,7 +89,7 @@ impl ScriptRuntime { .ok_or_else(|| ErrorCode::UDFDataError("WASM module not provided".to_string()))?; let runtime = arrow_udf_wasm::Runtime::new(decoded_code_blob).map_err(|err| { - ErrorCode::UDFDataError(format!("Failed to create WASM runtime for module: {}", err)) + ErrorCode::UDFRuntimeError(format!("Failed to create WASM runtime for module: {err}")) })?; Ok(ScriptRuntime::WebAssembly(Arc::new(RwLock::new(runtime)))) @@ -128,12 +129,6 @@ impl ScriptRuntime { &func.func_name, )?; } - #[cfg(not(feature = "python-udf"))] - ScriptRuntime::Python => { - return Err(ErrorCode::EnterpriseFeatureNotEnable( - "Failed to create python script udf", - )); - } // Ignore the execution for WASM context ScriptRuntime::WebAssembly(_) => {} } @@ -154,9 +149,9 @@ impl ScriptRuntime { let runtime = &runtimes[idx]; let runtime = runtime.read(); runtime.call(&func.name, input_batch).map_err(|err| { - ErrorCode::UDFDataError(format!( - "JavaScript UDF '{}' execution failed: {}", - func.name, err + ErrorCode::UDFRuntimeError(format!( + "JavaScript UDF {:?} execution failed: {err}", + func.name )) })? } @@ -164,24 +159,18 @@ impl ScriptRuntime { ScriptRuntime::Python => { let runtime = GLOBAL_PYTHON_RUNTIME.read(); runtime.call(&func.name, input_batch).map_err(|err| { - ErrorCode::UDFDataError(format!( - "Python UDF '{}' execution failed: {}", - func.name, err + ErrorCode::UDFRuntimeError(format!( + "Python UDF {:?} execution failed: {err}", + func.name )) })? } - #[cfg(not(feature = "python-udf"))] - ScriptRuntime::Python => { - return Err(ErrorCode::EnterpriseFeatureNotEnable( - "Failed to execute python script udf", - )); - } ScriptRuntime::WebAssembly(runtime) => { let runtime = runtime.read(); runtime.call(&func.func_name, input_batch).map_err(|err| { - ErrorCode::UDFDataError(format!( - "WASM UDF '{}' execution failed: {}", - func.func_name, err + ErrorCode::UDFRuntimeError(format!( + "WASM UDF {:?} execution failed: {err}", + func.func_name )) })? } @@ -235,7 +224,7 @@ impl Transform for TransformUdfScript { self.update_datablock(func, result_batch, &mut data_block)?; } else { return Err(ErrorCode::UDFDataError(format!( - "Failed to find runtime for function '{}' with key: {}", + "Failed to find runtime for function {:?} with key: {:?}", func.name, runtime_key ))); } @@ -247,16 +236,16 @@ impl Transform for TransformUdfScript { impl TransformUdfScript { fn get_runtime_key(func: &UdfFunctionDesc) -> Result { let (lang, func_name) = match &func.udf_type { - UDFType::Script((lang, _, _)) => (lang, &func.func_name), + UDFType::Script(UDFScriptCode { language: lang, .. }) => (lang, &func.func_name), _ => { return Err(ErrorCode::UDFDataError(format!( - "Unsupported UDFType variant for function '{}'", + "Unsupported UDFType variant for function {:?}", func.name ))); } }; - let runtime_key = format!("{}-{}", lang.trim(), func_name.trim()); + let runtime_key = format!("{}-{}", lang, func_name.trim()); Ok(runtime_key) } @@ -268,8 +257,10 @@ impl TransformUdfScript { let start = std::time::Instant::now(); for func in funcs { - let (lang, code_opt) = match &func.udf_type { - UDFType::Script((lang, _, code)) => (lang, Some(code.as_ref())), + let (&lang, code_opt) = match &func.udf_type { + UDFType::Script(UDFScriptCode { language, code, .. }) => { + (language, Some(code.as_ref().as_ref())) + } _ => continue, }; @@ -277,20 +268,19 @@ impl TransformUdfScript { let runtime = match script_runtimes.entry(runtime_key.clone()) { Entry::Occupied(entry) => entry.into_mut().clone(), Entry::Vacant(entry) => { - let new_runtime = ScriptRuntime::try_create(lang.trim(), code_opt, runtime_num) + let new_runtime = ScriptRuntime::try_create(lang, code_opt, runtime_num) .map(Arc::new) .map_err(|err| { ErrorCode::UDFDataError(format!( - "Failed to create UDF runtime for language '{}' with error: {}", - lang, err + "Failed to create UDF runtime for language {lang:?} with error: {err}", )) })?; entry.insert(new_runtime).clone() } }; - if let UDFType::Script((_, _, code)) = &func.udf_type { - runtime.add_function_with_handler(func, code)?; + if let UDFType::Script(UDFScriptCode { code, .. }) = &func.udf_type { + runtime.add_function_with_handler(func, code.as_ref().as_ref())?; } } diff --git a/src/query/service/tests/it/sql/planner/builders/binder.rs b/src/query/service/tests/it/sql/planner/builders/binder.rs index a3829436c27a..f40abdf7c881 100644 --- a/src/query/service/tests/it/sql/planner/builders/binder.rs +++ b/src/query/service/tests/it/sql/planner/builders/binder.rs @@ -26,3 +26,16 @@ async fn test_query_kind() -> Result<()> { assert_eq!(kind, QueryKind::CopyIntoTable); Ok(()) } + +#[tokio::test(flavor = "multi_thread")] +async fn test_planner() -> Result<()> { + let fixture = TestFixture::setup().await?; + + let ctx = fixture.new_query_ctx().await?; + let mut planner = Planner::new(ctx.clone()); + planner + .plan_sql("SELECT avg(number) from numbers_mt(10000)") + .await?; + + Ok(()) +} diff --git a/src/query/sql/src/executor/physical_plans/common.rs b/src/query/sql/src/executor/physical_plans/common.rs index 2670ca5e93ea..230b5f2af50d 100644 --- a/src/query/sql/src/executor/physical_plans/common.rs +++ b/src/query/sql/src/executor/physical_plans/common.rs @@ -15,28 +15,22 @@ use std::fmt::Display; use std::fmt::Formatter; -use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::Scalar; -use databend_common_functions::aggregates::AggregateFunctionFactory; +use crate::plans::UDFField; +use crate::plans::UDFType; use crate::IndexType; #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct AggregateFunctionSignature { pub name: String, + pub udaf: Option<(UDFType, Vec)>, + pub return_type: DataType, pub params: Vec, pub args: Vec, } -impl AggregateFunctionSignature { - pub fn return_type(&self) -> Result { - AggregateFunctionFactory::instance() - .get(&self.name, self.params.clone(), self.args.clone())? - .return_type() - } -} - #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] pub struct AggregateFunctionDesc { pub sig: AggregateFunctionSignature, diff --git a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs index fdab0c47810c..8f1d00ed51dc 100644 --- a/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs +++ b/src/query/sql/src/executor/physical_plans/physical_aggregate_final.rs @@ -55,7 +55,7 @@ impl AggregateFinal { pub fn output_schema(&self) -> Result { let mut fields = Vec::with_capacity(self.agg_funcs.len() + self.group_by.len()); for agg in self.agg_funcs.iter() { - let data_type = agg.sig.return_type()?; + let data_type = agg.sig.return_type.clone(); fields.push(DataField::new(&agg.output_column.to_string(), data_type)); } for id in self.group_by.iter() { @@ -120,38 +120,90 @@ impl PhysicalPlanBuilder { .map(|item| Ok(item.scalar.as_expr()?.sql_display())) .collect::>>()?; - let mut agg_funcs: Vec = agg.aggregate_functions.iter().map(|v| { - if let ScalarExpr::AggregateFunction(agg) = &v.scalar { - Ok(AggregateFunctionDesc { - sig: AggregateFunctionSignature { - name: agg.func_name.clone(), - args: agg.args.iter().map(|s| { - if let ScalarExpr::BoundColumnRef(col) = s { - Ok(input_schema.field_with_name(&col.column.index.to_string())?.data_type().clone()) + let mut agg_funcs: Vec = agg + .aggregate_functions + .iter() + .map(|v| match &v.scalar { + ScalarExpr::AggregateFunction(agg) => { + let arg_indices = agg + .args + .iter() + .map(|arg| { + if let ScalarExpr::BoundColumnRef(col) = arg { + Ok(col.column.index) } else { Err(ErrorCode::Internal( - "Aggregate function argument must be a BoundColumnRef".to_string() + "Aggregate function argument must be a BoundColumnRef" + .to_string(), )) } - }).collect::>()?, - params: agg.params.clone(), - }, - output_column: v.index, - arg_indices: agg.args.iter().map(|arg| { - if let ScalarExpr::BoundColumnRef(col) = arg { - Ok(col.column.index) - } else { - Err(ErrorCode::Internal( - "Aggregate function argument must be a BoundColumnRef".to_string() - )) - } - }).collect::>()?, - display: v.scalar.as_expr()?.sql_display(), - }) - } else { - Err(ErrorCode::Internal("Expected aggregate function".to_string())) - } - }).collect::>()?; + }) + .collect::>>()?; + let args = arg_indices + .iter() + .map(|i| { + Ok(input_schema + .field_with_name(&i.to_string())? + .data_type() + .clone()) + }) + .collect::>()?; + Ok(AggregateFunctionDesc { + sig: AggregateFunctionSignature { + name: agg.func_name.clone(), + udaf: None, + return_type: *agg.return_type.clone(), + args, + params: agg.params.clone(), + }, + output_column: v.index, + arg_indices, + display: v.scalar.as_expr()?.sql_display(), + }) + } + ScalarExpr::UDAFCall(udaf) => { + let arg_indices = udaf + .arguments + .iter() + .map(|arg| { + if let ScalarExpr::BoundColumnRef(col) = arg { + Ok(col.column.index) + } else { + Err(ErrorCode::Internal( + "Aggregate function argument must be a BoundColumnRef" + .to_string(), + )) + } + }) + .collect::>>()?; + let args = arg_indices + .iter() + .map(|i| { + Ok(input_schema + .field_with_name(&i.to_string())? + .data_type() + .clone()) + }) + .collect::>()?; + + Ok(AggregateFunctionDesc { + sig: AggregateFunctionSignature { + name: udaf.name.clone(), + udaf: Some((udaf.udf_type.clone(), udaf.state_fields.clone())), + return_type: *udaf.return_type.clone(), + args, + params: vec![], + }, + output_column: v.index, + arg_indices, + display: v.scalar.as_expr()?.sql_display(), + }) + } + _ => Err(ErrorCode::Internal( + "Expected aggregate function".to_string(), + )), + }) + .collect::>()?; let settings = self.ctx.get_settings(); let group_by_shuffle_mode = settings.get_group_by_shuffle_mode()?; @@ -302,38 +354,90 @@ impl PhysicalPlanBuilder { } }; - let mut agg_funcs: Vec = agg.aggregate_functions.iter().map(|v| { - if let ScalarExpr::AggregateFunction(agg) = &v.scalar { - Ok(AggregateFunctionDesc { - sig: AggregateFunctionSignature { - name: agg.func_name.clone(), - args: agg.args.iter().map(|s| { - if let ScalarExpr::BoundColumnRef(col) = s { - Ok(input_schema.field_with_name(&col.column.index.to_string())?.data_type().clone()) + let mut agg_funcs: Vec = agg + .aggregate_functions + .iter() + .map(|v| match &v.scalar { + ScalarExpr::AggregateFunction(agg) => { + let arg_indices = agg + .args + .iter() + .map(|arg| { + if let ScalarExpr::BoundColumnRef(col) = arg { + Ok(col.column.index) } else { Err(ErrorCode::Internal( - "Aggregate function argument must be a BoundColumnRef".to_string() + "Aggregate function argument must be a BoundColumnRef" + .to_string(), )) } - }).collect::>()?, - params: agg.params.clone(), - }, - output_column: v.index, - arg_indices: agg.args.iter().map(|arg| { - if let ScalarExpr::BoundColumnRef(col) = arg { - Ok(col.column.index) - } else { - Err(ErrorCode::Internal( - "Aggregate function argument must be a BoundColumnRef".to_string() - )) - } - }).collect::>()?, - display: v.scalar.as_expr()?.sql_display(), - }) - } else { - Err(ErrorCode::Internal("Expected aggregate function".to_string())) - } - }).collect::>()?; + }) + .collect::>>()?; + let args = arg_indices + .iter() + .map(|i| { + Ok(input_schema + .field_with_name(&i.to_string())? + .data_type() + .clone()) + }) + .collect::>()?; + Ok(AggregateFunctionDesc { + sig: AggregateFunctionSignature { + name: agg.func_name.clone(), + udaf: None, + return_type: *agg.return_type.clone(), + args, + params: agg.params.clone(), + }, + output_column: v.index, + arg_indices, + display: v.scalar.as_expr()?.sql_display(), + }) + } + ScalarExpr::UDAFCall(udaf) => { + let arg_indices = udaf + .arguments + .iter() + .map(|arg| { + if let ScalarExpr::BoundColumnRef(col) = arg { + Ok(col.column.index) + } else { + Err(ErrorCode::Internal( + "Aggregate function argument must be a BoundColumnRef" + .to_string(), + )) + } + }) + .collect::>>()?; + let args = arg_indices + .iter() + .map(|i| { + Ok(input_schema + .field_with_name(&i.to_string())? + .data_type() + .clone()) + }) + .collect::>()?; + + Ok(AggregateFunctionDesc { + sig: AggregateFunctionSignature { + name: udaf.name.clone(), + udaf: Some((udaf.udf_type.clone(), udaf.state_fields.clone())), + return_type: *udaf.return_type.clone(), + args, + params: vec![], + }, + output_column: v.index, + arg_indices, + display: v.scalar.as_expr()?.sql_display(), + }) + } + _ => Err(ErrorCode::Internal( + "Expected aggregate function".to_string(), + )), + }) + .collect::>()?; if let Some(grouping_sets) = agg.grouping_sets.as_ref() { // The argument types are wrapped nullable due to `AggregateExpand` plan. We should recover them to original types. diff --git a/src/query/sql/src/executor/physical_plans/physical_udf.rs b/src/query/sql/src/executor/physical_plans/physical_udf.rs index 5ae129c7296b..af379e6895e0 100644 --- a/src/query/sql/src/executor/physical_plans/physical_udf.rs +++ b/src/query/sql/src/executor/physical_plans/physical_udf.rs @@ -128,7 +128,7 @@ impl PhysicalPlanBuilder { let udf_func = UdfFunctionDesc { name: func.name.clone(), - func_name: func.func_name.clone(), + func_name: func.handler.clone(), output_column: item.index, arg_indices, arg_exprs, diff --git a/src/query/sql/src/executor/physical_plans/physical_window.rs b/src/query/sql/src/executor/physical_plans/physical_window.rs index f8eeb530c4a1..4a43b3190697 100644 --- a/src/query/sql/src/executor/physical_plans/physical_window.rs +++ b/src/query/sql/src/executor/physical_plans/physical_window.rs @@ -85,18 +85,19 @@ pub enum WindowFunction { impl WindowFunction { fn data_type(&self) -> Result { - match self { - WindowFunction::Aggregate(agg) => agg.sig.return_type(), + let return_type = match self { + WindowFunction::Aggregate(agg) => agg.sig.return_type.clone(), WindowFunction::RowNumber | WindowFunction::Rank | WindowFunction::DenseRank => { - Ok(DataType::Number(NumberDataType::UInt64)) + DataType::Number(NumberDataType::UInt64) } WindowFunction::PercentRank | WindowFunction::CumeDist => { - Ok(DataType::Number(NumberDataType::Float64)) + DataType::Number(NumberDataType::Float64) } - WindowFunction::LagLead(f) => Ok(f.return_type.clone()), - WindowFunction::NthValue(f) => Ok(f.return_type.clone()), - WindowFunction::Ntile(f) => Ok(f.return_type.clone()), - } + WindowFunction::LagLead(f) => f.return_type.clone(), + WindowFunction::NthValue(f) => f.return_type.clone(), + WindowFunction::Ntile(f) => f.return_type.clone(), + }; + Ok(return_type) } } @@ -262,6 +263,8 @@ impl PhysicalPlanBuilder { WindowFuncType::Aggregate(agg) => WindowFunction::Aggregate(AggregateFunctionDesc { sig: AggregateFunctionSignature { name: agg.func_name.clone(), + udaf: None, + return_type: *agg.return_type.clone(), args: agg .args .iter() diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index 151608950d4e..7b8811e8f44f 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -52,6 +52,7 @@ use crate::plans::FunctionCall; use crate::plans::GroupingSets; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; +use crate::plans::UDAFCall; use crate::plans::Visitor; use crate::plans::VisitorMut; use crate::BindContext; @@ -140,7 +141,7 @@ pub struct AggregateInfo { /// Mapping: (aggregate function display name) -> (index of agg func in `aggregate_functions`) /// This is used to find a aggregate function in current context. - pub aggregate_functions_map: HashMap, + aggregate_functions_map: HashMap, /// Mapping: (group item) -> (index of group item in `group_items`) /// This is used to check if a scalar expression is a group item. @@ -153,6 +154,20 @@ pub struct AggregateInfo { pub grouping_sets: Option, } +impl AggregateInfo { + fn push_aggregate_function(&mut self, item: ScalarItem, display_name: String) { + self.aggregate_functions.push(item); + self.aggregate_functions_map + .insert(display_name, self.aggregate_functions.len() - 1); + } + + pub fn get_aggregate_function(&self, display_name: &str) -> Option<&ScalarItem> { + self.aggregate_functions_map + .get(display_name) + .map(|index| &self.aggregate_functions[*index]) + } +} + pub(super) struct AggregateRewriter<'a> { pub bind_context: &'a mut BindContext, pub metadata: MetadataRef, @@ -169,73 +184,16 @@ impl<'a> AggregateRewriter<'a> { /// Replace the arguments of aggregate function with a BoundColumnRef, and /// add the replaced aggregate function and the arguments into `AggregateInfo`. fn replace_aggregate_function(&mut self, aggregate: &AggregateFunction) -> Result { - let agg_info = &mut self.bind_context.aggregate_info; - - if let Some(column) = - find_replaced_aggregate_function(agg_info, aggregate, &aggregate.display_name) - { + if let Some(column) = find_replaced_aggregate_function( + &self.bind_context.aggregate_info, + &aggregate.display_name, + &aggregate.return_type, + &aggregate.display_name, + ) { return Ok(BoundColumnRef { span: None, column }.into()); } - let mut replaced_args: Vec = Vec::with_capacity(aggregate.args.len()); - - for (i, arg) in aggregate.args.iter().enumerate() { - let name = format!("{}_arg_{}", &aggregate.func_name, i); - if let ScalarExpr::BoundColumnRef(column_ref) = arg { - replaced_args.push(column_ref.clone().into()); - agg_info.aggregate_arguments.push(ScalarItem { - index: column_ref.column.index, - scalar: arg.clone(), - }); - } else if let Some(item) = agg_info - .group_items - .iter() - .chain(agg_info.aggregate_arguments.iter()) - .find(|x| &x.scalar == arg) - { - // check if the arg is in group items - // we can reuse the index - let column_binding = ColumnBindingBuilder::new( - name, - item.index, - Box::new(arg.data_type()?), - Visibility::Visible, - ) - .build(); - - replaced_args.push(ScalarExpr::BoundColumnRef(BoundColumnRef { - span: arg.span(), - column: column_binding, - })); - } else { - let index = self.metadata.write().add_derived_column( - name.clone(), - arg.data_type()?, - Some(arg.clone()), - ); - - // Generate a ColumnBinding for each argument of aggregates - let column_binding = ColumnBindingBuilder::new( - name, - index, - Box::new(arg.data_type()?), - Visibility::Visible, - ) - .build(); - - replaced_args.push( - BoundColumnRef { - span: arg.span(), - column: column_binding.clone(), - } - .into(), - ); - agg_info.aggregate_arguments.push(ScalarItem { - index, - scalar: arg.clone(), - }); - } - } + let replaced_args = self.replace_function_args(&aggregate.args, &aggregate.func_name)?; let index = self.metadata.write().add_derived_column( aggregate.display_name.clone(), @@ -253,18 +211,131 @@ impl<'a> AggregateRewriter<'a> { return_type: aggregate.return_type.clone(), }; - agg_info.aggregate_functions.push(ScalarItem { - scalar: replaced_agg.clone().into(), - index, - }); - agg_info.aggregate_functions_map.insert( + self.bind_context.aggregate_info.push_aggregate_function( + ScalarItem { + scalar: replaced_agg.clone().into(), + index, + }, replaced_agg.display_name.clone(), - agg_info.aggregate_functions.len() - 1, ); Ok(replaced_agg.into()) } + fn replace_udaf_call(&mut self, udaf: &UDAFCall) -> Result { + if let Some(column) = find_replaced_aggregate_function( + &self.bind_context.aggregate_info, + &udaf.display_name, + &udaf.return_type, + &udaf.display_name, + ) { + return Ok(BoundColumnRef { span: None, column }.into()); + } + + let replaced_args = self.replace_function_args(&udaf.arguments, &udaf.name)?; + + let index = self.metadata.write().add_derived_column( + udaf.display_name.clone(), + *udaf.return_type.clone(), + Some(ScalarExpr::UDAFCall(udaf.clone())), + ); + + let replaced_udaf = UDAFCall { + span: udaf.span, + name: udaf.name.clone(), + display_name: udaf.display_name.clone(), + arg_types: udaf.arg_types.clone(), + state_fields: udaf.state_fields.clone(), + return_type: udaf.return_type.clone(), + arguments: replaced_args, + udf_type: udaf.udf_type.clone(), + }; + + self.bind_context.aggregate_info.push_aggregate_function( + ScalarItem { + scalar: replaced_udaf.clone().into(), + index, + }, + replaced_udaf.display_name.clone(), + ); + + Ok(replaced_udaf.into()) + } + + fn replace_function_args( + &mut self, + args: &[ScalarExpr], + func_name: &str, + ) -> Result> { + let AggregateInfo { + ref mut aggregate_arguments, + ref group_items, + .. + } = self.bind_context.aggregate_info; + + args.iter() + .enumerate() + .map(|(i, arg)| { + let name = format!("{}_arg_{}", func_name, i); + let data_type = arg.data_type()?; + if let ScalarExpr::BoundColumnRef(column_ref) = arg { + aggregate_arguments.push(ScalarItem { + index: column_ref.column.index, + scalar: arg.clone(), + }); + return Ok(column_ref.clone()); + } + + if let Some(item) = group_items + .iter() + .chain(aggregate_arguments.iter()) + .find(|x| &x.scalar == arg) + { + // check if the arg is in group items + // we can reuse the index + let column_binding = ColumnBindingBuilder::new( + name, + item.index, + Box::new(data_type), + Visibility::Visible, + ) + .build(); + + return Ok(BoundColumnRef { + span: arg.span(), + column: column_binding, + }); + } + + let index = self.metadata.write().add_derived_column( + name.clone(), + data_type.clone(), + Some(arg.clone()), + ); + + // Generate a ColumnBinding for each argument of aggregates + let column_binding = ColumnBindingBuilder::new( + name, + index, + Box::new(data_type), + Visibility::Visible, + ) + .build(); + + aggregate_arguments.push(ScalarItem { + index, + scalar: arg.clone(), + }); + + Ok(BoundColumnRef { + span: arg.span(), + column: column_binding.clone(), + }) + }) + .map(|x| x.map(|x| x.into())) + .collect() + } + fn replace_grouping(&mut self, function: &FunctionCall) -> Result { let agg_info = &mut self.bind_context.aggregate_info; if agg_info.grouping_sets.is_none() { @@ -314,12 +385,17 @@ impl<'a> AggregateRewriter<'a> { impl<'a> VisitorMut<'a> for AggregateRewriter<'a> { fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> { - if let ScalarExpr::AggregateFunction(aggregate) = expr { - *expr = self.replace_aggregate_function(aggregate)?; - return Ok(()); + match expr { + ScalarExpr::AggregateFunction(aggregate) => { + *expr = self.replace_aggregate_function(aggregate)?; + Ok(()) + } + ScalarExpr::UDAFCall(udaf) => { + *expr = self.replace_udaf_call(udaf)?; + Ok(()) + } + _ => walk_expr_mut(self, expr), } - - walk_expr_mut(self, expr) } fn visit_function_call(&mut self, func: &'a mut FunctionCall) -> Result<()> { @@ -941,20 +1017,19 @@ impl Binder { /// Replace [`AggregateFunction`] with a [`ColumnBinding`] if the function is already replaced. pub fn find_replaced_aggregate_function( agg_info: &AggregateInfo, - agg: &AggregateFunction, + display_name: &str, + return_type: &DataType, new_name: &str, ) -> Option { agg_info - .aggregate_functions_map - .get(&agg.display_name) - .map(|i| { + .get_aggregate_function(display_name) + .map(|scalar_item| { // This expression is already replaced. - let scalar_item = &agg_info.aggregate_functions[*i]; - debug_assert_eq!(scalar_item.scalar.data_type().unwrap(), *agg.return_type); + debug_assert_eq!(&scalar_item.scalar.data_type().unwrap(), return_type); ColumnBindingBuilder::new( new_name.to_string(), scalar_item.index, - agg.return_type.clone(), + Box::new(return_type.clone()), Visibility::Visible, ) .build() diff --git a/src/query/sql/src/planner/binder/project.rs b/src/query/sql/src/planner/binder/project.rs index d06b058a4ae5..0210567367c8 100644 --- a/src/query/sql/src/planner/binder/project.rs +++ b/src/query/sql/src/planner/binder/project.rs @@ -102,7 +102,13 @@ impl Binder { ScalarExpr::AggregateFunction(agg) => { // Replace to bound column to reduce duplicate derived column bindings. debug_assert!(!is_grouping_sets_item); - find_replaced_aggregate_function(agg_info, agg, &item.alias).unwrap() + find_replaced_aggregate_function( + agg_info, + &agg.display_name, + &agg.return_type, + &item.alias, + ) + .unwrap() } ScalarExpr::WindowFunction(win) => { find_replaced_window_function(window_info, win, &item.alias).unwrap() diff --git a/src/query/sql/src/planner/binder/project_set.rs b/src/query/sql/src/planner/binder/project_set.rs index 29818a3162d6..1d00ca883d0b 100644 --- a/src/query/sql/src/planner/binder/project_set.rs +++ b/src/query/sql/src/planner/binder/project_set.rs @@ -176,13 +176,11 @@ impl<'a> VisitorMut<'a> for SetReturningRewriter<'a> { if let ScalarExpr::AggregateFunction(agg_func) = expr { self.is_lazy_srf = true; - if let Some(index) = self + if let Some(agg_item) = self .bind_context .aggregate_info - .aggregate_functions_map - .get(&agg_func.display_name) + .get_aggregate_function(&agg_func.display_name) { - let agg_item = &self.bind_context.aggregate_info.aggregate_functions[*index]; let column_binding = ColumnBindingBuilder::new( agg_func.display_name.clone(), agg_item.index, diff --git a/src/query/sql/src/planner/binder/qualify.rs b/src/query/sql/src/planner/binder/qualify.rs index 2ddc083e22b3..f3fa1bf348e8 100644 --- a/src/query/sql/src/planner/binder/qualify.rs +++ b/src/query/sql/src/planner/binder/qualify.rs @@ -145,29 +145,27 @@ impl VisitorMut<'_> for QualifyChecker<'_> { } if let ScalarExpr::AggregateFunction(agg) = expr { - if let Some(column) = self + let Some(agg_func) = self .bind_context .aggregate_info - .aggregate_functions_map - .get(&agg.display_name) - { - let agg_func = &self.bind_context.aggregate_info.aggregate_functions[*column]; - let column_binding = ColumnBindingBuilder::new( - agg.display_name.clone(), - agg_func.index, - Box::new(agg_func.scalar.data_type()?), - Visibility::Visible, - ) - .build(); - *expr = BoundColumnRef { - span: None, - column: column_binding, - } - .into(); - return Ok(()); + .get_aggregate_function(&agg.display_name) + else { + return Err(ErrorCode::Internal("Invalid aggregate function")); + }; + + let column_binding = ColumnBindingBuilder::new( + agg.display_name.clone(), + agg_func.index, + Box::new(agg_func.scalar.data_type()?), + Visibility::Visible, + ) + .build(); + *expr = BoundColumnRef { + span: None, + column: column_binding, } - - return Err(ErrorCode::Internal("Invalid aggregate function")); + .into(); + return Ok(()); } walk_expr_mut(self, expr) diff --git a/src/query/sql/src/planner/binder/sort.rs b/src/query/sql/src/planner/binder/sort.rs index a036e42939b0..86e463acfeef 100644 --- a/src/query/sql/src/planner/binder/sort.rs +++ b/src/query/sql/src/planner/binder/sort.rs @@ -312,7 +312,7 @@ impl Binder { Ok(UDFCall { span: udf.span, name: udf.name.clone(), - func_name: udf.func_name.clone(), + handler: udf.handler.clone(), display_name: udf.display_name.clone(), udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs index f17e27d2cb94..bb49195f3fbd 100644 --- a/src/query/sql/src/planner/binder/udf.rs +++ b/src/query/sql/src/planner/binder/udf.rs @@ -18,12 +18,16 @@ use chrono::Utc; use databend_common_ast::ast::AlterUDFStmt; use databend_common_ast::ast::CreateUDFStmt; use databend_common_ast::ast::Identifier; +use databend_common_ast::ast::TypeName; +use databend_common_ast::ast::UDAFStateField; use databend_common_ast::ast::UDFDefinition; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::udf_client::UDFFlightClient; +use databend_common_expression::DataField; use databend_common_meta_app::principal::LambdaUDF; +use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition as PlanUDFDefinition; use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; @@ -31,23 +35,18 @@ use databend_common_meta_app::principal::UserDefinedFunction; use crate::normalize_identifier; use crate::optimizer::SExpr; -use crate::planner::resolve_type_name; +use crate::planner::resolve_type_name_udf; use crate::planner::udf_validator::UDFValidator; use crate::plans::AlterUDFPlan; use crate::plans::CreateUDFPlan; use crate::plans::DropUDFPlan; use crate::plans::Plan; +use crate::plans::UDFLanguage; use crate::BindContext; use crate::Binder; use crate::UdfRewriter; impl Binder { - fn is_allowed_language(language: &str) -> bool { - let allowed_languages: HashSet<&str> = - ["javascript", "wasm", "python"].iter().cloned().collect(); - allowed_languages.contains(&language.to_lowercase().as_str()) - } - pub(in crate::planner::binder) async fn bind_udf_definition( &mut self, udf_name: &Identifier, @@ -55,6 +54,7 @@ impl Binder { udf_definition: &UDFDefinition, ) -> Result { let name = normalize_identifier(udf_name, &self.name_resolution_ctx).to_string(); + let description = udf_description.clone().unwrap_or_default(); match udf_definition { UDFDefinition::LambdaUDF { parameters, @@ -68,7 +68,7 @@ impl Binder { validator.verify_definition_expr(definition)?; Ok(UserDefinedFunction { name: validator.name, - description: udf_description.clone().unwrap_or_default(), + description, definition: PlanUDFDefinition::LambdaUDF(LambdaUDF { parameters: validator.parameters, definition: definition.to_string(), @@ -87,9 +87,9 @@ impl Binder { let mut arg_datatypes = Vec::with_capacity(arg_types.len()); for arg_type in arg_types { - arg_datatypes.push(DataType::from(&resolve_type_name(arg_type, true)?)); + arg_datatypes.push(DataType::from(&resolve_type_name_udf(arg_type)?)); } - let return_type = DataType::from(&resolve_type_name(return_type, true)?); + let return_type = DataType::from(&resolve_type_name_udf(return_type)?); let mut client = UDFFlightClient::connect( address, @@ -110,7 +110,7 @@ impl Binder { Ok(UserDefinedFunction { name, - description: udf_description.clone().unwrap_or_default(), + description, definition: PlanUDFDefinition::UDFServer(UDFServer { address: address.clone(), arg_types: arg_datatypes, @@ -121,6 +121,7 @@ impl Binder { created_on: Utc::now(), }) } + UDFDefinition::UDAFServer { .. } => unimplemented!(), UDFDefinition::UDFScript { arg_types, return_type, @@ -129,34 +130,43 @@ impl Binder { language, runtime_version, } => { - let mut arg_datatypes = Vec::with_capacity(arg_types.len()); - for arg_type in arg_types { - arg_datatypes.push(DataType::from(&resolve_type_name(arg_type, true)?)); - } - let return_type = DataType::from(&resolve_type_name(return_type, true)?); - - if !Self::is_allowed_language(language) { - return Err(ErrorCode::InvalidArgument(format!( - "Unallowed UDF language '{language}', must be python, javascript or wasm" - ))); - } - - let mut runtime_version = runtime_version.to_string(); - if runtime_version.is_empty() && language.to_lowercase() == "python" { - runtime_version = "3.12.2".to_string(); - } - + let definition = create_udf_definition_script( + arg_types, + None, + return_type, + runtime_version, + handler, + language, + code, + )?; Ok(UserDefinedFunction { name, - description: udf_description.clone().unwrap_or_default(), - definition: PlanUDFDefinition::UDFScript(UDFScript { - code: code.clone(), - arg_types: arg_datatypes, - return_type, - handler: handler.clone(), - language: language.clone(), - runtime_version, - }), + description, + definition, + created_on: Utc::now(), + }) + } + UDFDefinition::UDAFScript { + arg_types, + state_fields, + return_type, + code, + language, + runtime_version, + } => { + let definition = create_udf_definition_script( + arg_types, + Some(state_fields), + return_type, + runtime_version, + "", + language, + code, + )?; + Ok(UserDefinedFunction { + name, + description, + definition, created_on: Utc::now(), }) } @@ -221,3 +231,72 @@ impl Binder { Ok(s_expr) } } + +fn create_udf_definition_script( + arg_types: &[TypeName], + state_fields: Option<&[UDAFStateField]>, + return_type: &TypeName, + runtime_version: &str, + handler: &str, + language: &str, + code: &str, +) -> Result { + let Ok(language) = language.parse::() else { + return Err(ErrorCode::InvalidArgument(format!( + "Unallowed UDF language {language:?}, must be python, javascript or wasm" + ))); + }; + + let arg_types = arg_types + .iter() + .map(|arg_type| Ok(DataType::from(&resolve_type_name_udf(arg_type)?))) + .collect::>>()?; + + let return_type = DataType::from(&resolve_type_name_udf(return_type)?); + + let mut runtime_version = runtime_version.to_string(); + if runtime_version.is_empty() && language == UDFLanguage::Python { + runtime_version = "3.12.2".to_string(); + } + + match state_fields { + Some(fields) => { + let state_fields = fields + .iter() + .map(|field| { + Ok(DataField::new( + &field.name.name, + DataType::from(&resolve_type_name_udf(&field.type_name)?), + )) + }) + .collect::>>()?; + + let state_field_names = state_fields + .iter() + .map(|f| f.name()) + .collect::>(); + if state_field_names.len() != state_fields.len() { + return Err(ErrorCode::InvalidArgument(format!( + "Duplicate state field name in UDAF script" + ))); + } + + Ok(PlanUDFDefinition::UDAFScript(UDAFScript { + code: code.to_string(), + arg_types, + state_fields, + return_type, + language: language.to_string(), + runtime_version, + })) + } + None => Ok(PlanUDFDefinition::UDFScript(UDFScript { + code: code.to_string(), + arg_types, + return_type, + handler: handler.to_string(), + language: language.to_string(), + runtime_version, + })), + } +} diff --git a/src/query/sql/src/planner/binder/window.rs b/src/query/sql/src/planner/binder/window.rs index e43511df31b4..3fcba8d215f6 100644 --- a/src/query/sql/src/planner/binder/window.rs +++ b/src/query/sql/src/planner/binder/window.rs @@ -590,31 +590,29 @@ pub struct WindowAggregateRewriter<'a> { impl<'a> VisitorMut<'a> for WindowAggregateRewriter<'a> { fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> { if let ScalarExpr::AggregateFunction(agg_func) = expr { - if let Some(index) = self + let Some(agg) = self .bind_context .aggregate_info - .aggregate_functions_map - .get(&agg_func.display_name) - { - let agg = &self.bind_context.aggregate_info.aggregate_functions[*index]; - let column_binding = ColumnBindingBuilder::new( - agg_func.display_name.clone(), - agg.index, - agg_func.return_type.clone(), - Visibility::Visible, - ) - .build(); + .get_aggregate_function(&agg_func.display_name) + else { + return Err(ErrorCode::BadArguments("Invalid window function argument")); + }; - *expr = BoundColumnRef { - span: None, - column: column_binding, - } - .into(); + let column_binding = ColumnBindingBuilder::new( + agg_func.display_name.clone(), + agg.index, + agg_func.return_type.clone(), + Visibility::Visible, + ) + .build(); - return Ok(()); - } else { - return Err(ErrorCode::BadArguments("Invalid window function argument")); + *expr = BoundColumnRef { + span: None, + column: column_binding, } + .into(); + + return Ok(()); } walk_expr_mut(self, expr) diff --git a/src/query/sql/src/planner/format/display_rel_operator.rs b/src/query/sql/src/planner/format/display_rel_operator.rs index 23e7c4329132..bf8de3f49d8f 100644 --- a/src/query/sql/src/planner/format/display_rel_operator.rs +++ b/src/query/sql/src/planner/format/display_rel_operator.rs @@ -102,7 +102,7 @@ pub fn format_scalar(scalar: &ScalarExpr) -> String { ScalarExpr::UDFCall(udf) => { format!( "{}({})", - &udf.func_name, + &udf.handler, udf.arguments .iter() .map(format_scalar) @@ -113,6 +113,17 @@ pub fn format_scalar(scalar: &ScalarExpr) -> String { ScalarExpr::UDFLambdaCall(udf) => { format!("{}({})", &udf.func_name, format_scalar(&udf.scalar)) } + ScalarExpr::UDAFCall(udaf) => { + format!( + "{}({})", + &udaf.name, + udaf.arguments + .iter() + .map(format_scalar) + .collect::>() + .join(", ") + ) + } ScalarExpr::AsyncFunctionCall(async_func) => async_func.display_name.clone(), } } diff --git a/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs b/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs index ff242dc062e4..58edbe9906ad 100644 --- a/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs +++ b/src/query/sql/src/planner/optimizer/decorrelate/flatten_scalar.rs @@ -98,7 +98,7 @@ impl SubqueryRewriter { Ok(ScalarExpr::UDFCall(UDFCall { span: udf.span, name: udf.name.clone(), - func_name: udf.func_name.clone(), + handler: udf.handler.clone(), display_name: udf.display_name.clone(), udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), diff --git a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs index dd945f90b290..e46de1f1289c 100644 --- a/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs @@ -47,6 +47,7 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDAFCall; use crate::plans::UDFCall; use crate::plans::UDFLambdaCall; use crate::plans::WindowFuncType; @@ -388,7 +389,7 @@ impl SubqueryRewriter { let expr: ScalarExpr = UDFCall { span: udf.span, name: udf.name.clone(), - func_name: udf.func_name.clone(), + handler: udf.handler.clone(), display_name: udf.display_name.clone(), udf_type: udf.udf_type.clone(), arg_types: udf.arg_types.clone(), @@ -399,7 +400,6 @@ impl SubqueryRewriter { Ok((expr, s_expr)) } - ScalarExpr::UDFLambdaCall(udf) => { let mut s_expr = s_expr.clone(); let res = self.try_rewrite_subquery(&udf.scalar, &s_expr, false)?; @@ -412,6 +412,29 @@ impl SubqueryRewriter { } .into(); + Ok((expr, s_expr)) + } + ScalarExpr::UDAFCall(udaf) => { + let mut args = vec![]; + let mut s_expr = s_expr.clone(); + for arg in udaf.arguments.iter() { + let res = self.try_rewrite_subquery(arg, &s_expr, false)?; + s_expr = res.1; + args.push(res.0); + } + + let expr: ScalarExpr = UDAFCall { + span: udaf.span, + name: udaf.name.clone(), + display_name: udaf.display_name.clone(), + udf_type: udaf.udf_type.clone(), + arg_types: udaf.arg_types.clone(), + state_fields: udaf.state_fields.clone(), + return_type: udaf.return_type.clone(), + arguments: args, + } + .into(); + Ok((expr, s_expr)) } } diff --git a/src/query/sql/src/planner/optimizer/filter/infer_filter.rs b/src/query/sql/src/planner/optimizer/filter/infer_filter.rs index 5227b86b612e..637b03a2f680 100644 --- a/src/query/sql/src/planner/optimizer/filter/infer_filter.rs +++ b/src/query/sql/src/planner/optimizer/filter/infer_filter.rs @@ -525,6 +525,7 @@ impl<'a> InferFilterOptimizer<'a> { | ScalarExpr::SubqueryExpr(_) | ScalarExpr::UDFCall(_) | ScalarExpr::UDFLambdaCall(_) + | ScalarExpr::UDAFCall(_) | ScalarExpr::AsyncFunctionCall(_) => { // Can not replace `BoundColumnRef` or can not replace unsupported ScalarExpr. self.can_replace = false; diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs index c07ca1e32efc..e1ce43bc053b 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs @@ -1290,7 +1290,7 @@ fn format_scalar(scalar: &ScalarExpr, column_map: &HashMap format!( "{}({})", - &udf.func_name, + &udf.handler, udf.arguments .iter() .map(|arg| { format_scalar(arg, column_map) }) diff --git a/src/query/sql/src/planner/plans/mod.rs b/src/query/sql/src/planner/plans/mod.rs index e1876f090ba8..2704f874b049 100644 --- a/src/query/sql/src/planner/plans/mod.rs +++ b/src/query/sql/src/planner/plans/mod.rs @@ -48,6 +48,7 @@ mod set; mod set_priority; mod sort; mod system; +mod udaf; mod udf; mod union_all; mod window; @@ -93,6 +94,7 @@ pub use set::*; pub use set_priority::SetPriorityPlan; pub use sort::*; pub use system::*; +pub use udaf::*; pub use udf::*; pub use union_all::UnionAll; pub use window::*; diff --git a/src/query/sql/src/planner/plans/operator.rs b/src/query/sql/src/planner/plans/operator.rs index 2836f89bfd9b..93ba7863db66 100644 --- a/src/query/sql/src/planner/plans/operator.rs +++ b/src/query/sql/src/planner/plans/operator.rs @@ -111,6 +111,7 @@ pub enum RelOp { ExpressionScan, CacheScan, Udf, + Udaf, AsyncFunction, RecursiveCteScan, MergeInto, diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 36ea6e959ad8..14c88669fb1c 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -17,6 +17,7 @@ use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::hash::Hasher; +use std::str::FromStr; use std::sync::Arc; use databend_common_ast::ast::BinaryOperator; @@ -57,6 +58,7 @@ pub enum ScalarExpr { CastExpr(CastExpr), SubqueryExpr(SubqueryExpr), UDFCall(UDFCall), + UDAFCall(UDAFCall), UDFLambdaCall(UDFLambdaCall), AsyncFunctionCall(AsyncFunctionCall), } @@ -75,6 +77,7 @@ impl Clone for ScalarExpr { ScalarExpr::SubqueryExpr(v) => ScalarExpr::SubqueryExpr(v.clone()), ScalarExpr::UDFCall(v) => ScalarExpr::UDFCall(v.clone()), ScalarExpr::UDFLambdaCall(v) => ScalarExpr::UDFLambdaCall(v.clone()), + ScalarExpr::UDAFCall(v) => ScalarExpr::UDAFCall(v.clone()), ScalarExpr::AsyncFunctionCall(v) => ScalarExpr::AsyncFunctionCall(v.clone()), } } @@ -98,6 +101,7 @@ impl PartialEq for ScalarExpr { (ScalarExpr::SubqueryExpr(l), ScalarExpr::SubqueryExpr(r)) => l.eq(r), (ScalarExpr::UDFCall(l), ScalarExpr::UDFCall(r)) => l.eq(r), (ScalarExpr::UDFLambdaCall(l), ScalarExpr::UDFLambdaCall(r)) => l.eq(r), + (ScalarExpr::UDAFCall(l), ScalarExpr::UDAFCall(r)) => l.eq(r), (ScalarExpr::AsyncFunctionCall(l), ScalarExpr::AsyncFunctionCall(r)) => l.eq(r), _ => false, } @@ -121,6 +125,7 @@ impl Hash for ScalarExpr { ScalarExpr::SubqueryExpr(v) => v.hash(state), ScalarExpr::UDFCall(v) => v.hash(state), ScalarExpr::UDFLambdaCall(v) => v.hash(state), + ScalarExpr::UDAFCall(v) => v.hash(state), ScalarExpr::AsyncFunctionCall(v) => v.hash(state), } } @@ -201,6 +206,7 @@ impl ScalarExpr { ScalarExpr::SubqueryExpr(expr) => expr.span, ScalarExpr::UDFCall(expr) => expr.span, ScalarExpr::UDFLambdaCall(expr) => expr.span, + ScalarExpr::UDAFCall(expr) => expr.span, ScalarExpr::AsyncFunctionCall(expr) => expr.span, } } @@ -517,6 +523,23 @@ impl TryFrom for UDFLambdaCall { } } +impl From for ScalarExpr { + fn from(v: UDAFCall) -> Self { + Self::UDAFCall(v) + } +} + +impl TryFrom for UDAFCall { + type Error = ErrorCode; + fn try_from(value: ScalarExpr) -> Result { + if let ScalarExpr::UDAFCall(value) = value { + Ok(value) + } else { + Err(ErrorCode::Internal("Cannot downcast Scalar to UDAFCall")) + } + } +} + impl From for ScalarExpr { fn from(v: AsyncFunctionCall) -> Self { Self::AsyncFunctionCall(v) @@ -771,7 +794,7 @@ pub struct UDFCall { // name in meta pub name: String, // name in handler - pub func_name: String, + pub handler: String, pub display_name: String, pub arg_types: Vec, pub return_type: Box, @@ -779,10 +802,70 @@ pub struct UDFCall { pub udf_type: UDFType, } +#[derive(Clone, Debug, Educe)] +#[educe(PartialEq, Eq, Hash)] +pub struct UDAFCall { + #[educe(Hash(ignore), PartialEq(ignore), Eq(ignore))] + pub span: Span, + pub name: String, // name in meta + pub display_name: String, + pub arg_types: Vec, + pub state_fields: Vec, + pub return_type: Box, + pub arguments: Vec, + pub udf_type: UDFType, +} + +#[derive(Clone, Debug, Educe, serde::Serialize, serde::Deserialize)] +#[educe(PartialEq, Eq, Hash)] +pub struct UDFField { + pub name: String, + pub data_type: DataType, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +pub enum UDFLanguage { + JavaScript, + WebAssembly, + Python, +} + +impl FromStr for UDFLanguage { + type Err = ErrorCode; + + fn from_str(s: &str) -> Result { + match s.trim().to_lowercase().as_str() { + "javascript" => Ok(Self::JavaScript), + "wasm" => Ok(Self::WebAssembly), + "python" => Ok(Self::Python), + _ => Err(ErrorCode::BadArguments(format!( + "Unsupported script language: {s}" + ))), + } + } +} + +impl Display for UDFLanguage { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + UDFLanguage::JavaScript => write!(f, "javascript"), + UDFLanguage::WebAssembly => write!(f, "wasm"), + UDFLanguage::Python => write!(f, "python"), + } + } +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct UDFScriptCode { + pub language: UDFLanguage, + pub runtime_version: String, + pub code: Arc>, +} + #[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize, EnumAsInner)] pub enum UDFType { - Server(String), // server_addr - Script((String, String, Vec)), // Lang, Version, Code + Server(String), // server_addr + Script(UDFScriptCode), } impl UDFType { @@ -952,6 +1035,13 @@ pub trait Visitor<'a>: Sized { self.visit(&udf.scalar) } + fn visit_udaf_call(&mut self, udaf: &'a UDAFCall) -> Result<()> { + for expr in &udaf.arguments { + self.visit(expr)?; + } + Ok(()) + } + fn visit_async_function_call(&mut self, async_func: &'a AsyncFunctionCall) -> Result<()> { for expr in &async_func.arguments { self.visit(expr)?; @@ -972,6 +1062,7 @@ pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expr: &'a ScalarExpr) -> R ScalarExpr::SubqueryExpr(expr) => visitor.visit_subquery(expr), ScalarExpr::UDFCall(expr) => visitor.visit_udf_call(expr), ScalarExpr::UDFLambdaCall(expr) => visitor.visit_udf_lambda_call(expr), + ScalarExpr::UDAFCall(expr) => visitor.visit_udaf_call(expr), ScalarExpr::AsyncFunctionCall(expr) => visitor.visit_async_function_call(expr), } } @@ -1054,6 +1145,13 @@ pub trait VisitorMut<'a>: Sized { self.visit(&mut udf.scalar) } + fn visit_udaf_call(&mut self, udaf: &'a mut UDAFCall) -> Result<()> { + for expr in &mut udaf.arguments { + self.visit(expr)?; + } + Ok(()) + } + fn visit_async_function_call(&mut self, async_func: &'a mut AsyncFunctionCall) -> Result<()> { for expr in &mut async_func.arguments { self.visit(expr)?; @@ -1077,6 +1175,7 @@ pub fn walk_expr_mut<'a, V: VisitorMut<'a>>( ScalarExpr::SubqueryExpr(expr) => visitor.visit_subquery_expr(expr), ScalarExpr::UDFCall(expr) => visitor.visit_udf_call(expr), ScalarExpr::UDFLambdaCall(expr) => visitor.visit_udf_lambda_call(expr), + ScalarExpr::UDAFCall(expr) => visitor.visit_udaf_call(expr), ScalarExpr::AsyncFunctionCall(expr) => visitor.visit_async_function_call(expr), } } diff --git a/src/query/sql/src/planner/plans/udaf.rs b/src/query/sql/src/planner/plans/udaf.rs new file mode 100644 index 000000000000..1c5e98dadf32 --- /dev/null +++ b/src/query/sql/src/planner/plans/udaf.rs @@ -0,0 +1,102 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use databend_common_catalog::table_context::TableContext; +use databend_common_exception::Result; + +use crate::optimizer::ColumnSet; +use crate::optimizer::RelExpr; +use crate::optimizer::RelationalProperty; +use crate::optimizer::RequiredProperty; +use crate::optimizer::StatInfo; +use crate::plans::Operator; +use crate::plans::RelOp; +use crate::plans::ScalarItem; + +/// `Udaf` is a plan that evaluate a series of udaf functions. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Udaf { + pub items: Vec, + pub script_udf: bool, +} + +impl Udaf { + pub fn used_columns(&self) -> Result { + let mut used_columns = ColumnSet::new(); + for item in self.items.iter() { + used_columns.insert(item.index); + used_columns.extend(item.scalar.used_columns()); + } + Ok(used_columns) + } +} + +impl Operator for Udaf { + fn rel_op(&self) -> RelOp { + RelOp::Udaf + } + + fn derive_relational_prop(&self, rel_expr: &RelExpr) -> Result> { + let input_prop = rel_expr.derive_relational_prop_child(0)?; + + // Derive output columns + let mut output_columns = input_prop.output_columns.clone(); + for item in self.items.iter() { + output_columns.insert(item.index); + } + + // Derive outer columns + let mut outer_columns = input_prop.outer_columns.clone(); + for item in self.items.iter() { + let used_columns = item.scalar.used_columns(); + let outer = used_columns + .difference(&output_columns) + .cloned() + .collect::(); + outer_columns = outer_columns.union(&outer).cloned().collect(); + } + outer_columns = outer_columns.difference(&output_columns).cloned().collect(); + + // Derive used columns + let mut used_columns = self.used_columns()?; + used_columns.extend(input_prop.used_columns.clone()); + + // Derive orderings + let orderings = input_prop.orderings.clone(); + let partition_orderings = input_prop.partition_orderings.clone(); + + Ok(Arc::new(RelationalProperty { + output_columns, + outer_columns, + used_columns, + orderings, + partition_orderings, + })) + } + + fn derive_stats(&self, rel_expr: &RelExpr) -> Result> { + rel_expr.derive_cardinality_child(0) + } + + fn compute_required_prop_children( + &self, + _ctx: Arc, + _rel_expr: &RelExpr, + required: &RequiredProperty, + ) -> Result>> { + Ok(vec![vec![required.clone()]]) + } +} diff --git a/src/query/sql/src/planner/semantic/grouping_check.rs b/src/query/sql/src/planner/semantic/grouping_check.rs index 2a9780d49ebe..21ff5b7bc06a 100644 --- a/src/query/sql/src/planner/semantic/grouping_check.rs +++ b/src/query/sql/src/planner/semantic/grouping_check.rs @@ -113,29 +113,50 @@ impl VisitorMut<'_> for GroupingChecker<'_> { return Err(ErrorCode::Internal("Group Check: Invalid window function")); } ScalarExpr::AggregateFunction(agg) => { - if let Some(column) = self + let Some(agg_func) = self .bind_context .aggregate_info - .aggregate_functions_map - .get(&agg.display_name) - { - let agg_func = &self.bind_context.aggregate_info.aggregate_functions[*column]; - let column_binding = ColumnBindingBuilder::new( - agg.display_name.clone(), - agg_func.index, - Box::new(agg_func.scalar.data_type()?), - Visibility::Visible, - ) - .build(); - *expr = BoundColumnRef { - span: None, - column: column_binding, - } - .into(); - return Ok(()); + .get_aggregate_function(&agg.display_name) + else { + return Err(ErrorCode::Internal("Invalid aggregate function")); + }; + + let column_binding = ColumnBindingBuilder::new( + agg.display_name.clone(), + agg_func.index, + Box::new(agg_func.scalar.data_type()?), + Visibility::Visible, + ) + .build(); + *expr = BoundColumnRef { + span: None, + column: column_binding, } + .into(); + return Ok(()); + } + ScalarExpr::UDAFCall(udaf) => { + let Some(agg_func) = self + .bind_context + .aggregate_info + .get_aggregate_function(&udaf.display_name) + else { + return Err(ErrorCode::Internal("Invalid udaf function")); + }; - return Err(ErrorCode::Internal("Invalid aggregate function")); + let column_binding = ColumnBindingBuilder::new( + udaf.display_name.clone(), + agg_func.index, + Box::new(agg_func.scalar.data_type()?), + Visibility::Visible, + ) + .build(); + *expr = BoundColumnRef { + span: None, + column: column_binding, + } + .into(); + return Ok(()); } ScalarExpr::BoundColumnRef(column_ref) => { if let Some(index) = self @@ -165,8 +186,8 @@ impl VisitorMut<'_> for GroupingChecker<'_> { if self .bind_context .aggregate_info - .aggregate_functions_map - .contains_key(&column.column.column_name) + .get_aggregate_function(&column.column.column_name) + .is_some() { // Be replaced by `WindowRewriter`. return Ok(()); diff --git a/src/query/sql/src/planner/semantic/lowering.rs b/src/query/sql/src/planner/semantic/lowering.rs index f2a1fab7a4e8..c2fe31bca55b 100644 --- a/src/query/sql/src/planner/semantic/lowering.rs +++ b/src/query/sql/src/planner/semantic/lowering.rs @@ -260,6 +260,17 @@ impl ScalarExpr { scalar.as_raw_expr() } + ScalarExpr::UDAFCall(udaf) => RawExpr::ColumnRef { + span: None, + id: ColumnBinding::new_dummy_column( + udaf.display_name.clone(), + Box::new(udaf.return_type.as_ref().clone()), + DummyColumnType::UDF, + ), + data_type: udaf.return_type.as_ref().clone(), + display_name: udaf.display_name.clone(), + }, + ScalarExpr::AsyncFunctionCall(async_func) => RawExpr::ColumnRef { span: None, id: ColumnBinding::new_dummy_column( diff --git a/src/query/sql/src/planner/semantic/mod.rs b/src/query/sql/src/planner/semantic/mod.rs index bdc761c70760..987e753f170a 100644 --- a/src/query/sql/src/planner/semantic/mod.rs +++ b/src/query/sql/src/planner/semantic/mod.rs @@ -43,6 +43,7 @@ pub use name_resolution::NameResolutionSuggest; pub use name_resolution::VariableNormalizer; pub use type_check::resolve_type_name; pub use type_check::resolve_type_name_by_str; +pub use type_check::resolve_type_name_udf; pub use type_check::validate_function_arg; pub use type_check::TypeChecker; pub(crate) use udf_rewriter::UdfRewriter; diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index ae48c35f55a9..b6c338c01bd8 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -93,6 +93,7 @@ use databend_common_functions::GENERAL_SEARCH_FUNCTIONS; use databend_common_functions::GENERAL_WINDOW_FUNCTIONS; use databend_common_functions::RANK_WINDOW_FUNCTIONS; use databend_common_meta_app::principal::LambdaUDF; +use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition; use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; @@ -148,8 +149,11 @@ use crate::plans::ScalarItem; use crate::plans::SqlSource; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDAFCall; use crate::plans::UDFCall; +use crate::plans::UDFField; use crate::plans::UDFLambdaCall; +use crate::plans::UDFScriptCode; use crate::plans::UDFType; use crate::plans::Visitor as ScalarVisitor; use crate::plans::WindowFunc; @@ -732,43 +736,43 @@ impl<'a> TypeChecker<'a> { { if let Some(udf) = self.resolve_udf(*span, func_name, args)? { return Ok(udf); + } + + // Function not found, try to find and suggest similar function name. + let all_funcs = BUILTIN_FUNCTIONS + .all_function_names() + .into_iter() + .chain(AggregateFunctionFactory::instance().registered_names()) + .chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain( + Self::all_sugar_functions() + .iter() + .cloned() + .map(str::to_string), + ); + let mut engine: SimSearch = SimSearch::new(); + for func_name in all_funcs { + engine.insert(func_name.clone(), &func_name); + } + let possible_funcs = engine + .search(func_name) + .iter() + .map(|name| format!("'{name}'")) + .collect::>(); + if possible_funcs.is_empty() { + return Err(ErrorCode::UnknownFunction(format!( + "no function matches the given name: {func_name}" + )) + .set_span(*span)); } else { - // Function not found, try to find and suggest similar function name. - let all_funcs = BUILTIN_FUNCTIONS - .all_function_names() - .into_iter() - .chain(AggregateFunctionFactory::instance().registered_names()) - .chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain( - Self::all_sugar_functions() - .iter() - .cloned() - .map(str::to_string), - ); - let mut engine: SimSearch = SimSearch::new(); - for func_name in all_funcs { - engine.insert(func_name.clone(), &func_name); - } - let possible_funcs = engine - .search(func_name) - .iter() - .map(|name| format!("'{name}'")) - .collect::>(); - if possible_funcs.is_empty() { - return Err(ErrorCode::UnknownFunction(format!( - "no function matches the given name: {func_name}" - )) - .set_span(*span)); - } else { - return Err(ErrorCode::UnknownFunction(format!( - "no function matches the given name: '{func_name}', do you mean {}?", - possible_funcs.join(", ") - )) - .set_span(*span)); - } + return Err(ErrorCode::UnknownFunction(format!( + "no function matches the given name: '{func_name}', do you mean {}?", + possible_funcs.join(", ") + )) + .set_span(*span)); } } @@ -3715,6 +3719,9 @@ impl<'a> TypeChecker<'a> { UDFDefinition::UDFScript(udf_def) => Ok(Some( self.resolve_udf_script(span, name, arguments, udf_def)?, )), + UDFDefinition::UDAFScript(udf_def) => Ok(Some( + self.resolve_udaf_script(span, name, arguments, udf_def)?, + )), } } @@ -3754,7 +3761,7 @@ impl<'a> TypeChecker<'a> { UDFCall { span, name, - func_name: udf_definition.handler, + handler: udf_definition.handler, display_name, udf_type: UDFType::Server(udf_definition.address.clone()), arg_types: udf_definition.arg_types, @@ -3766,21 +3773,17 @@ impl<'a> TypeChecker<'a> { ))) } - async fn resolve_udf_with_stage(&mut self, udf_definition: &UDFScript) -> Result { - let file_location = match udf_definition.code.strip_prefix('@') { + async fn resolve_udf_with_stage(&mut self, code: String) -> Result> { + let file_location = match code.strip_prefix('@') { Some(location) => FileLocation::Stage(location.to_string()), None => { - let uri = UriLocation::from_uri(udf_definition.code.clone(), BTreeMap::default()); + let uri = UriLocation::from_uri(code.clone(), BTreeMap::default()); match uri { Ok(uri) => FileLocation::Uri(uri), Err(_) => { // fallback to use the code as real code - return Ok(UDFType::Script(( - udf_definition.language.clone(), - udf_definition.runtime_version.clone(), - udf_definition.code.clone().into(), - ))); + return Ok(code.into()); } } } @@ -3791,7 +3794,7 @@ impl<'a> TypeChecker<'a> { .map_err(|err| { ErrorCode::SemanticError(format!( "Failed to resolve code location {:?}: {}", - &udf_definition.code, err + code, err )) })?; @@ -3827,35 +3830,45 @@ impl<'a> TypeChecker<'a> { None => code_blob, }; - Ok(UDFType::Script(( - udf_definition.language.clone(), - udf_definition.runtime_version.clone(), - code_blob, - ))) + Ok(code_blob) } fn resolve_udf_script( &mut self, span: Span, name: String, - arguments: &[Expr], + args: &[Expr], udf_definition: UDFScript, ) -> Result> { - let mut args = Vec::with_capacity(arguments.len()); - for (argument, dest_type) in arguments.iter().zip(udf_definition.arg_types.iter()) { + let UDFScript { + code, + handler, + language, + arg_types, + return_type, + runtime_version, + } = udf_definition; + let language = language.parse()?; + let mut arguments = Vec::with_capacity(args.len()); + for (argument, dest_type) in args.iter().zip(arg_types.iter()) { let box (arg, ty) = self.resolve(argument)?; if ty != *dest_type { - args.push(wrap_cast(&arg, dest_type)); + arguments.push(wrap_cast(&arg, dest_type)); } else { - args.push(arg); + arguments.push(arg); } } - let const_udf_type = - databend_common_base::runtime::block_on(self.resolve_udf_with_stage(&udf_definition))?; + let code_blob = databend_common_base::runtime::block_on(self.resolve_udf_with_stage(code))? + .into_boxed_slice(); + let udf_type = UDFType::Script(UDFScriptCode { + language, + runtime_version, + code: code_blob.into(), + }); - let arg_names = arguments.iter().map(|arg| format!("{}", arg)).join(", "); - let display_name = format!("{}({})", udf_definition.handler, arg_names); + let arg_names = args.iter().map(|arg| format!("{arg}")).join(", "); + let display_name = format!("{}({})", &handler, arg_names); self.bind_context.have_udf_script = true; self.ctx.set_cacheable(false); @@ -3863,15 +3876,78 @@ impl<'a> TypeChecker<'a> { UDFCall { span, name, - func_name: udf_definition.handler, + handler, display_name, - arg_types: udf_definition.arg_types, - return_type: Box::new(udf_definition.return_type.clone()), - udf_type: const_udf_type, - arguments: args, + arg_types, + return_type: Box::new(return_type.clone()), + udf_type, + arguments, } .into(), - udf_definition.return_type.clone(), + return_type, + ))) + } + + fn resolve_udaf_script( + &mut self, + span: Span, + name: String, + args: &[Expr], + udf_definition: UDAFScript, + ) -> Result> { + let UDAFScript { + code, + language, + arg_types, + state_fields, + return_type, + runtime_version, + } = udf_definition; + let language = language.parse()?; + let code_blob = databend_common_base::runtime::block_on(self.resolve_udf_with_stage(code))? + .into_boxed_slice(); + let udf_type = UDFType::Script(UDFScriptCode { + language, + runtime_version, + code: code_blob.into(), + }); + + let mut arguments = Vec::with_capacity(arg_types.len()); + for (argument, dest_type) in args.iter().zip(arg_types.iter()) { + let box (arg, ty) = self.resolve(argument)?; + if ty != *dest_type { + arguments.push(wrap_cast(&arg, dest_type)); + } else { + arguments.push(arg); + } + } + + let display_name = format!( + "{name}({})", + arg_types.iter().map(|arg| format!("{arg}")).join(", ") + ); + + self.bind_context.have_udf_script = true; + self.ctx.set_cacheable(false); + Ok(Box::new(( + UDAFCall { + span, + name, + display_name, + arg_types, + state_fields: state_fields + .iter() + .map(|f| UDFField { + name: f.name().to_string(), + data_type: f.data_type().clone(), + }) + .collect(), + return_type: Box::new(return_type.clone()), + udf_type, + arguments, + } + .into(), + return_type, ))) } @@ -5195,6 +5271,14 @@ pub fn resolve_type_name(type_name: &TypeName, not_null: bool) -> Result Result { + let type_name = match type_name { + name @ TypeName::Nullable(_) | name @ TypeName::NotNull(_) => name, + name => &name.clone().wrap_nullable(), + }; + resolve_type_name(type_name, true) +} + pub fn validate_function_arg( name: &str, args_len: usize, diff --git a/src/query/storages/system/src/user_functions_table.rs b/src/query/storages/system/src/user_functions_table.rs index dc6ca994ed05..8ffb0d9b99d6 100644 --- a/src/query/storages/system/src/user_functions_table.rs +++ b/src/query/storages/system/src/user_functions_table.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::BTreeMap; use std::sync::Arc; use chrono::DateTime; @@ -79,7 +80,7 @@ impl AsyncSystemTable for UserFunctionsTable { for user_function in &user_functions { names.push(user_function.name.as_str()); - is_aggregate.push(None); + is_aggregate.push(Some(user_function.is_aggregate)); languages.push(user_function.language.as_str()); descriptions.push(user_function.description.as_str()); arguments.push(serde_json::to_vec(&user_function.arguments)?); @@ -107,6 +108,8 @@ pub struct UserFunctionArguments { return_type: Option, #[serde(skip_serializing_if = "std::vec::Vec::is_empty")] parameters: Vec, + #[serde(skip_serializing_if = "std::collections::BTreeMap::is_empty")] + states: BTreeMap, } #[derive(serde::Serialize)] @@ -159,12 +162,17 @@ impl UserFunctionsTable { .into_iter() .map(|user_function| UserFunction { name: user_function.name, - is_aggregate: false, + is_aggregate: match user_function.definition { + UDFDefinition::LambdaUDF(_) => false, + UDFDefinition::UDFServer(_) | UDFDefinition::UDFScript(_) => false, + UDFDefinition::UDAFScript(_) => true, + }, description: user_function.description, language: match &user_function.definition { UDFDefinition::LambdaUDF(_) => String::from("SQL"), UDFDefinition::UDFServer(x) => x.language.clone(), - UDFDefinition::UDFScript(x) => x.language.clone(), + UDFDefinition::UDFScript(x) => x.language.to_string(), + UDFDefinition::UDAFScript(x) => x.language.to_string(), }, definition: user_function.definition.to_string(), created_on: user_function.created_on, @@ -173,16 +181,29 @@ impl UserFunctionsTable { return_type: None, arg_types: vec![], parameters: x.parameters.clone(), + states: BTreeMap::new(), }, UDFDefinition::UDFServer(x) => UserFunctionArguments { parameters: vec![], return_type: Some(x.return_type.to_string()), arg_types: x.arg_types.iter().map(ToString::to_string).collect(), + states: BTreeMap::new(), }, UDFDefinition::UDFScript(x) => UserFunctionArguments { parameters: vec![], return_type: Some(x.return_type.to_string()), arg_types: x.arg_types.iter().map(ToString::to_string).collect(), + states: BTreeMap::new(), + }, + UDFDefinition::UDAFScript(x) => UserFunctionArguments { + parameters: vec![], + return_type: Some(x.return_type.to_string()), + arg_types: x.arg_types.iter().map(ToString::to_string).collect(), + states: x + .state_fields + .iter() + .map(|f| (f.name().to_string(), f.data_type().to_string())) + .collect(), }, }, }) diff --git a/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test b/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test new file mode 100644 index 000000000000..40fa9b8026ca --- /dev/null +++ b/tests/sqllogictests/suites/base/03_common/03_0047_select_udaf.test @@ -0,0 +1,40 @@ +statement ok +CREATE or replace FUNCTION weighted_avg (INT, INT) STATE {sum INT, weight INT} RETURNS FLOAT +LANGUAGE javascript AS $$ +export function create_state() { + return {sum: 0, weight: 0}; +} +export function accumulate(state, value, weight) { + state.sum += value * weight; + state.weight += weight; + return state; +} +export function retract(state, value, weight) { + state.sum -= value * weight; + state.weight -= weight; + return state; +} +export function merge(state1, state2) { + state1.sum += state2.sum; + state1.weight += state2.weight; + return state1; +} +export function finish(state) { + return state.sum / state.weight; +} +$$; + +query R +select weighted_avg(number+1, number*2) from numbers(10); +---- +7.3333335 + +query RIR +select weighted_avg(number+1, number*2), sum(number), avg(number) from numbers(10); +---- +7.3333335 45 4.5 + +query R +select a + b from ( select weighted_avg(number+1, number*2) a, avg(number) b from numbers(10) ); +---- +11.833333492279053 diff --git a/tests/suites/0_stateless/20+_others/20_0016_udf_timestamp.result b/tests/suites/0_stateless/20+_others/20_0016_udf_timestamp.result index 6f4ccbf2c45c..ca3f9000a1b0 100644 --- a/tests/suites/0_stateless/20+_others/20_0016_udf_timestamp.result +++ b/tests/suites/0_stateless/20+_others/20_0016_udf_timestamp.result @@ -1,6 +1,6 @@ ==TEST SHOW USER FUNCTIONS== -isnotempty NULL {"parameters":["p"]} SQL yyyy-mm-dd HH:MM:SS.ssssss -ping NULL Built-in UDF {"arg_types":["String NULL"],"return_type":"String NULL"} python yyyy-mm-dd HH:MM:SS.ssssss +isnotempty 0 {"parameters":["p"]} SQL yyyy-mm-dd HH:MM:SS.ssssss +ping 0 Built-in UDF {"arg_types":["String NULL"],"return_type":"String NULL"} python yyyy-mm-dd HH:MM:SS.ssssss ==TEST SELECT * FROM SYSTEM.USER_FUNCTIONS== -isnotempty NULL {"parameters":["p"]} SQL (p) -> NOT is_null(p) yyyy-mm-dd HH:MM:SS.ssssss -ping NULL Built-in UDF {"arg_types":["String NULL"],"return_type":"String NULL"} python (String NULL) RETURNS String NULL LANGUAGE python HANDLER = ping ADDRESS = http://0.0.0.0:8815 yyyy-mm-dd HH:MM:SS.ssssss +isnotempty 0 {"parameters":["p"]} SQL (p) -> NOT is_null(p) yyyy-mm-dd HH:MM:SS.ssssss +ping 0 Built-in UDF {"arg_types":["String NULL"],"return_type":"String NULL"} python (String NULL) RETURNS String NULL LANGUAGE python HANDLER = ping ADDRESS = http://0.0.0.0:8815 yyyy-mm-dd HH:MM:SS.ssssss