From 4cf3c1c5270e9a87161549d83aeba7e1a027792a Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sat, 9 Oct 2021 23:27:15 +0800 Subject: [PATCH] add approx_distinct function --- README.md | 2 + ballista/rust/core/proto/ballista.proto | 1 + .../core/src/serde/logical_plan/to_proto.rs | 4 + ballista/rust/core/src/serde/mod.rs | 3 + datafusion/src/logical_plan/expr.rs | 15 + datafusion/src/logical_plan/mod.rs | 22 +- datafusion/src/physical_plan/aggregates.rs | 14 +- .../expressions/approx_distinct.rs | 321 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 2 + .../src/physical_plan/hyperloglog/mod.rs | 67 +++- datafusion/tests/sql.rs | 17 + 11 files changed, 451 insertions(+), 17 deletions(-) create mode 100644 datafusion/src/physical_plan/expressions/approx_distinct.rs diff --git a/README.md b/README.md index 00d868c457c1..8b129177deda 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,8 @@ DataFusion also includes a simple command-line interactive SQL utility. See the - [x] trim - Miscellaneous/Boolean functions - [x] nullif +- Approximation functions + - [ ] approx_distinct - Common date/time functions - [ ] Basic date functions - [ ] Basic time functions diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 8175156e3051..9a2ec710411b 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -167,6 +167,7 @@ enum AggregateFunction { SUM = 2; AVG = 3; COUNT = 4; + APPROX_DISTINCT = 5; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index c3ffb1a2022e..402422adb205 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1137,6 +1137,9 @@ impl TryInto for &Expr { ref fun, ref args, .. } => { let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, AggregateFunction::Sum => protobuf::AggregateFunction::Sum, @@ -1370,6 +1373,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Sum => Self::Sum, AggregateFunction::Avg => Self::Avg, AggregateFunction::Count => Self::Count, + AggregateFunction::ApproxDistinct => Self::ApproxDistinct, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 1383ba89685c..a4df5a45555d 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -114,6 +114,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Sum => AggregateFunction::Sum, protobuf::AggregateFunction::Avg => AggregateFunction::Avg, protobuf::AggregateFunction::Count => AggregateFunction::Count, + protobuf::AggregateFunction::ApproxDistinct => { + AggregateFunction::ApproxDistinct + } } } } diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index f61ed832c6b4..8ef69e9b0cfe 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1495,6 +1495,21 @@ pub fn random() -> Expr { } } +/// Returns the approximate number of distinct input values. +/// This function provides an approximation of count(DISTINCT x). +/// Zero is returned if all input values are null. +/// This function should produce a standard error of 0.81%, +/// which is the standard deviation of the (approximately normal) +/// error distribution over all possible sets. +/// It does not guarantee an upper bound on the error for any specific input set. +pub fn approx_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxDistinct, + distinct: false, + args: vec![expr], + } +} + /// Create an convenience function representing a unary scalar function macro_rules! unary_scalar_expr { ($ENUM:ident, $FUNC:ident) => { diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 3f0c7d253c93..8569b35196df 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,17 +36,17 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case, - ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, - cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, - exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, - lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - normalize_col, normalize_cols, now, octet_length, or, random, regexp_match, - regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, rtrim, - sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, - substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, unnormalize_cols, - upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, - RewriteRecursion, + abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, + bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, + combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, + create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, + initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, + max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, + regexp_match, regexp_replace, repeat, replace, replace_col, reverse, right, round, + rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, + starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, + unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter, + ExpressionVisitor, Literal, Recursion, RewriteRecursion, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index aad43cc0f8b9..eb3f6ca409a4 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -59,6 +59,8 @@ pub enum AggregateFunction { Max, /// avg Avg, + /// Approximate aggregate function + ApproxDistinct, } impl fmt::Display for AggregateFunction { @@ -77,6 +79,7 @@ impl FromStr for AggregateFunction { "count" => AggregateFunction::Count, "avg" => AggregateFunction::Avg, "sum" => AggregateFunction::Sum, + "approx_distinct" => AggregateFunction::ApproxDistinct, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -96,7 +99,9 @@ pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result Ok(DataType::UInt64), + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(DataType::UInt64) + } AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()), AggregateFunction::Sum => sum_return_type(&arg_types[0]), AggregateFunction::Avg => avg_return_type(&arg_types[0]), @@ -149,6 +154,9 @@ pub fn create_aggregate_expr( "SUM(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::ApproxDistinct, _) => Arc::new( + expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()), + ), (AggregateFunction::Min, _) => { Arc::new(expressions::Min::new(arg, name, return_type)) } @@ -194,7 +202,9 @@ static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; pub fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match fun { - AggregateFunction::Count => Signature::any(1, Volatility::Immutable), + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Signature::any(1, Volatility::Immutable) + } AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs b/datafusion/src/physical_plan/expressions/approx_distinct.rs new file mode 100644 index 000000000000..7a19b6c9d16f --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use super::format_state_name; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{ + hyperloglog::HyperLogLog, Accumulator, AggregateExpr, PhysicalExpr, +}; +use crate::scalar::ScalarValue; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryOffsetSizeTrait, GenericBinaryArray, GenericStringArray, + PrimitiveArray, StringOffsetSizeTrait, +}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use std::any::type_name; +use std::any::Any; +use std::convert::TryFrom; +use std::convert::TryInto; +use std::hash::Hash; +use std::marker::PhantomData; +use std::sync::Arc; + +/// APPROX_DISTINCT aggregate expression +#[derive(Debug)] +pub struct ApproxDistinct { + name: String, + input_data_type: DataType, + expr: Arc, +} + +impl ApproxDistinct { + /// Create a new ApproxDistinct aggregate function. + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ) -> Self { + Self { + name: name.into(), + input_data_type, + expr, + } + } +} + +impl AggregateExpr for ApproxDistinct { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::UInt64, false)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "hll_registers"), + DataType::Binary, + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + let accumulator: Box = match &self.input_data_type { + // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL + // TODO support for boolean (trivial case) + DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), + DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), + DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), + DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), + other => { + return Err(DataFusionError::NotImplemented(format!( + "Support for count_distinct for data type {} is not implemented", + other + ))) + } + }; + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +struct BinaryHLLAccumulator +where + T: BinaryOffsetSizeTrait, +{ + hll: HyperLogLog>, + phantom_data: PhantomData, +} + +impl BinaryHLLAccumulator +where + T: BinaryOffsetSizeTrait, +{ + /// new approx_distinct accumulator + pub fn new() -> Self { + Self { + hll: HyperLogLog::new(), + phantom_data: PhantomData, + } + } +} + +#[derive(Debug)] +struct StringHLLAccumulator +where + T: StringOffsetSizeTrait, +{ + hll: HyperLogLog, + phantom_data: PhantomData, +} + +impl StringHLLAccumulator +where + T: StringOffsetSizeTrait, +{ + /// new approx_distinct accumulator + pub fn new() -> Self { + Self { + hll: HyperLogLog::new(), + phantom_data: PhantomData, + } + } +} + +#[derive(Debug)] +struct NumericHLLAccumulator +where + T: ArrowPrimitiveType, + T::Native: Hash, +{ + hll: HyperLogLog, +} + +impl NumericHLLAccumulator +where + T: ArrowPrimitiveType, + T::Native: Hash, +{ + /// new approx_distinct accumulator + pub fn new() -> Self { + Self { + hll: HyperLogLog::new(), + } + } +} + +impl From<&HyperLogLog> for ScalarValue { + fn from(v: &HyperLogLog) -> ScalarValue { + let values = v.as_ref().to_vec(); + ScalarValue::Binary(Some(values)) + } +} + +impl TryFrom<&[u8]> for HyperLogLog { + type Error = DataFusionError; + fn try_from(v: &[u8]) -> Result> { + let arr: [u8; 16384] = v.try_into().map_err(|_| { + DataFusionError::Internal( + "Impossibly got invalid binary array from states".into(), + ) + })?; + Ok(HyperLogLog::::new_with_registers(arr)) + } +} + +impl TryFrom<&ScalarValue> for HyperLogLog { + type Error = DataFusionError; + fn try_from(v: &ScalarValue) -> Result> { + if let ScalarValue::Binary(Some(slice)) = v { + slice.as_slice().try_into() + } else { + Err(DataFusionError::Internal( + "Impossibly got invalid scalar value while converting to HyperLogLog" + .into(), + )) + } + } +} + +macro_rules! default_accumulator_impl { + () => { + fn update(&mut self, values: &[ScalarValue]) -> Result<()> { + self.update_batch( + values + .iter() + .map(|s| s.to_array() as ArrayRef) + .collect::>() + .as_slice(), + ) + } + + fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { + assert_eq!(1, states.len(), "expect only 1 element in the states"); + let other = HyperLogLog::try_from(&states[0])?; + self.hll.merge(&other); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + assert_eq!(1, states.len(), "expect only 1 element in the states"); + let binary_array = states[0].as_any().downcast_ref::().unwrap(); + for v in binary_array.iter() { + let v = v.ok_or_else(|| { + DataFusionError::Internal( + "Impossibly got empty binary array from states".into(), + ) + })?; + let other = v.try_into()?; + self.hll.merge(&other); + } + Ok(()) + } + + fn state(&self) -> Result> { + let value = ScalarValue::from(&self.hll); + Ok(vec![value]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::UInt64(Some(self.hll.count() as u64))) + } + }; +} + +macro_rules! downcast_value { + ($Value: expr, $Type: ident, $T: tt) => {{ + $Value[0] + .as_any() + .downcast_ref::<$Type>() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast value to {}", + type_name::<$Type>() + )) + })? + }}; +} + +impl Accumulator for BinaryHLLAccumulator +where + T: BinaryOffsetSizeTrait, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array: &GenericBinaryArray = + downcast_value!(values, GenericBinaryArray, T); + // flatten because we would skip nulls + self.hll + .extend(array.into_iter().flatten().map(|v| v.to_vec())); + Ok(()) + } + + default_accumulator_impl!(); +} + +impl Accumulator for StringHLLAccumulator +where + T: StringOffsetSizeTrait, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array: &GenericStringArray = + downcast_value!(values, GenericStringArray, T); + // flatten because we would skip nulls + self.hll + .extend(array.into_iter().flatten().map(|i| i.to_string())); + Ok(()) + } + + default_accumulator_impl!(); +} + +impl Accumulator for NumericHLLAccumulator +where + T: ArrowPrimitiveType + std::fmt::Debug, + T::Native: Hash, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array: &PrimitiveArray = downcast_value!(values, PrimitiveArray, T); + // flatten because we would skip nulls + self.hll.extend(array.into_iter().flatten()); + Ok(()) + } + + default_accumulator_impl!(); +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 9f7a6cc6b5fb..4ca00367e7fe 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -25,6 +25,7 @@ use crate::physical_plan::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; +mod approx_distinct; mod average; #[macro_use] mod binary; @@ -55,6 +56,7 @@ pub mod helpers { pub use super::min_max::{max, min}; } +pub use approx_distinct::ApproxDistinct; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; diff --git a/datafusion/src/physical_plan/hyperloglog/mod.rs b/datafusion/src/physical_plan/hyperloglog/mod.rs index 25e521311079..3b91d3039e2c 100644 --- a/datafusion/src/physical_plan/hyperloglog/mod.rs +++ b/datafusion/src/physical_plan/hyperloglog/mod.rs @@ -34,9 +34,6 @@ //! //! This module also borrows some code structure from [pdatastructs.rs](https://github.com/crepererum/pdatastructs.rs/blob/3997ed50f6b6871c9e53c4c5e0f48f431405fc63/src/hyperloglog.rs). -// TODO remove this when hooked up with the rest -#![allow(dead_code)] - use ahash::{AHasher, RandomState}; use std::hash::{BuildHasher, Hash, Hasher}; use std::marker::PhantomData; @@ -58,7 +55,12 @@ where phantom: PhantomData, } -/// fixed seed for the hashing so that values are consistent across runs +/// Fixed seed for the hashing so that values are consistent across runs +/// +/// Note that when we later move on to have serialized HLL register binaries +/// shared across cluster, this SEED will have to be consistent across all +/// parties otherwise we might have corruption. So ideally for later this seed +/// shall be part of the serialized form (or stay unchanged across versions). const SEED: RandomState = RandomState::with_seeds( 0x885f6cab121d01a3_u64, 0x71e4379f2976ad8f_u64, @@ -73,6 +75,13 @@ where /// Creates a new, empty HyperLogLog. pub fn new() -> Self { let registers = [0; NUM_REGISTERS]; + Self::new_with_registers(registers) + } + + /// Creates a HyperLogLog from already populated registers + /// note that this method should not be invoked in untrusted environment + /// because the internal structure of registers are not examined. + pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self { Self { registers, phantom: PhantomData, @@ -109,6 +118,19 @@ where histogram } + /// Merge the other [`HyperLogLog`] into this one + pub fn merge(&mut self, other: &HyperLogLog) { + assert!( + self.registers.len() == other.registers.len(), + "unexpected got unequal register size, expect {}, got {}", + self.registers.len(), + other.registers.len() + ); + for i in 0..self.registers.len() { + self.registers[i] = self.registers[i].max(other.registers[i]); + } + } + /// Guess the number of unique elements seen by the HyperLogLog. pub fn count(&self) -> usize { let histogram = self.get_histogram(); @@ -171,6 +193,15 @@ fn hll_tau(x: f64) -> f64 { } } +impl AsRef<[u8]> for HyperLogLog +where + T: Hash + ?Sized, +{ + fn as_ref(&self) -> &[u8] { + &self.registers + } +} + impl Extend for HyperLogLog where T: Hash, @@ -300,4 +331,32 @@ mod tests { hll.extend((0..1000).map(|i| i.to_string())); compare_with_delta(hll.count(), 1000); } + + #[test] + fn test_empty_merge() { + let mut hll = HyperLogLog::::new(); + hll.merge(&HyperLogLog::::new()); + assert_eq!(hll.count(), 0); + } + + #[test] + fn test_merge_overlapped() { + let mut hll = HyperLogLog::::new(); + hll.extend((0..1000).map(|i| i.to_string())); + + let mut other = HyperLogLog::::new(); + other.extend((0..1000).map(|i| i.to_string())); + + hll.merge(&other); + compare_with_delta(hll.count(), 1000); + } + + #[test] + fn test_repetition() { + let mut hll = HyperLogLog::::new(); + for i in 0..1_000_000 { + hll.add(&(i % 1000)); + } + compare_with_delta(hll.count(), 1000); + } } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 801451f81d86..e82254203af0 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -983,6 +983,23 @@ async fn csv_query_count() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_approx_count() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+--------------+", + "| count_c9 | count_c9_str |", + "+----------+--------------+", + "| 100 | 99 |", + "+----------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_empty_over() -> Result<()> {