Skip to content

Commit

Permalink
Support aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
lewiszlw committed Feb 19, 2024
1 parent 4968187 commit 8fc6758
Show file tree
Hide file tree
Showing 17 changed files with 398 additions and 72 deletions.
27 changes: 25 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion bustubx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ derive-new = "0.5.9"
tracing = "0.1"
thiserror = "1.0.56"
tempfile = "3"
derive-with = "0.5.0"
derive-with = "0.5.0"
strum = { version = "0.26", features = ["derive"]}
2 changes: 1 addition & 1 deletion bustubx/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl Schema {
.enumerate()
.find(|(_, col)| match (relation, &col.relation) {
(Some(rel), Some(col_rel)) => rel.resolved_eq(col_rel) && name == &col.name,
(Some(rel), None) => false,
(Some(_), None) => false,
(None, Some(_)) | (None, None) => name == &col.name,
})
.ok_or_else(|| BustubxError::Plan(format!("Unable to get column named \"{name}\"")))?;
Expand Down
10 changes: 10 additions & 0 deletions bustubx/src/common/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ impl ScalarValue {
self, data_type
))),
},
DataType::Float64 => match self {
ScalarValue::Int8(v) => Ok(ScalarValue::Float64(v.map(|v| v as f64))),
ScalarValue::Int32(v) => Ok(ScalarValue::Float64(v.map(|v| v as f64))),
ScalarValue::Int64(v) => Ok(ScalarValue::Float64(v.map(|v| v as f64))),
ScalarValue::Float64(v) => Ok(ScalarValue::Float64(v.map(|v| v))),
_ => Err(BustubxError::NotSupport(format!(
"Failed to cast {} to {} type",
self, data_type
))),
},
_ => Err(BustubxError::NotSupport(format!(
"Not support cast to {} type",
data_type
Expand Down
75 changes: 70 additions & 5 deletions bustubx/src/execution/physical_plan/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,89 @@
use crate::catalog::SchemaRef;
use crate::common::ScalarValue;
use crate::execution::physical_plan::PhysicalPlan;
use crate::execution::{ExecutionContext, VolcanoExecutor};
use crate::expression::Expr;
use crate::{BustubxResult, Tuple};
use crate::expression::{Accumulator, Expr, ExprTrait};
use crate::{BustubxError, BustubxResult, Tuple};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

#[derive(Debug)]
pub struct PhysicalAggregate {
/// The incoming physical plan
pub input: Arc<PhysicalPlan>,
/// Grouping expressions
pub group_expr: Vec<Expr>,
pub group_exprs: Vec<Expr>,
/// Aggregate expressions
pub aggr_expr: Vec<Expr>,
pub aggr_exprs: Vec<Expr>,
/// The schema description of the aggregate output
pub schema: SchemaRef,

// TODO tmp solution
pub output_count: AtomicUsize,
}

impl PhysicalAggregate {
pub fn new(
input: Arc<PhysicalPlan>,
group_exprs: Vec<Expr>,
aggr_exprs: Vec<Expr>,
schema: SchemaRef,
) -> Self {
Self {
input,
group_exprs,
aggr_exprs,
schema,
output_count: AtomicUsize::new(0),
}
}
}

impl VolcanoExecutor for PhysicalAggregate {
fn init(&self, context: &mut ExecutionContext) -> BustubxResult<()> {
self.input.init(context)?;
self.output_count.store(0, Ordering::SeqCst);
Ok(())
}

fn next(&self, context: &mut ExecutionContext) -> BustubxResult<Option<Tuple>> {
todo!()
if self.output_count.load(Ordering::SeqCst) > 0 {
return Ok(None);
}

// TODO support group
let mut accumulators = self
.aggr_exprs
.iter()
.map(|expr| {
if let Expr::AggregateFunction(aggr) = expr {
Ok(aggr.func_kind.create_accumulator())
} else {
Err(BustubxError::Execution(format!(
"aggr expr is not AggregateFunction instead of {}",
expr
)))
}
})
.collect::<BustubxResult<Vec<Box<dyn Accumulator>>>>()?;

loop {
if let Some(tuple) = self.input.next(context)? {
for (idx, acc) in accumulators.iter_mut().enumerate() {
acc.update_value(&self.aggr_exprs[idx].evaluate(&tuple)?)?;
}
} else {
break;
}
}

let values = accumulators
.iter()
.map(|acc| acc.evaluate())
.collect::<BustubxResult<Vec<ScalarValue>>>()?;

self.output_count.fetch_add(1, Ordering::SeqCst);
Ok(Some(Tuple::new(self.schema.clone(), values)))
}

fn output_schema(&self) -> SchemaRef {
Expand Down
7 changes: 7 additions & 0 deletions bustubx/src/execution/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod seq_scan;
mod sort;
mod values;

pub use aggregate::PhysicalAggregate;
pub use create_index::PhysicalCreateIndex;
pub use create_table::PhysicalCreateTable;
pub use empty::PhysicalEmpty;
Expand Down Expand Up @@ -43,6 +44,7 @@ pub enum PhysicalPlan {
Values(PhysicalValues),
NestedLoopJoin(PhysicalNestedLoopJoin),
Sort(PhysicalSort),
Aggregate(PhysicalAggregate),
}

impl PhysicalPlan {
Expand All @@ -58,6 +60,7 @@ impl PhysicalPlan {
..
}) => vec![left_input, right_input],
PhysicalPlan::Sort(PhysicalSort { input, .. }) => vec![input],
PhysicalPlan::Aggregate(PhysicalAggregate { input, .. }) => vec![input],
PhysicalPlan::Empty(_)
| PhysicalPlan::CreateTable(_)
| PhysicalPlan::CreateIndex(_)
Expand All @@ -81,6 +84,7 @@ impl VolcanoExecutor for PhysicalPlan {
PhysicalPlan::Limit(op) => op.init(context),
PhysicalPlan::NestedLoopJoin(op) => op.init(context),
PhysicalPlan::Sort(op) => op.init(context),
PhysicalPlan::Aggregate(op) => op.init(context),
}
}

Expand All @@ -97,6 +101,7 @@ impl VolcanoExecutor for PhysicalPlan {
PhysicalPlan::Limit(op) => op.next(context),
PhysicalPlan::NestedLoopJoin(op) => op.next(context),
PhysicalPlan::Sort(op) => op.next(context),
PhysicalPlan::Aggregate(op) => op.next(context),
}
}

Expand All @@ -113,6 +118,7 @@ impl VolcanoExecutor for PhysicalPlan {
Self::Limit(op) => op.output_schema(),
Self::NestedLoopJoin(op) => op.output_schema(),
Self::Sort(op) => op.output_schema(),
Self::Aggregate(op) => op.output_schema(),
}
}
}
Expand All @@ -131,6 +137,7 @@ impl std::fmt::Display for PhysicalPlan {
Self::Limit(op) => write!(f, "{op}"),
Self::NestedLoopJoin(op) => write!(f, "{op}"),
Self::Sort(op) => write!(f, "{op}"),
Self::Aggregate(op) => write!(f, "{op}"),
}
}
}
49 changes: 49 additions & 0 deletions bustubx/src/expression/aggr/avg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use crate::catalog::DataType;
use crate::common::ScalarValue;
use crate::expression::Accumulator;
use crate::{BustubxError, BustubxResult};
use std::fs::read;

Check warning on line 5 in bustubx/src/expression/aggr/avg.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `std::fs::read`

#[derive(Debug)]
pub struct AvgAccumulator {
sum: Option<f64>,
count: u64,
}

impl AvgAccumulator {
pub fn new() -> Self {
Self {
sum: None,
count: 0,
}
}
}

impl Accumulator for AvgAccumulator {
fn update_value(&mut self, value: &ScalarValue) -> BustubxResult<()> {
if !value.is_null() {
let value = match value.cast_to(&DataType::Float64)? {
ScalarValue::Float64(Some(v)) => v,
_ => {
return Err(BustubxError::Internal(format!(
"Failed to cast value {} to float64",
value
)))
}
};

match self.sum {
Some(sum) => self.sum = Some(sum + value),
None => self.sum = Some(value),
}
self.count += 1;
}
Ok(())
}

fn evaluate(&self) -> BustubxResult<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
}
}
18 changes: 14 additions & 4 deletions bustubx/src/expression/aggr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
mod avg;
mod count;

pub use avg::AvgAccumulator;
pub use count::CountAccumulator;

use crate::catalog::{Column, DataType, Schema};
use crate::common::ScalarValue;
use crate::expression::{Expr, ExprTrait};
use crate::{BustubxError, BustubxResult, Tuple};
use std::fmt::Debug;
use strum::{EnumIter, IntoEnumIterator};

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct AggregateFunction {
Expand All @@ -22,6 +25,7 @@ impl ExprTrait for AggregateFunction {
fn data_type(&self, _input_schema: &Schema) -> BustubxResult<DataType> {
match self.func_kind {
AggregateFunctionKind::Count => Ok(DataType::Int64),
AggregateFunctionKind::Avg => Ok(DataType::Float64),
}
}

Expand All @@ -31,10 +35,10 @@ impl ExprTrait for AggregateFunction {

fn evaluate(&self, tuple: &Tuple) -> BustubxResult<ScalarValue> {
match self.func_kind {
AggregateFunctionKind::Count => {
AggregateFunctionKind::Count | AggregateFunctionKind::Avg => {
let expr = self.args.get(0).ok_or(BustubxError::Internal(format!(
"COUNT function should have one arg instead of {:?}",
self.args
"aggregate function {} should have one arg instead of {:?}",
self.func_kind, self.args
)))?;
expr.evaluate(tuple)
}
Expand All @@ -56,17 +60,23 @@ impl std::fmt::Display for AggregateFunction {
}
}

#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, PartialEq, Eq, Debug, EnumIter)]
pub enum AggregateFunctionKind {
Count,
Avg,
}

impl AggregateFunctionKind {
pub fn create_accumulator(&self) -> Box<dyn Accumulator> {
match self {
AggregateFunctionKind::Count => Box::new(CountAccumulator::new()),
AggregateFunctionKind::Avg => Box::new(AvgAccumulator::new()),
}
}

pub fn find(name: &str) -> Option<Self> {
AggregateFunctionKind::iter().find(|kind| kind.to_string().eq_ignore_ascii_case(name))
}
}

impl std::fmt::Display for AggregateFunctionKind {
Expand Down
2 changes: 2 additions & 0 deletions bustubx/src/expression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ mod binary;
mod cast;
mod column;
mod literal;
mod util;

pub use aggr::*;
pub use alias::Alias;
pub use binary::{BinaryExpr, BinaryOp};

Check warning on line 11 in bustubx/src/expression/mod.rs

View workflow job for this annotation

GitHub Actions / Test Suite

unused import: `BinaryOp`
pub use cast::Cast;
pub use column::ColumnExpr;
pub use literal::Literal;
pub use util::*;

use crate::catalog::Schema;
use crate::catalog::{Column, DataType};
Expand Down
27 changes: 27 additions & 0 deletions bustubx/src/expression/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::catalog::SchemaRef;
use crate::expression::{Alias, Cast, ColumnExpr, Expr};
use crate::BustubxResult;

/// Convert an expression into Column expression
pub fn columnize_expr(e: &Expr, input_schema: &SchemaRef) -> BustubxResult<Expr> {
match e {
Expr::Column(_) => Ok(e.clone()),
Expr::Alias(Alias { expr, name }) => Ok(Expr::Alias(Alias {
expr: Box::new(columnize_expr(expr, input_schema)?),
name: name.clone(),
})),
Expr::Cast(Cast { expr, data_type }) => Ok(Expr::Cast(Cast {
expr: Box::new(columnize_expr(expr, input_schema)?),
data_type: data_type.clone(),
})),
_ => {
let name = e.to_string();
let idx = input_schema.index_of(None, name.as_str())?;
let col = input_schema.column_with_index(idx)?;
Ok(Expr::Column(ColumnExpr {
relation: col.relation.clone(),
name,
}))
}
}
}
Loading

0 comments on commit 8fc6758

Please sign in to comment.