From 21ef43591f6c6bd0218ebe4bbfd8889d75931e7f Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 26 Jan 2023 23:58:43 +0530 Subject: [PATCH 01/26] wip(codegen,ic): groupby and aggregates --- src/codegen.rs | 234 +++++++++++++++++++++++++++++++++++++----------- src/expr/agg.rs | 7 ++ src/expr/mod.rs | 93 +------------------ src/ic.rs | 16 +++- src/vm.rs | 6 ++ 5 files changed, 212 insertions(+), 144 deletions(-) create mode 100644 src/expr/agg.rs diff --git a/src/codegen.rs b/src/codegen.rs index 5864ba3..aad2646 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -7,7 +7,7 @@ use sqlparser::{ use std::{error::Error, fmt::Display}; use crate::{ - expr::{Expr, ExprError}, + expr::{BinOp, Expr, ExprError, UnOp}, ic::{Instruction, IntermediateCode}, identifier::IdentifierError, parser::parse, @@ -54,11 +54,40 @@ pub fn codegen_str(code: &str) -> Result, ParserOrCodegenE .collect::, CodegenError>>()?) } +/// Context passed around to any func that needs codegen. +struct CodegenContext { + pub instrs: Vec, + current_reg: RegisterIndex, +} + +impl CodegenContext { + pub fn new() -> Self { + Self { + instrs: Vec::new(), + current_reg: RegisterIndex::default(), + } + } + + pub fn get_and_increment_reg(&mut self) -> RegisterIndex { + let reg = self.current_reg; + self.current_reg = self.current_reg.next_index(); + reg + } + + pub fn last_used_reg(&self) -> RegisterIndex { + self.current_reg + } +} + +impl Default for CodegenContext { + fn default() -> Self { + Self::new() + } +} + /// Generates intermediate code from the AST. pub fn codegen_ast(ast: &Statement) -> Result { - let mut instrs = Vec::::new(); - - let mut current_reg = RegisterIndex::default(); + let mut ctx = CodegenContext::default(); match ast { Statement::CreateTable { @@ -87,35 +116,33 @@ pub fn codegen_ast(ast: &Statement) -> Result { collation: _, on_commit: _, } => { - let table_reg_index = current_reg; - instrs.push(Instruction::Empty { + let table_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { index: table_reg_index, }); - current_reg = current_reg.next_index(); - let col_reg_index = current_reg; - current_reg = current_reg.next_index(); + let col_reg_index = ctx.get_and_increment_reg(); for col in columns { - instrs.push(Instruction::ColumnDef { + ctx.instrs.push(Instruction::ColumnDef { index: col_reg_index, name: col.name.value.as_str().into(), data_type: col.data_type.clone(), }); for option in col.options.iter() { - instrs.push(Instruction::AddColumnOption { + ctx.instrs.push(Instruction::AddColumnOption { index: col_reg_index, option: option.clone(), }); } - instrs.push(Instruction::AddColumn { + ctx.instrs.push(Instruction::AddColumn { table_reg_index, col_index: col_reg_index, }); } - instrs.push(Instruction::NewTable { + ctx.instrs.push(Instruction::NewTable { index: table_reg_index, name: name.0.clone().try_into()?, exists_ok: *if_not_exists, @@ -134,22 +161,20 @@ pub fn codegen_ast(ast: &Statement) -> Result { table: _, on: _, } => { - let table_reg_index = current_reg; - instrs.push(Instruction::Source { + let table_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Source { index: table_reg_index, name: table_name.0.clone().try_into()?, }); - current_reg = current_reg.next_index(); - let insert_reg_index = current_reg; - instrs.push(Instruction::InsertDef { + let insert_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::InsertDef { table_reg_index, index: insert_reg_index, }); - current_reg = current_reg.next_index(); for col in columns { - instrs.push(Instruction::ColumnInsertDef { + ctx.instrs.push(Instruction::ColumnInsertDef { insert_index: insert_reg_index, col_name: col.value.as_str().into(), }) @@ -162,18 +187,18 @@ pub fn codegen_ast(ast: &Statement) -> Result { .. } => { for row in values.0.clone() { - let row_reg = current_reg; - current_reg = current_reg.next_index(); + let row_reg = ctx.get_and_increment_reg(); - instrs.push(Instruction::RowDef { + ctx.instrs.push(Instruction::RowDef { insert_index: insert_reg_index, row_index: row_reg, }); for value in row { - instrs.push(Instruction::AddValue { + let value = codegen_expr(value, &mut ctx)?; + ctx.instrs.push(Instruction::AddValue { row_index: row_reg, - expr: value.try_into()?, + expr: value, }); } } @@ -185,7 +210,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { )), }?; - instrs.push(Instruction::Insert { + ctx.instrs.push(Instruction::Insert { index: insert_reg_index, }); @@ -193,8 +218,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } Statement::Query(query) => { // TODO: support CTEs - let mut table_reg_index = current_reg; - current_reg = current_reg.next_index(); + let mut table_reg_index = ctx.get_and_increment_reg(); match &query.body { SetExpr::Select(select) => { @@ -215,7 +239,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { alias: _, args: _, with_hints: _, - } => instrs.push(Instruction::Source { + } => ctx.instrs.push(Instruction::Source { index: table_reg_index, name: name.0.clone().try_into()?, }), @@ -249,7 +273,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } } - &[] => instrs.push(Instruction::NonExistent { + &[] => ctx.instrs.push(Instruction::NonExistent { index: table_reg_index, }), _ => { @@ -262,43 +286,47 @@ pub fn codegen_ast(ast: &Statement) -> Result { } if let Some(expr) = select.selection.clone() { - instrs.push(Instruction::Filter { + let expr = codegen_expr(expr, &mut ctx)?; + ctx.instrs.push(Instruction::Filter { index: table_reg_index, - expr: expr.try_into()?, + expr, }) } for group_by in select.group_by.clone() { - instrs.push(Instruction::GroupBy { + let group_by = codegen_expr(group_by, &mut ctx)?; + ctx.instrs.push(Instruction::GroupBy { index: table_reg_index, - expr: group_by.try_into()?, + expr: group_by, }); } if let Some(expr) = select.having.clone() { - instrs.push(Instruction::Filter { + let expr = codegen_expr(expr, &mut ctx)?; + ctx.instrs.push(Instruction::Filter { index: table_reg_index, - expr: expr.try_into()?, + expr, }) } if !select.projection.is_empty() { let original_table_reg_index = table_reg_index; - table_reg_index = current_reg; - current_reg = current_reg.next_index(); + table_reg_index = ctx.get_and_increment_reg(); - instrs.push(Instruction::Empty { + ctx.instrs.push(Instruction::Empty { index: table_reg_index, }); for projection in select.projection.clone() { - instrs.push(Instruction::Project { + let projection = Instruction::Project { input: original_table_reg_index, output: table_reg_index, expr: match projection { - SelectItem::UnnamedExpr(ref expr) => expr.clone().try_into()?, + SelectItem::UnnamedExpr(ref expr) => { + codegen_expr(expr.clone(), &mut ctx)? + } SelectItem::ExprWithAlias { ref expr, .. } => { - expr.clone().try_into()? + codegen_expr(expr.clone(), &mut ctx)? } SelectItem::QualifiedWildcard(_) => Expr::Wildcard, SelectItem::Wildcard => Expr::Wildcard, @@ -316,7 +344,8 @@ pub fn codegen_ast(ast: &Statement) -> Result { } SelectItem::Wildcard => None, }, - }) + }; + ctx.instrs.push(projection) } if select.distinct { @@ -329,8 +358,8 @@ pub fn codegen_ast(ast: &Statement) -> Result { } SetExpr::Values(exprs) => { if exprs.0.len() == 1 && exprs.0[0].len() == 1 { - let expr: Expr = exprs.0[0][0].clone().try_into()?; - instrs.push(Instruction::Expr { + let expr: Expr = codegen_expr(exprs.0[0][0].clone(), &mut ctx)?; + ctx.instrs.push(Instruction::Expr { index: table_reg_index, expr, }); @@ -376,9 +405,10 @@ pub fn codegen_ast(ast: &Statement) -> Result { }; for order_by in query.order_by.clone() { - instrs.push(Instruction::Order { + let order_by_expr = codegen_expr(order_by.expr, &mut ctx)?; + ctx.instrs.push(Instruction::Order { index: table_reg_index, - expr: order_by.expr.try_into()?, + expr: order_by_expr, ascending: order_by.asc.unwrap_or(true), }); // TODO: support NULLS FIRST/NULLS LAST @@ -387,7 +417,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { if let Some(limit) = query.limit.clone() { if let ast::Expr::Value(val) = limit.clone() { if let Value::Int64(limit) = val.clone().try_into()? { - instrs.push(Instruction::Limit { + ctx.instrs.push(Instruction::Limit { index: table_reg_index, limit: limit as u64, }); @@ -407,7 +437,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - instrs.push(Instruction::Return { + ctx.instrs.push(Instruction::Return { index: table_reg_index, }); @@ -417,7 +447,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { schema_name, if_not_exists, } => { - instrs.push(Instruction::NewSchema { + ctx.instrs.push(Instruction::NewSchema { schema_name: schema_name.0.clone().try_into()?, exists_ok: *if_not_exists, }); @@ -426,7 +456,109 @@ pub fn codegen_ast(ast: &Statement) -> Result { _ => Err(CodegenError::UnsupportedStatement(ast.to_string())), }?; - Ok(IntermediateCode { instrs }) + Ok(IntermediateCode { instrs: ctx.instrs }) +} + +fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result { + match expr_ast { + ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), + ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), + ast::Expr::IsFalse(e) => Ok(Expr::Unary { + op: UnOp::IsFalse, + operand: Box::new(codegen_expr(*e, ctx)?), + }), + ast::Expr::IsTrue(e) => Ok(Expr::Unary { + op: UnOp::IsTrue, + operand: Box::new(codegen_expr(*e, ctx)?), + }), + ast::Expr::IsNull(e) => Ok(Expr::Unary { + op: UnOp::IsNull, + operand: Box::new(codegen_expr(*e, ctx)?), + }), + ast::Expr::IsNotNull(e) => Ok(Expr::Unary { + op: UnOp::IsNotNull, + operand: Box::new(codegen_expr(*e, ctx)?), + }), + ast::Expr::Between { + expr, + negated, + low, + high, + } => { + let expr: Box = Box::new(codegen_expr(*expr, ctx)?); + let left = Box::new(codegen_expr(*low, ctx)?); + let right = Box::new(codegen_expr(*high, ctx)?); + let between = Expr::Binary { + left: Box::new(Expr::Binary { + left, + op: BinOp::LessThanOrEqual, + right: expr.clone(), + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: expr, + op: BinOp::LessThanOrEqual, + right, + }), + }; + if negated { + Ok(Expr::Unary { + op: UnOp::Not, + operand: Box::new(between), + }) + } else { + Ok(between) + } + } + ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary { + left: Box::new(codegen_expr(*left, ctx)?), + op: op.try_into()?, + right: Box::new(codegen_expr(*right, ctx)?), + }), + ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { + op: op.try_into()?, + operand: Box::new(codegen_expr(*expr, ctx)?), + }), + ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), + ast::Expr::Function(ref f) => { + match f.args { + &[ast::Expr::Identifier(i)] => { + ctx.push(Instruction::Aggregate { + input: ctx.last_used_reg(), + output: ctx.get_and_increment_reg(), + func: f.name.to_string(), + col_name: (), + }) + // Ok(Expr::ColumnRef(vec![i].try_into()?)) + } + } + Ok(Expr::Function { + name: f.name.to_string().as_str().into(), + args: f + .args + .iter() + .map(|arg| match arg { + ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { + ast::FunctionArgExpr::Expr(e) => Ok(codegen_expr(e.clone(), instrs)?), + ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), + ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { + reason: "Qualified wildcards are not supported yet", + expr: expr_ast.clone(), + }), + }, + ast::FunctionArg::Named { .. } => Err(ExprError::Expr { + reason: "Named function arguments are not supported", + expr: expr_ast.clone(), + }), + }) + .collect::, _>>()?, + }) + } + _ => Err(ExprError::Expr { + reason: "Unsupported expression", + expr: expr_ast, + }), + } } /// Error while generating an intermediate code from the AST. diff --git a/src/expr/agg.rs b/src/expr/agg.rs new file mode 100644 index 0000000..4eb2990 --- /dev/null +++ b/src/expr/agg.rs @@ -0,0 +1,7 @@ +#[derive(Clone, PartialEq, Eq)] +/// Functions that reduce an entire column to a single value. +pub enum AggregateFunction { + Count, + Max, + Sum, +} diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 21e97fd..5d6e84a 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -11,6 +11,7 @@ use crate::{ }; pub mod eval; +pub mod agg; /// An expression #[derive(Debug, Clone, PartialEq)] @@ -130,98 +131,6 @@ impl Display for UnOp { } } -impl TryFrom for Expr { - type Error = ExprError; - fn try_from(expr_ast: ast::Expr) -> Result { - match expr_ast { - ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), - ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), - ast::Expr::IsFalse(e) => Ok(Expr::Unary { - op: UnOp::IsFalse, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsTrue(e) => Ok(Expr::Unary { - op: UnOp::IsTrue, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsNull(e) => Ok(Expr::Unary { - op: UnOp::IsNull, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::IsNotNull(e) => Ok(Expr::Unary { - op: UnOp::IsNotNull, - operand: Box::new((*e).try_into()?), - }), - ast::Expr::Between { - expr, - negated, - low, - high, - } => { - let expr: Box = Box::new((*expr).try_into()?); - let left = Box::new((*low).try_into()?); - let right = Box::new((*high).try_into()?); - let between = Expr::Binary { - left: Box::new(Expr::Binary { - left, - op: BinOp::LessThanOrEqual, - right: expr.clone(), - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: expr, - op: BinOp::LessThanOrEqual, - right, - }), - }; - if negated { - Ok(Expr::Unary { - op: UnOp::Not, - operand: Box::new(between), - }) - } else { - Ok(between) - } - } - ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary { - left: Box::new((*left).try_into()?), - op: op.try_into()?, - right: Box::new((*right).try_into()?), - }), - ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { - op: op.try_into()?, - operand: Box::new((*expr).try_into()?), - }), - ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), - ast::Expr::Function(ref f) => Ok(Expr::Function { - name: f.name.to_string().as_str().into(), - args: f - .args - .iter() - .map(|arg| match arg { - ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { - ast::FunctionArgExpr::Expr(e) => Ok(e.clone().try_into()?), - ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), - ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { - reason: "Qualified wildcards are not supported yet", - expr: expr_ast.clone(), - }), - }, - ast::FunctionArg::Named { .. } => Err(ExprError::Expr { - reason: "Named function arguments are not supported", - expr: expr_ast.clone(), - }), - }) - .collect::, _>>()?, - }), - _ => Err(ExprError::Expr { - reason: "Unsupported expression", - expr: expr_ast, - }), - } - } -} - impl TryFrom for BinOp { type Error = ExprError; fn try_from(op: ast::BinaryOperator) -> Result { diff --git a/src/ic.rs b/src/ic.rs index 3d0fc1a..658acf1 100644 --- a/src/ic.rs +++ b/src/ic.rs @@ -5,7 +5,7 @@ use fmt_derive::{Debug, Display}; use sqlparser::ast::{ColumnOptionDef, DataType}; use crate::{ - expr::Expr, + expr::{Expr, agg::AggregateFunction}, identifier::{SchemaRef, TableRef}, value::Value, vm::RegisterIndex, @@ -68,6 +68,20 @@ pub enum Instruction { alias: Option, }, + Aggregate { + input: RegisterIndex, + output: RegisterIndex, + func: AggregateFunction, + #[display( + "{}", + match col_name { + None => "None".to_owned(), + Some(col_name) => format!("{}", col_name) + } + )] + col_name: Option, + }, + /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `index` by the given expression. /// /// This will result in a [`Register::GroupedTable`](`crate::vm::Register::GroupedTable`) being stored at the `index` register. diff --git a/src/vm.rs b/src/vm.rs index 8396995..32e211c 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -328,6 +328,12 @@ impl VirtualMachine { return Err(RuntimeError::RegisterNotATable("project", reg.clone())) } }, + Instruction::Aggregate { + input: _, + output: _, + func: _, + col_name: _, + } => todo!("aggregate is not implemented yet"), Instruction::GroupBy { index: _, expr: _ } => todo!("group by is not implemented yet"), Instruction::Order { index, From c061302ff4d48376b5b040ed196f32742442c514 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Wed, 1 Feb 2023 23:59:20 +0530 Subject: [PATCH 02/26] feat(ic): remove Value. add inp, out regs to GroupBy. --- src/codegen.rs | 5 ++++- src/ic.rs | 16 ++++++++-------- src/vm.rs | 4 ---- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index aad2646..f553cf4 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -295,10 +295,13 @@ pub fn codegen_ast(ast: &Statement) -> Result { for group_by in select.group_by.clone() { let group_by = codegen_expr(group_by, &mut ctx)?; + let grouped_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::GroupBy { - index: table_reg_index, + input: table_reg_index, + output: grouped_reg_index, expr: group_by, }); + table_reg_index = grouped_reg_index } if let Some(expr) = select.having.clone() { diff --git a/src/ic.rs b/src/ic.rs index 658acf1..feb0184 100644 --- a/src/ic.rs +++ b/src/ic.rs @@ -5,9 +5,8 @@ use fmt_derive::{Debug, Display}; use sqlparser::ast::{ColumnOptionDef, DataType}; use crate::{ - expr::{Expr, agg::AggregateFunction}, + expr::{agg::AggregateFunction, Expr}, identifier::{SchemaRef, TableRef}, - value::Value, vm::RegisterIndex, BoundedString, }; @@ -21,9 +20,6 @@ pub struct IntermediateCode { /// The instruction set of OtterSQL. #[derive(Display, Debug, Clone, PartialEq)] pub enum Instruction { - /// Load a [`Value`] into a register. - Value { index: RegisterIndex, value: Value }, - /// Load a [`Expr`] into a register. Expr { index: RegisterIndex, expr: Expr }, @@ -82,12 +78,16 @@ pub enum Instruction { col_name: Option, }, - /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `index` by the given expression. + /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `input` by the given expression. /// - /// This will result in a [`Register::GroupedTable`](`crate::vm::Register::GroupedTable`) being stored at the `index` register. + /// This will result in a [`Register::GroupedTable`](`crate::vm::Register::GroupedTable`) being stored at the `output` register. /// /// Must be added before any projections so as to catch errors in column selections. - GroupBy { index: RegisterIndex, expr: Expr }, + GroupBy { + input: RegisterIndex, + output: RegisterIndex, + expr: Expr, + }, /// Order the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `index` by the given expression. /// diff --git a/src/vm.rs b/src/vm.rs index 32e211c..18ec7ef 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -133,10 +133,6 @@ impl VirtualMachine { fn execute_instr(&mut self, instr: &Instruction) -> Result, RuntimeError> { let _ = &self.database; match instr { - Instruction::Value { index, value } => { - self.registers - .insert(*index, Register::Value(value.clone())); - } Instruction::Expr { index, expr } => { self.registers.insert(*index, Register::Expr(expr.clone())); } From b2f6c6c56adf839a795467c46d57a7a67d435dd7 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 2 Feb 2023 00:00:45 +0530 Subject: [PATCH 03/26] refactor: rename ic to ir --- src/codegen.rs | 4 ++-- src/{ic.rs => ir.rs} | 0 src/lib.rs | 4 ++-- src/vm.rs | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename src/{ic.rs => ir.rs} (100%) diff --git a/src/codegen.rs b/src/codegen.rs index f553cf4..d8c14a5 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -8,7 +8,7 @@ use std::{error::Error, fmt::Display}; use crate::{ expr::{BinOp, Expr, ExprError, UnOp}, - ic::{Instruction, IntermediateCode}, + ir::{Instruction, IntermediateCode}, identifier::IdentifierError, parser::parse, value::{Value, ValueError}, @@ -617,7 +617,7 @@ mod tests { use crate::{ codegen::codegen_ast, expr::{BinOp, Expr}, - ic::Instruction, + ir::Instruction, identifier::{ColumnRef, SchemaRef, TableRef}, parser::parse, value::Value, diff --git a/src/ic.rs b/src/ir.rs similarity index 100% rename from src/ic.rs rename to src/ir.rs diff --git a/src/lib.rs b/src/lib.rs index bbafd36..1ab0c9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ pub mod codegen; pub mod column; pub mod database; pub mod expr; -pub mod ic; +pub mod ir; pub mod identifier; pub mod parser; pub mod schema; @@ -18,7 +18,7 @@ pub mod vm; pub use column::Column; pub use database::Database; -pub use ic::{Instruction, IntermediateCode}; +pub use ir::{Instruction, IntermediateCode}; pub use identifier::BoundedString; pub use table::Table; pub use value::Value; diff --git a/src/vm.rs b/src/vm.rs index 18ec7ef..90268f9 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -13,7 +13,7 @@ use crate::codegen::{codegen_ast, CodegenError}; use crate::column::Column; use crate::expr::eval::ExprExecError; use crate::expr::Expr; -use crate::ic::{Instruction, IntermediateCode}; +use crate::ir::{Instruction, IntermediateCode}; use crate::identifier::{ColumnRef, TableRef}; use crate::parser::parse; use crate::schema::Schema; From 5dfadc501ca737f03d69ba8dfb67ec307c0e5ee0 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Fri, 3 Feb 2023 00:25:26 +0530 Subject: [PATCH 04/26] wip: special codegen for agg functions --- Cargo.toml | 1 + src/codegen.rs | 145 ++++++++++++++++++++++++++++++++++++------------- 2 files changed, 107 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6d153aa..90e7d6f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ permutation = "0.4.1" ordered-float = "3.1.0" tabled = { version = "0.10.0", optional = true } fmt-derive = "0.0.5" +phf = { version = "0.11.1", features = ["macros"] } [features] default = ["terminal-output"] diff --git a/src/codegen.rs b/src/codegen.rs index d8c14a5..334cbff 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,18 +1,23 @@ //! Intermediate code generation from the AST. use sqlparser::{ - ast::{self, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins}, + ast::{ + self, Function, FunctionArg, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins, + }, parser::ParserError, }; +use phf::{phf_set, Set}; + use std::{error::Error, fmt::Display}; use crate::{ expr::{BinOp, Expr, ExprError, UnOp}, - ir::{Instruction, IntermediateCode}, identifier::IdentifierError, + ir::{Instruction, IntermediateCode}, parser::parse, value::{Value, ValueError}, vm::RegisterIndex, + BoundedString, }; /// Represents either a parser error or a codegen error. @@ -54,10 +59,14 @@ pub fn codegen_str(code: &str) -> Result, ParserOrCodegenE .collect::, CodegenError>>()?) } +const TEMP_COL_NAME_PREFIX: &'static str = "__otter_temp_col"; + /// Context passed around to any func that needs codegen. struct CodegenContext { pub instrs: Vec, current_reg: RegisterIndex, + last_temp_col_num: usize, + is_inside_agg_fn: bool, } impl CodegenContext { @@ -65,6 +74,8 @@ impl CodegenContext { Self { instrs: Vec::new(), current_reg: RegisterIndex::default(), + last_temp_col_num: 0, + is_inside_agg_fn: false, } } @@ -77,6 +88,13 @@ impl CodegenContext { pub fn last_used_reg(&self) -> RegisterIndex { self.current_reg } + + pub fn get_new_temp_col(&mut self) -> BoundedString { + self.last_temp_col_num += 1; + format!("{TEMP_COL_NAME_PREFIX}_{}", self.last_temp_col_num) + .as_str() + .into() + } } impl Default for CodegenContext { @@ -85,6 +103,13 @@ impl Default for CodegenContext { } } +static AGGREGATE_FUNCTIONS: Set<&'static str> = phf_set! { + "count", + "max", + "min", + "sum", +}; + /// Generates intermediate code from the AST. pub fn codegen_ast(ast: &Statement) -> Result { let mut ctx = CodegenContext::default(); @@ -321,19 +346,23 @@ pub fn codegen_ast(ast: &Statement) -> Result { }); for projection in select.projection.clone() { + // TODO(now): call codegen_fn_agg here and use new register index as + // the input for that project + let expr = match projection { + SelectItem::UnnamedExpr(ref expr) => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::ExprWithAlias { ref expr, .. } => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::QualifiedWildcard(_) => Expr::Wildcard, + SelectItem::Wildcard => Expr::Wildcard, + }; + let projection = Instruction::Project { input: original_table_reg_index, output: table_reg_index, - expr: match projection { - SelectItem::UnnamedExpr(ref expr) => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::ExprWithAlias { ref expr, .. } => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - SelectItem::Wildcard => Expr::Wildcard, - }, + expr, alias: match projection { SelectItem::UnnamedExpr(_) => None, SelectItem::ExprWithAlias { alias, .. } => { @@ -524,36 +553,13 @@ fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result Ok(Expr::Value(v.try_into()?)), ast::Expr::Function(ref f) => { - match f.args { - &[ast::Expr::Identifier(i)] => { - ctx.push(Instruction::Aggregate { - input: ctx.last_used_reg(), - output: ctx.get_and_increment_reg(), - func: f.name.to_string(), - col_name: (), - }) - // Ok(Expr::ColumnRef(vec![i].try_into()?)) - } - } + let fn_name = f.name.to_string(); Ok(Expr::Function { - name: f.name.to_string().as_str().into(), + name: fn_name.as_str().into(), args: f .args .iter() - .map(|arg| match arg { - ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { - ast::FunctionArgExpr::Expr(e) => Ok(codegen_expr(e.clone(), instrs)?), - ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), - ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { - reason: "Qualified wildcards are not supported yet", - expr: expr_ast.clone(), - }), - }, - ast::FunctionArg::Named { .. } => Err(ExprError::Expr { - reason: "Named function arguments are not supported", - expr: expr_ast.clone(), - }), - }) + .map(|arg| codegen_fn_arg(&expr_ast, arg, ctx)) .collect::, _>>()?, }) } @@ -564,6 +570,67 @@ fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result Result<(RegisterIndex, Expr), ExprError> { + let fn_name = f.name.to_string(); + if AGGREGATE_FUNCTIONS.contains(&fn_name) { + // TODO: support aggregate functions that take multiple args + if f.args.len() != 1 { + return Err(ExprError::Expr { + reason: "Aggregate functions that take more or less than one argument are not supported yet", + expr: expr_ast.clone(), + }); + } + + if ctx.is_inside_agg_fn { + return Err(ExprError::Expr { + reason: "Aggregate functions cannot be nested", + expr: expr_ast.clone(), + }); + } + ctx.is_inside_agg_fn = true; + + let orig_table_reg = ctx.last_used_reg(); + let temp_table_reg = ctx.get_and_increment_reg(); + let projected_col_name = ctx.get_new_temp_col(); + + ctx.instrs.push(Instruction::Project { + input: orig_table_reg, + output: temp_table_reg, + expr: codegen_fn_arg(&expr_ast, &f.args[0], ctx)?, + alias: Some(projected_col_name), + }); + + // TODO(now): insert Aggregate instruction here + + ctx.is_inside_agg_fn = false; + } +} + +fn codegen_fn_arg( + expr_ast: &ast::Expr, + arg: &FunctionArg, + ctx: &mut CodegenContext, +) -> Result { + match arg { + ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { + ast::FunctionArgExpr::Expr(e) => Ok(codegen_expr(e.clone(), ctx)?), + ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), + ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { + reason: "Qualified wildcards are not supported yet", + expr: expr_ast.clone(), + }), + }, + ast::FunctionArg::Named { .. } => Err(ExprError::Expr { + reason: "Named function arguments are not supported", + expr: expr_ast.clone(), + }), + } +} + /// Error while generating an intermediate code from the AST. #[derive(Debug)] pub enum CodegenError { @@ -617,8 +684,8 @@ mod tests { use crate::{ codegen::codegen_ast, expr::{BinOp, Expr}, - ir::Instruction, identifier::{ColumnRef, SchemaRef, TableRef}, + ir::Instruction, parser::parse, value::Value, vm::RegisterIndex, From fab2a3381efe5dd581ebbc27659c101f886c7790 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Wed, 22 Feb 2023 00:18:24 +0530 Subject: [PATCH 05/26] feat: a special path for codegen of agg under group bys --- src/codegen.rs | 117 +++++++++++++++++++++++++++++++++++++----------- src/expr/agg.rs | 14 ++++++ src/expr/mod.rs | 6 ++- src/ir.rs | 8 ++-- 4 files changed, 115 insertions(+), 30 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 334cbff..9386db8 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -11,7 +11,7 @@ use phf::{phf_set, Set}; use std::{error::Error, fmt::Display}; use crate::{ - expr::{BinOp, Expr, ExprError, UnOp}, + expr::{agg::AggregateFunction, BinOp, Expr, ExprError, UnOp}, identifier::IdentifierError, ir::{Instruction, IntermediateCode}, parser::parse, @@ -345,9 +345,25 @@ pub fn codegen_ast(ast: &Statement) -> Result { index: table_reg_index, }); + let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); + + let intermediate_reg_index; + if has_aggs { + intermediate_reg_index = ctx.get_and_increment_reg(); + + ctx.instrs.push(Instruction::Empty { + index: intermediate_reg_index, + }); + } + for projection in select.projection.clone() { - // TODO(now): call codegen_fn_agg here and use new register index as - // the input for that project + let expr_ast = match projection { + SelectItem::UnnamedExpr(ref expr) + | SelectItem::ExprWithAlias { ref expr, .. } => Some(expr), + SelectItem::QualifiedWildcard(_) => None, + SelectItem::Wildcard => None, + }; + let expr = match projection { SelectItem::UnnamedExpr(ref expr) => { codegen_expr(expr.clone(), &mut ctx)? @@ -359,25 +375,45 @@ pub fn codegen_ast(ast: &Statement) -> Result { SelectItem::Wildcard => Expr::Wildcard, }; - let projection = Instruction::Project { - input: original_table_reg_index, - output: table_reg_index, - expr, - alias: match projection { - SelectItem::UnnamedExpr(_) => None, - SelectItem::ExprWithAlias { alias, .. } => { - Some(alias.value.as_str().into()) - } - SelectItem::QualifiedWildcard(name) => { - return Err(CodegenError::UnsupportedStatementForm( - "Qualified wildcards are not supported yet", - name.to_string(), - )) - } - SelectItem::Wildcard => None, - }, + let alias = match projection { + SelectItem::UnnamedExpr(_) => None, + SelectItem::ExprWithAlias { alias, .. } => { + Some(alias.value.as_str().into()) + } + SelectItem::QualifiedWildcard(name) => { + return Err(CodegenError::UnsupportedStatementForm( + "Qualified wildcards are not supported yet", + name.to_string(), + )) + } + SelectItem::Wildcard => None, }; - ctx.instrs.push(projection) + + match expr_ast { + // an aggregate operation + Some(expr_ast) if is_expr_agg(expr_ast) => match expr_ast { + ast::Expr::Function(ref f) => codegen_fn_agg( + expr_ast, + f, + alias, + original_table_reg_index, + intermediate_reg_index, + table_reg_index, + &mut ctx, + )?, + _ => {} + }, + // a non-aggregate operation + _ => { + let projection = Instruction::Project { + input: original_table_reg_index, + output: table_reg_index, + expr, + alias, + }; + ctx.instrs.push(projection) + } + } } if select.distinct { @@ -570,11 +606,35 @@ fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result bool { + AGGREGATE_FUNCTIONS.contains(fn_name) +} + +fn is_expr_agg(e: &ast::Expr) -> bool { + match e { + ast::Expr::Function(ref f) => is_fn_name_aggregate(&f.name.to_string().to_lowercase()), + _ => false, + } +} + +fn is_projection_agg(p: &SelectItem) -> bool { + match p { + SelectItem::UnnamedExpr(ref expr) => is_expr_agg(expr), + SelectItem::ExprWithAlias { ref expr, .. } => is_expr_agg(expr), + SelectItem::QualifiedWildcard(_) => false, + SelectItem::Wildcard => false, + } +} + fn codegen_fn_agg( expr_ast: &ast::Expr, f: &Function, + alias: Option, + orig_table_reg: RegisterIndex, + intermediate_reg_index: RegisterIndex, + output_reg_index: RegisterIndex, ctx: &mut CodegenContext, -) -> Result<(RegisterIndex, Expr), ExprError> { +) -> Result<(), ExprError> { let fn_name = f.name.to_string(); if AGGREGATE_FUNCTIONS.contains(&fn_name) { // TODO: support aggregate functions that take multiple args @@ -593,21 +653,26 @@ fn codegen_fn_agg( } ctx.is_inside_agg_fn = true; - let orig_table_reg = ctx.last_used_reg(); - let temp_table_reg = ctx.get_and_increment_reg(); let projected_col_name = ctx.get_new_temp_col(); ctx.instrs.push(Instruction::Project { input: orig_table_reg, - output: temp_table_reg, + output: intermediate_reg_index, expr: codegen_fn_arg(&expr_ast, &f.args[0], ctx)?, alias: Some(projected_col_name), }); - // TODO(now): insert Aggregate instruction here + ctx.instrs.push(Instruction::Aggregate { + input: intermediate_reg_index, + output: output_reg_index, + func: AggregateFunction::from_name(fn_name.as_str())?, + col_name: projected_col_name, + alias, + }); ctx.is_inside_agg_fn = false; } + Ok(()) } fn codegen_fn_arg( diff --git a/src/expr/agg.rs b/src/expr/agg.rs index 4eb2990..a11501f 100644 --- a/src/expr/agg.rs +++ b/src/expr/agg.rs @@ -1,3 +1,5 @@ +use super::ExprError; + #[derive(Clone, PartialEq, Eq)] /// Functions that reduce an entire column to a single value. pub enum AggregateFunction { @@ -5,3 +7,15 @@ pub enum AggregateFunction { Max, Sum, } + +impl AggregateFunction { + /// Get an aggregation function by name. + pub fn from_name(name: &str) -> Result { + match name.to_lowercase().as_str() { + "count" => Ok(Self::Count), + "max" => Ok(Self::Max), + "sum" => Ok(Self::Sum), + _ => Err(ExprError::UnknownAggregateFunction(name.to_owned())), + } + } +} diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 5d6e84a..c3ea682 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -10,8 +10,8 @@ use crate::{ BoundedString, }; -pub mod eval; pub mod agg; +pub mod eval; /// An expression #[derive(Debug, Clone, PartialEq)] @@ -193,6 +193,7 @@ pub enum ExprError { }, Value(ValueError), Identifier(IdentifierError), + UnknownAggregateFunction(String), } impl Display for ExprError { @@ -209,6 +210,9 @@ impl Display for ExprError { } ExprError::Value(v) => write!(f, "{}", v), ExprError::Identifier(v) => write!(f, "{}", v), + ExprError::UnknownAggregateFunction(agg) => { + write!(f, "Unsupported Aggregate Function: {}", agg) + } } } } diff --git a/src/ir.rs b/src/ir.rs index feb0184..4b4d0b9 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -68,14 +68,16 @@ pub enum Instruction { input: RegisterIndex, output: RegisterIndex, func: AggregateFunction, + /// Column in input to aggregate. + col_name: BoundedString, #[display( "{}", - match col_name { + match alias { None => "None".to_owned(), - Some(col_name) => format!("{}", col_name) + Some(alias) => format!("{}", alias) } )] - col_name: Option, + alias: Option, }, /// Group the [`Register::TableRef`](`crate::vm::Register::TableRef`) at `input` by the given expression. From 82f3f0e9c79f331dd6bc8c5b51707366b658e414 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 2 Mar 2023 00:00:03 +0530 Subject: [PATCH 06/26] fix: build errors --- src/codegen.rs | 8 +++++--- src/vm.rs | 9 +++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 9386db8..f2a387b 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -347,7 +347,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); - let intermediate_reg_index; + let mut intermediate_reg_index = table_reg_index; if has_aggs { intermediate_reg_index = ctx.get_and_increment_reg(); @@ -655,10 +655,11 @@ fn codegen_fn_agg( let projected_col_name = ctx.get_new_temp_col(); + let expr = codegen_fn_arg(&expr_ast, &f.args[0], ctx)?; ctx.instrs.push(Instruction::Project { input: orig_table_reg, output: intermediate_reg_index, - expr: codegen_fn_arg(&expr_ast, &f.args[0], ctx)?, + expr, alias: Some(projected_col_name), }); @@ -1406,7 +1407,8 @@ mod tests { }, }, Instruction::GroupBy { - index: RegisterIndex::default(), + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), expr: Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, diff --git a/src/vm.rs b/src/vm.rs index 90268f9..946d720 100644 --- a/src/vm.rs +++ b/src/vm.rs @@ -13,8 +13,8 @@ use crate::codegen::{codegen_ast, CodegenError}; use crate::column::Column; use crate::expr::eval::ExprExecError; use crate::expr::Expr; -use crate::ir::{Instruction, IntermediateCode}; use crate::identifier::{ColumnRef, TableRef}; +use crate::ir::{Instruction, IntermediateCode}; use crate::parser::parse; use crate::schema::Schema; use crate::table::{Row, RowShared, Table}; @@ -329,8 +329,13 @@ impl VirtualMachine { output: _, func: _, col_name: _, + alias: _, } => todo!("aggregate is not implemented yet"), - Instruction::GroupBy { index: _, expr: _ } => todo!("group by is not implemented yet"), + Instruction::GroupBy { + input: _, + output: _, + expr: _, + } => todo!("group by is not implemented yet"), Instruction::Order { index, expr, From cc208c25665ecf40bb61add3f1cf084cfbb878c3 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Thu, 2 Mar 2023 23:39:20 +0530 Subject: [PATCH 07/26] fix: make tests build and some refactoring --- src/codegen.rs | 528 +++++++++++++++++++++++++++++++++++++++++++++-- src/expr/agg.rs | 15 ++ src/expr/eval.rs | 332 ----------------------------- src/expr/mod.rs | 147 ------------- 4 files changed, 525 insertions(+), 497 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index f2a387b..3a4aee6 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -110,6 +110,14 @@ static AGGREGATE_FUNCTIONS: Set<&'static str> = phf_set! { "sum", }; +fn extract_expr_ast_from_project(projection: SelectItem) -> Option { + match projection { + SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => Some(expr), + SelectItem::QualifiedWildcard(_) => None, + SelectItem::Wildcard => None, + } +} + /// Generates intermediate code from the AST. pub fn codegen_ast(ast: &Statement) -> Result { let mut ctx = CodegenContext::default(); @@ -321,6 +329,9 @@ pub fn codegen_ast(ast: &Statement) -> Result { for group_by in select.group_by.clone() { let group_by = codegen_expr(group_by, &mut ctx)?; let grouped_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { + index: grouped_reg_index, + }); ctx.instrs.push(Instruction::GroupBy { input: table_reg_index, output: grouped_reg_index, @@ -329,14 +340,6 @@ pub fn codegen_ast(ast: &Statement) -> Result { table_reg_index = grouped_reg_index } - if let Some(expr) = select.having.clone() { - let expr = codegen_expr(expr, &mut ctx)?; - ctx.instrs.push(Instruction::Filter { - index: table_reg_index, - expr, - }) - } - if !select.projection.is_empty() { let original_table_reg_index = table_reg_index; table_reg_index = ctx.get_and_increment_reg(); @@ -357,12 +360,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } for projection in select.projection.clone() { - let expr_ast = match projection { - SelectItem::UnnamedExpr(ref expr) - | SelectItem::ExprWithAlias { ref expr, .. } => Some(expr), - SelectItem::QualifiedWildcard(_) => None, - SelectItem::Wildcard => None, - }; + let expr_ast = extract_expr_ast_from_project(projection.clone()); let expr = match projection { SelectItem::UnnamedExpr(ref expr) => { @@ -391,9 +389,9 @@ pub fn codegen_ast(ast: &Statement) -> Result { match expr_ast { // an aggregate operation - Some(expr_ast) if is_expr_agg(expr_ast) => match expr_ast { + Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { ast::Expr::Function(ref f) => codegen_fn_agg( - expr_ast, + &expr_ast, f, alias, original_table_reg_index, @@ -423,6 +421,14 @@ pub fn codegen_ast(ast: &Statement) -> Result { )); } } + + if let Some(expr) = select.having.clone() { + let expr = codegen_expr(expr, &mut ctx)?; + ctx.instrs.push(Instruction::Filter { + index: table_reg_index, + expr, + }) + } } SetExpr::Values(exprs) => { if exprs.0.len() == 1 && exprs.0[0].len() == 1 { @@ -635,7 +641,7 @@ fn codegen_fn_agg( output_reg_index: RegisterIndex, ctx: &mut CodegenContext, ) -> Result<(), ExprError> { - let fn_name = f.name.to_string(); + let fn_name = f.name.to_string().to_lowercase(); if AGGREGATE_FUNCTIONS.contains(&fn_name) { // TODO: support aggregate functions that take multiple args if f.args.len() != 1 { @@ -742,7 +748,7 @@ impl From for CodegenError { impl Error for CodegenError {} #[cfg(test)] -mod tests { +mod codegen_tests { use sqlparser::ast::{ColumnOption, ColumnOptionDef, DataType}; use pretty_assertions::assert_eq; @@ -1465,3 +1471,489 @@ mod tests { ); } } + +#[cfg(test)] +mod expr_codegen_tests { + use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer}; + + use crate::{ + codegen::{codegen_expr, CodegenContext}, + expr::{BinOp, Expr, ExprError, UnOp}, + identifier::ColumnRef, + value::Value, + }; + + #[test] + fn conversion_from_ast() { + fn parse_expr(s: &str) -> ast::Expr { + let dialect = GenericDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, s); + let tokens = tokenizer.tokenize().unwrap(); + let mut parser = Parser::new(tokens, &dialect); + parser.parse_expr().unwrap() + } + + fn codegen_expr_wrapper(expr_ast: ast::Expr) -> Result { + let mut ctx = CodegenContext::new(); + codegen_expr(expr_ast, &mut ctx) + } + + assert_eq!( + codegen_expr_wrapper(parse_expr("abc")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "abc".into() + })) + ); + + assert_ne!( + codegen_expr_wrapper(parse_expr("abc")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "cab".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("table1.col1")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: Some("table1".into()), + col_name: "col1".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("schema1.table1.col1")), + Ok(Expr::ColumnRef(ColumnRef { + schema_name: Some("schema1".into()), + table_name: Some("table1".into()), + col_name: "col1".into() + })) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("5 IS NULL")), + Ok(Expr::Unary { + op: UnOp::IsNull, + operand: Box::new(Expr::Value(Value::Int64(5))) + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("1 IS TRUE")), + Ok(Expr::Unary { + op: UnOp::IsTrue, + operand: Box::new(Expr::Value(Value::Int64(1))) + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("4 BETWEEN 3 AND 5")), + Ok(Expr::Binary { + left: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(3))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(4))) + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(4))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(5))) + }) + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("4 NOT BETWEEN 3 AND 5")), + Ok(Expr::Unary { + op: UnOp::Not, + operand: Box::new(Expr::Binary { + left: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(3))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(4))) + }), + op: BinOp::And, + right: Box::new(Expr::Binary { + left: Box::new(Expr::Value(Value::Int64(4))), + op: BinOp::LessThanOrEqual, + right: Box::new(Expr::Value(Value::Int64(5))) + }) + }) + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("MAX(col1)")), + Ok(Expr::Function { + name: "MAX".into(), + args: vec![Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into() + })] + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("some_func(col1, 1, 'abc')")), + Ok(Expr::Function { + name: "some_func".into(), + args: vec![ + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into() + }), + Expr::Value(Value::Int64(1)), + Expr::Value(Value::String("abc".to_owned())) + ] + }) + ); + + assert_eq!( + codegen_expr_wrapper(parse_expr("COUNT(*)")), + Ok(Expr::Function { + name: "COUNT".into(), + args: vec![Expr::Wildcard] + }) + ); + } +} + +#[cfg(test)] +mod expr_eval_tests { + use sqlparser::{ + ast::{ColumnOption, ColumnOptionDef, DataType}, + dialect::GenericDialect, + parser::Parser, + tokenizer::Tokenizer, + }; + + use crate::{ + column::Column, + expr::{eval::ExprExecError, BinOp, Expr, UnOp}, + table::{Row, Table}, + value::{Value, ValueBinaryOpError, ValueUnaryOpError}, + }; + + use super::{codegen_expr, CodegenContext}; + + fn str_to_expr(s: &str) -> Expr { + let dialect = GenericDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, s); + let tokens = tokenizer.tokenize().unwrap(); + let mut parser = Parser::new(tokens, &dialect); + let mut ctx = CodegenContext::new(); + codegen_expr(parser.parse_expr().unwrap(), &mut ctx).unwrap() + } + + fn exec_expr_no_context(expr: Expr) -> Result { + let mut table = Table::new_temp(0); + table.new_row(vec![]); + Expr::execute(&expr, &table, table.all_data()[0].to_shared()) + } + + fn exec_str_no_context(s: &str) -> Result { + let expr = str_to_expr(s); + exec_expr_no_context(expr) + } + + fn exec_str_with_context(s: &str, table: &Table, row: &Row) -> Result { + let expr = str_to_expr(s); + Expr::execute(&expr, table, row.to_shared()) + } + + #[test] + fn exec_value() { + assert_eq!(exec_str_no_context("NULL"), Ok(Value::Null)); + + assert_eq!(exec_str_no_context("true"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("1"), Ok(Value::Int64(1))); + + assert_eq!(exec_str_no_context("1.1"), Ok(Value::Float64(1.1.into()))); + + assert_eq!(exec_str_no_context(".1"), Ok(Value::Float64(0.1.into()))); + + assert_eq!( + exec_str_no_context("'str'"), + Ok(Value::String("str".to_owned())) + ); + } + + #[test] + fn exec_logical() { + assert_eq!(exec_str_no_context("true and true"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("true and false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and true"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("false and 10"), + Err(ValueBinaryOpError { + operator: BinOp::And, + values: (Value::Bool(false), Value::Int64(10)) + } + .into()) + ); + assert_eq!( + exec_str_no_context("10 and false"), + Err(ValueBinaryOpError { + operator: BinOp::And, + values: (Value::Int64(10), Value::Bool(false)) + } + .into()) + ); + + assert_eq!(exec_str_no_context("true or true"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("true or false"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false or true"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("false or false"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("true or 10"), + Err(ValueBinaryOpError { + operator: BinOp::Or, + values: (Value::Bool(true), Value::Int64(10)) + } + .into()) + ); + assert_eq!( + exec_str_no_context("10 or true"), + Err(ValueBinaryOpError { + operator: BinOp::Or, + values: (Value::Int64(10), Value::Bool(true)) + } + .into()) + ); + } + + #[test] + fn exec_arithmetic() { + assert_eq!(exec_str_no_context("1 + 1"), Ok(Value::Int64(2))); + assert_eq!( + exec_str_no_context("1.1 + 1.1"), + Ok(Value::Float64(2.2.into())) + ); + + // this applies to all binary ops + assert_eq!( + exec_str_no_context("1 + 1.1"), + Err(ValueBinaryOpError { + operator: BinOp::Plus, + values: (Value::Int64(1), Value::Float64(1.1.into())) + } + .into()) + ); + + assert_eq!(exec_str_no_context("4 - 2"), Ok(Value::Int64(2))); + assert_eq!(exec_str_no_context("4 - 6"), Ok(Value::Int64(-2))); + assert_eq!( + exec_str_no_context("4.5 - 2.2"), + Ok(Value::Float64(2.3.into())) + ); + + assert_eq!(exec_str_no_context("4 * 2"), Ok(Value::Int64(8))); + assert_eq!( + exec_str_no_context("0.5 * 2.2"), + Ok(Value::Float64(1.1.into())) + ); + + assert_eq!(exec_str_no_context("4 / 2"), Ok(Value::Int64(2))); + assert_eq!(exec_str_no_context("4 / 3"), Ok(Value::Int64(1))); + assert_eq!( + exec_str_no_context("4.0 / 2.0"), + Ok(Value::Float64(2.0.into())) + ); + assert_eq!( + exec_str_no_context("5.1 / 2.5"), + Ok(Value::Float64(2.04.into())) + ); + + assert_eq!(exec_str_no_context("5 % 2"), Ok(Value::Int64(1))); + assert_eq!( + exec_str_no_context("5.5 % 2.5"), + Ok(Value::Float64(0.5.into())) + ); + } + + #[test] + fn exec_comparison() { + assert_eq!(exec_str_no_context("1 = 1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 = 2"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 != 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1.1 = 1.1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1.2 = 1.22"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1.2 != 1.22"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("1 < 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 < 1"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 <= 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("1 <= 1"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 > 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 > 3"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("3 >= 2"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("3 >= 3"), Ok(Value::Bool(true))); + } + + #[test] + fn exec_pattern_match() { + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'kira'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'KIRA'"), + Ok(Value::Bool(false)) + ); + assert_eq!( + exec_str_no_context("'my name is yoshikage kira' LIKE 'kira yoshikage'"), + Ok(Value::Bool(false)) + ); + + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'kira'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRA'"), + Ok(Value::Bool(true)) + ); + assert_eq!( + exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRAA'"), + Ok(Value::Bool(false)) + ); + } + + #[test] + fn exec_unary() { + assert_eq!(exec_str_no_context("+1"), Ok(Value::Int64(1))); + assert_eq!(exec_str_no_context("+ -1"), Ok(Value::Int64(-1))); + assert_eq!(exec_str_no_context("-1"), Ok(Value::Int64(-1))); + assert_eq!(exec_str_no_context("- -1"), Ok(Value::Int64(1))); + assert_eq!(exec_str_no_context("not true"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("not false"), Ok(Value::Bool(true))); + + assert_eq!(exec_str_no_context("true is true"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false is false"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("false is true"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("true is false"), Ok(Value::Bool(false))); + assert_eq!( + exec_str_no_context("1 is true"), + Err(ValueUnaryOpError { + operator: UnOp::IsTrue, + value: Value::Int64(1) + } + .into()) + ); + + assert_eq!(exec_str_no_context("NULL is NULL"), Ok(Value::Bool(true))); + assert_eq!( + exec_str_no_context("NULL is not NULL"), + Ok(Value::Bool(false)) + ); + assert_eq!(exec_str_no_context("1 is NULL"), Ok(Value::Bool(false))); + assert_eq!(exec_str_no_context("1 is not NULL"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("0 is not NULL"), Ok(Value::Bool(true))); + assert_eq!(exec_str_no_context("'' is not NULL"), Ok(Value::Bool(true))); + } + + #[test] + fn exec_wildcard() { + assert_eq!( + exec_expr_no_context(Expr::Wildcard), + Err(ExprExecError::CannotExecute(Expr::Wildcard)) + ); + } + + #[test] + fn exec_column_ref() { + let mut table = Table::new( + "table1".into(), + vec![ + Column::new( + "col1".into(), + DataType::Int(None), + vec![ColumnOptionDef { + name: None, + option: ColumnOption::Unique { is_primary: true }, + }], + false, + ), + Column::new( + "col2".into(), + DataType::Int(None), + vec![ColumnOptionDef { + name: None, + option: ColumnOption::Unique { is_primary: false }, + }], + false, + ), + Column::new("col3".into(), DataType::String, vec![], false), + ], + ); + table.new_row(vec![ + Value::Int64(4), + Value::Int64(10), + Value::String("brr".to_owned()), + ]); + + assert_eq!( + table.all_data(), + vec![Row::new(vec![ + Value::Int64(4), + Value::Int64(10), + Value::String("brr".to_owned()) + ])] + ); + + assert_eq!( + exec_str_with_context("col1", &table, &table.all_data()[0]), + Ok(Value::Int64(4)) + ); + + assert_eq!( + exec_str_with_context("col3", &table, &table.all_data()[0]), + Ok(Value::String("brr".to_owned())) + ); + + assert_eq!( + exec_str_with_context("col1 = 4", &table, &table.all_data()[0]), + Ok(Value::Bool(true)) + ); + + assert_eq!( + exec_str_with_context("col1 + 1", &table, &table.all_data()[0]), + Ok(Value::Int64(5)) + ); + + assert_eq!( + exec_str_with_context("col1 + col2", &table, &table.all_data()[0]), + Ok(Value::Int64(14)) + ); + + assert_eq!( + exec_str_with_context( + "col1 + col2 = 10 or col1 * col2 = 40", + &table, + &table.all_data()[0] + ), + Ok(Value::Bool(true)) + ); + } +} diff --git a/src/expr/agg.rs b/src/expr/agg.rs index a11501f..85419aa 100644 --- a/src/expr/agg.rs +++ b/src/expr/agg.rs @@ -1,4 +1,5 @@ use super::ExprError; +use std::fmt::Display; #[derive(Clone, PartialEq, Eq)] /// Functions that reduce an entire column to a single value. @@ -19,3 +20,17 @@ impl AggregateFunction { } } } + +impl Display for AggregateFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Agg({})", + match self { + AggregateFunction::Count => "count", + AggregateFunction::Max => "max", + AggregateFunction::Sum => "sum", + } + ) + } +} diff --git a/src/expr/eval.rs b/src/expr/eval.rs index 7739186..1ae01d3 100644 --- a/src/expr/eval.rs +++ b/src/expr/eval.rs @@ -153,335 +153,3 @@ impl Display for ExprExecError { } impl Error for ExprExecError {} - -#[cfg(test)] -mod test { - use sqlparser::{ - ast::{ColumnOption, ColumnOptionDef, DataType}, - dialect::GenericDialect, - parser::Parser, - tokenizer::Tokenizer, - }; - - use crate::{ - column::Column, - expr::{BinOp, Expr, UnOp}, - table::{Row, Table}, - value::{Value, ValueBinaryOpError, ValueUnaryOpError}, - }; - - use super::ExprExecError; - - fn str_to_expr(s: &str) -> Expr { - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, s); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens, &dialect); - parser.parse_expr().unwrap().try_into().unwrap() - } - - fn exec_expr_no_context(expr: Expr) -> Result { - let mut table = Table::new_temp(0); - table.new_row(vec![]); - Expr::execute(&expr, &table, table.all_data()[0].to_shared()) - } - - fn exec_str_no_context(s: &str) -> Result { - let expr = str_to_expr(s); - exec_expr_no_context(expr) - } - - fn exec_str_with_context(s: &str, table: &Table, row: &Row) -> Result { - let expr = str_to_expr(s); - Expr::execute(&expr, table, row.to_shared()) - } - - #[test] - fn exec_value() { - assert_eq!(exec_str_no_context("NULL"), Ok(Value::Null)); - - assert_eq!(exec_str_no_context("true"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("1"), Ok(Value::Int64(1))); - - assert_eq!(exec_str_no_context("1.1"), Ok(Value::Float64(1.1.into()))); - - assert_eq!(exec_str_no_context(".1"), Ok(Value::Float64(0.1.into()))); - - assert_eq!( - exec_str_no_context("'str'"), - Ok(Value::String("str".to_owned())) - ); - } - - #[test] - fn exec_logical() { - assert_eq!(exec_str_no_context("true and true"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("true and false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and true"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("false and 10"), - Err(ValueBinaryOpError { - operator: BinOp::And, - values: (Value::Bool(false), Value::Int64(10)) - } - .into()) - ); - assert_eq!( - exec_str_no_context("10 and false"), - Err(ValueBinaryOpError { - operator: BinOp::And, - values: (Value::Int64(10), Value::Bool(false)) - } - .into()) - ); - - assert_eq!(exec_str_no_context("true or true"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("true or false"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false or true"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("false or false"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("true or 10"), - Err(ValueBinaryOpError { - operator: BinOp::Or, - values: (Value::Bool(true), Value::Int64(10)) - } - .into()) - ); - assert_eq!( - exec_str_no_context("10 or true"), - Err(ValueBinaryOpError { - operator: BinOp::Or, - values: (Value::Int64(10), Value::Bool(true)) - } - .into()) - ); - } - - #[test] - fn exec_arithmetic() { - assert_eq!(exec_str_no_context("1 + 1"), Ok(Value::Int64(2))); - assert_eq!( - exec_str_no_context("1.1 + 1.1"), - Ok(Value::Float64(2.2.into())) - ); - - // this applies to all binary ops - assert_eq!( - exec_str_no_context("1 + 1.1"), - Err(ValueBinaryOpError { - operator: BinOp::Plus, - values: (Value::Int64(1), Value::Float64(1.1.into())) - } - .into()) - ); - - assert_eq!(exec_str_no_context("4 - 2"), Ok(Value::Int64(2))); - assert_eq!(exec_str_no_context("4 - 6"), Ok(Value::Int64(-2))); - assert_eq!( - exec_str_no_context("4.5 - 2.2"), - Ok(Value::Float64(2.3.into())) - ); - - assert_eq!(exec_str_no_context("4 * 2"), Ok(Value::Int64(8))); - assert_eq!( - exec_str_no_context("0.5 * 2.2"), - Ok(Value::Float64(1.1.into())) - ); - - assert_eq!(exec_str_no_context("4 / 2"), Ok(Value::Int64(2))); - assert_eq!(exec_str_no_context("4 / 3"), Ok(Value::Int64(1))); - assert_eq!( - exec_str_no_context("4.0 / 2.0"), - Ok(Value::Float64(2.0.into())) - ); - assert_eq!( - exec_str_no_context("5.1 / 2.5"), - Ok(Value::Float64(2.04.into())) - ); - - assert_eq!(exec_str_no_context("5 % 2"), Ok(Value::Int64(1))); - assert_eq!( - exec_str_no_context("5.5 % 2.5"), - Ok(Value::Float64(0.5.into())) - ); - } - - #[test] - fn exec_comparison() { - assert_eq!(exec_str_no_context("1 = 1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 = 2"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 != 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1.1 = 1.1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1.2 = 1.22"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1.2 != 1.22"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("1 < 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 < 1"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 <= 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("1 <= 1"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 > 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 > 3"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("3 >= 2"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("3 >= 3"), Ok(Value::Bool(true))); - } - - #[test] - fn exec_pattern_match() { - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'kira'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'KIRA'"), - Ok(Value::Bool(false)) - ); - assert_eq!( - exec_str_no_context("'my name is yoshikage kira' LIKE 'kira yoshikage'"), - Ok(Value::Bool(false)) - ); - - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'kira'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRA'"), - Ok(Value::Bool(true)) - ); - assert_eq!( - exec_str_no_context("'my name is Yoshikage Kira' ILIKE 'KIRAA'"), - Ok(Value::Bool(false)) - ); - } - - #[test] - fn exec_unary() { - assert_eq!(exec_str_no_context("+1"), Ok(Value::Int64(1))); - assert_eq!(exec_str_no_context("+ -1"), Ok(Value::Int64(-1))); - assert_eq!(exec_str_no_context("-1"), Ok(Value::Int64(-1))); - assert_eq!(exec_str_no_context("- -1"), Ok(Value::Int64(1))); - assert_eq!(exec_str_no_context("not true"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("not false"), Ok(Value::Bool(true))); - - assert_eq!(exec_str_no_context("true is true"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false is false"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("false is true"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("true is false"), Ok(Value::Bool(false))); - assert_eq!( - exec_str_no_context("1 is true"), - Err(ValueUnaryOpError { - operator: UnOp::IsTrue, - value: Value::Int64(1) - } - .into()) - ); - - assert_eq!(exec_str_no_context("NULL is NULL"), Ok(Value::Bool(true))); - assert_eq!( - exec_str_no_context("NULL is not NULL"), - Ok(Value::Bool(false)) - ); - assert_eq!(exec_str_no_context("1 is NULL"), Ok(Value::Bool(false))); - assert_eq!(exec_str_no_context("1 is not NULL"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("0 is not NULL"), Ok(Value::Bool(true))); - assert_eq!(exec_str_no_context("'' is not NULL"), Ok(Value::Bool(true))); - } - - #[test] - fn exec_wildcard() { - assert_eq!( - exec_expr_no_context(Expr::Wildcard), - Err(ExprExecError::CannotExecute(Expr::Wildcard)) - ); - } - - #[test] - fn exec_column_ref() { - let mut table = Table::new( - "table1".into(), - vec![ - Column::new( - "col1".into(), - DataType::Int(None), - vec![ColumnOptionDef { - name: None, - option: ColumnOption::Unique { is_primary: true }, - }], - false, - ), - Column::new( - "col2".into(), - DataType::Int(None), - vec![ColumnOptionDef { - name: None, - option: ColumnOption::Unique { is_primary: false }, - }], - false, - ), - Column::new("col3".into(), DataType::String, vec![], false), - ], - ); - table.new_row(vec![ - Value::Int64(4), - Value::Int64(10), - Value::String("brr".to_owned()), - ]); - - assert_eq!( - table.all_data(), - vec![Row::new(vec![ - Value::Int64(4), - Value::Int64(10), - Value::String("brr".to_owned()) - ])] - ); - - assert_eq!( - exec_str_with_context("col1", &table, &table.all_data()[0]), - Ok(Value::Int64(4)) - ); - - assert_eq!( - exec_str_with_context("col3", &table, &table.all_data()[0]), - Ok(Value::String("brr".to_owned())) - ); - - assert_eq!( - exec_str_with_context("col1 = 4", &table, &table.all_data()[0]), - Ok(Value::Bool(true)) - ); - - assert_eq!( - exec_str_with_context("col1 + 1", &table, &table.all_data()[0]), - Ok(Value::Int64(5)) - ); - - assert_eq!( - exec_str_with_context("col1 + col2", &table, &table.all_data()[0]), - Ok(Value::Int64(14)) - ); - - assert_eq!( - exec_str_with_context( - "col1 + col2 = 10 or col1 * col2 = 40", - &table, - &table.all_data()[0] - ), - Ok(Value::Bool(true)) - ); - } -} diff --git a/src/expr/mod.rs b/src/expr/mod.rs index c3ea682..c6151a6 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -230,150 +230,3 @@ impl From for ExprError { } impl Error for ExprError {} - -#[cfg(test)] -mod tests { - use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer}; - - use crate::{ - expr::{BinOp, Expr, UnOp}, - identifier::ColumnRef, - value::Value, - }; - - #[test] - fn conversion_from_ast() { - fn parse_expr(s: &str) -> ast::Expr { - let dialect = GenericDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, s); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens, &dialect); - parser.parse_expr().unwrap() - } - - assert_eq!( - parse_expr("abc").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "abc".into() - })) - ); - - assert_ne!( - parse_expr("abc").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "cab".into() - })) - ); - - assert_eq!( - parse_expr("table1.col1").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: Some("table1".into()), - col_name: "col1".into() - })) - ); - - assert_eq!( - parse_expr("schema1.table1.col1").try_into(), - Ok(Expr::ColumnRef(ColumnRef { - schema_name: Some("schema1".into()), - table_name: Some("table1".into()), - col_name: "col1".into() - })) - ); - - assert_eq!( - parse_expr("5 IS NULL").try_into(), - Ok(Expr::Unary { - op: UnOp::IsNull, - operand: Box::new(Expr::Value(Value::Int64(5))) - }) - ); - - assert_eq!( - parse_expr("1 IS TRUE").try_into(), - Ok(Expr::Unary { - op: UnOp::IsTrue, - operand: Box::new(Expr::Value(Value::Int64(1))) - }) - ); - - assert_eq!( - parse_expr("4 BETWEEN 3 AND 5").try_into(), - Ok(Expr::Binary { - left: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(3))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(4))) - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(4))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(5))) - }) - }) - ); - - assert_eq!( - parse_expr("4 NOT BETWEEN 3 AND 5").try_into(), - Ok(Expr::Unary { - op: UnOp::Not, - operand: Box::new(Expr::Binary { - left: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(3))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(4))) - }), - op: BinOp::And, - right: Box::new(Expr::Binary { - left: Box::new(Expr::Value(Value::Int64(4))), - op: BinOp::LessThanOrEqual, - right: Box::new(Expr::Value(Value::Int64(5))) - }) - }) - }) - ); - - assert_eq!( - parse_expr("MAX(col1)").try_into(), - Ok(Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col1".into() - })] - }) - ); - - assert_eq!( - parse_expr("some_func(col1, 1, 'abc')").try_into(), - Ok(Expr::Function { - name: "some_func".into(), - args: vec![ - Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col1".into() - }), - Expr::Value(Value::Int64(1)), - Expr::Value(Value::String("abc".to_owned())) - ] - }) - ); - - assert_eq!( - parse_expr("COUNT(*)").try_into(), - Ok(Expr::Function { - name: "COUNT".into(), - args: vec![Expr::Wildcard] - }) - ); - } -} From 519f1e51d99997f72db984ff576e34c4094dd437 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 5 Mar 2023 20:00:21 +0530 Subject: [PATCH 08/26] fix(codegen): separate pre and post groupby projections --- src/codegen.rs | 166 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 108 insertions(+), 58 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 3a4aee6..5aa74d1 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -118,6 +118,20 @@ fn extract_expr_ast_from_project(projection: SelectItem) -> Option { } } +fn extract_alias_from_project( + projection: &SelectItem, +) -> Result, CodegenError> { + match projection { + SelectItem::UnnamedExpr(_) => Ok(None), + SelectItem::ExprWithAlias { alias, .. } => Ok(Some(alias.value.as_str().into())), + SelectItem::QualifiedWildcard(name) => Err(CodegenError::UnsupportedStatementForm( + "Qualified wildcards are not supported yet", + name.to_string(), + )), + SelectItem::Wildcard => Ok(None), + } +} + /// Generates intermediate code from the AST. pub fn codegen_ast(ast: &Statement) -> Result { let mut ctx = CodegenContext::default(); @@ -326,6 +340,72 @@ pub fn codegen_ast(ast: &Statement) -> Result { }) } + // if there are groupby + aggregations, we project all operations within an + // aggregation to another table first. for example, `SUM(col * col)` would be + // evaluated as `Project (col * col)` into `%2` and then apply the group by on + // `%2`. + // TODO: possible idea for refactor: make an intermediate representation of + // Projection that separates non-agg and agg projections. + let pre_grouped_reg_index = table_reg_index; + let mut agg_intermediate_cols = Vec::new(); + if !select.projection.is_empty() { + let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); + + if has_aggs { + let grouped_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { + index: grouped_reg_index, + }); + + table_reg_index = grouped_reg_index; + + for projection in &select.projection { + let alias = extract_alias_from_project(&projection)?; + + let expr_ast = extract_expr_ast_from_project(projection.clone()); + + match expr_ast { + // an aggregate operation + Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { + ast::Expr::Function(ref f) => { + let expr = + codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; + let projected_col_name = ctx.get_new_temp_col(); + agg_intermediate_cols.push(projected_col_name); + ctx.instrs.push(Instruction::Project { + input: pre_grouped_reg_index, + output: grouped_reg_index, + expr, + alias: Some(projected_col_name), + }); + } + _ => {} + }, + // a non-aggregate operation + _ => { + let expr = match projection { + SelectItem::UnnamedExpr(ref expr) => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::ExprWithAlias { ref expr, .. } => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::QualifiedWildcard(_) => Expr::Wildcard, + SelectItem::Wildcard => Expr::Wildcard, + }; + let projection = Instruction::Project { + input: pre_grouped_reg_index, + output: grouped_reg_index, + expr, + alias, + }; + ctx.instrs.push(projection) + } + } + } + } + } + for group_by in select.group_by.clone() { let group_by = codegen_expr(group_by, &mut ctx)?; let grouped_reg_index = ctx.get_and_increment_reg(); @@ -340,6 +420,9 @@ pub fn codegen_ast(ast: &Statement) -> Result { table_reg_index = grouped_reg_index } + // this is only for aggregations. + // aggs are applied on the grouped table created by the `GroupBy` instructions + // generated above. if !select.projection.is_empty() { let original_table_reg_index = table_reg_index; table_reg_index = ctx.get_and_increment_reg(); @@ -350,66 +433,33 @@ pub fn codegen_ast(ast: &Statement) -> Result { let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); - let mut intermediate_reg_index = table_reg_index; if has_aggs { - intermediate_reg_index = ctx.get_and_increment_reg(); - - ctx.instrs.push(Instruction::Empty { - index: intermediate_reg_index, - }); - } - - for projection in select.projection.clone() { - let expr_ast = extract_expr_ast_from_project(projection.clone()); - - let expr = match projection { - SelectItem::UnnamedExpr(ref expr) => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::ExprWithAlias { ref expr, .. } => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - SelectItem::Wildcard => Expr::Wildcard, - }; - - let alias = match projection { - SelectItem::UnnamedExpr(_) => None, - SelectItem::ExprWithAlias { alias, .. } => { - Some(alias.value.as_str().into()) - } - SelectItem::QualifiedWildcard(name) => { - return Err(CodegenError::UnsupportedStatementForm( - "Qualified wildcards are not supported yet", - name.to_string(), - )) - } - SelectItem::Wildcard => None, - }; - - match expr_ast { - // an aggregate operation - Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - ast::Expr::Function(ref f) => codegen_fn_agg( - &expr_ast, - f, - alias, - original_table_reg_index, - intermediate_reg_index, - table_reg_index, - &mut ctx, - )?, + let mut agg_index = 0; + for projection in &select.projection { + let alias = extract_alias_from_project(&projection)?; + + let expr_ast = extract_expr_ast_from_project(projection.clone()); + + match expr_ast { + // an aggregate operation + Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { + ast::Expr::Function(ref f) => { + ctx.instrs.push(Instruction::Aggregate { + input: original_table_reg_index, + output: table_reg_index, + func: AggregateFunction::from_name( + f.name.to_string().to_lowercase().as_str(), + )?, + col_name: agg_intermediate_cols[agg_index], + alias, + }); + agg_index += 1; + } + _ => unreachable!( + "check for fn is already done. this should not happen." + ), + }, _ => {} - }, - // a non-aggregate operation - _ => { - let projection = Instruction::Project { - input: original_table_reg_index, - output: table_reg_index, - expr, - alias, - }; - ctx.instrs.push(projection) } } } From 33a5577c1cda674e066f8030ec89caa12d752b28 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 12 Mar 2023 23:10:57 +0530 Subject: [PATCH 09/26] fix(codegen): incorrect agg checks. Non-groupby tests pass now. --- src/codegen.rs | 97 ++++++++++++++++++++++++-------------------------- 1 file changed, 46 insertions(+), 51 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 5aa74d1..afb1062 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -349,58 +349,53 @@ pub fn codegen_ast(ast: &Statement) -> Result { let pre_grouped_reg_index = table_reg_index; let mut agg_intermediate_cols = Vec::new(); if !select.projection.is_empty() { - let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); - - if has_aggs { - let grouped_reg_index = ctx.get_and_increment_reg(); - ctx.instrs.push(Instruction::Empty { - index: grouped_reg_index, - }); + let grouped_reg_index = ctx.get_and_increment_reg(); + ctx.instrs.push(Instruction::Empty { + index: grouped_reg_index, + }); - table_reg_index = grouped_reg_index; + table_reg_index = grouped_reg_index; - for projection in &select.projection { - let alias = extract_alias_from_project(&projection)?; + for projection in &select.projection { + let alias = extract_alias_from_project(&projection)?; - let expr_ast = extract_expr_ast_from_project(projection.clone()); + let expr_ast = extract_expr_ast_from_project(projection.clone()); - match expr_ast { - // an aggregate operation - Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - ast::Expr::Function(ref f) => { - let expr = - codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; - let projected_col_name = ctx.get_new_temp_col(); - agg_intermediate_cols.push(projected_col_name); - ctx.instrs.push(Instruction::Project { - input: pre_grouped_reg_index, - output: grouped_reg_index, - expr, - alias: Some(projected_col_name), - }); - } - _ => {} - }, - // a non-aggregate operation - _ => { - let expr = match projection { - SelectItem::UnnamedExpr(ref expr) => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::ExprWithAlias { ref expr, .. } => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - SelectItem::Wildcard => Expr::Wildcard, - }; - let projection = Instruction::Project { + match expr_ast { + // an aggregate operation + Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { + ast::Expr::Function(ref f) => { + let expr = codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; + let projected_col_name = ctx.get_new_temp_col(); + agg_intermediate_cols.push(projected_col_name); + ctx.instrs.push(Instruction::Project { input: pre_grouped_reg_index, output: grouped_reg_index, expr, - alias, - }; - ctx.instrs.push(projection) + alias: Some(projected_col_name), + }); } + _ => {} + }, + // a non-aggregate operation + _ => { + let expr = match projection { + SelectItem::UnnamedExpr(ref expr) => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::ExprWithAlias { ref expr, .. } => { + codegen_expr(expr.clone(), &mut ctx)? + } + SelectItem::QualifiedWildcard(_) => Expr::Wildcard, + SelectItem::Wildcard => Expr::Wildcard, + }; + let projection = Instruction::Project { + input: pre_grouped_reg_index, + output: grouped_reg_index, + expr, + alias, + }; + ctx.instrs.push(projection) } } } @@ -424,16 +419,16 @@ pub fn codegen_ast(ast: &Statement) -> Result { // aggs are applied on the grouped table created by the `GroupBy` instructions // generated above. if !select.projection.is_empty() { - let original_table_reg_index = table_reg_index; - table_reg_index = ctx.get_and_increment_reg(); - - ctx.instrs.push(Instruction::Empty { - index: table_reg_index, - }); - let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); if has_aggs { + let original_table_reg_index = table_reg_index; + table_reg_index = ctx.get_and_increment_reg(); + + ctx.instrs.push(Instruction::Empty { + index: table_reg_index, + }); + let mut agg_index = 0; for projection in &select.projection { let alias = extract_alias_from_project(&projection)?; From 35d41d9ee7a739c5e48e9df0bfa992be9f8677e0 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 12 Mar 2023 23:11:30 +0530 Subject: [PATCH 10/26] fix(codegen): remove unused fn --- src/codegen.rs | 54 +------------------------------------------------- 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index afb1062..4e19bef 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,8 +1,6 @@ //! Intermediate code generation from the AST. use sqlparser::{ - ast::{ - self, Function, FunctionArg, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins, - }, + ast::{self, FunctionArg, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins}, parser::ParserError, }; @@ -677,56 +675,6 @@ fn is_projection_agg(p: &SelectItem) -> bool { } } -fn codegen_fn_agg( - expr_ast: &ast::Expr, - f: &Function, - alias: Option, - orig_table_reg: RegisterIndex, - intermediate_reg_index: RegisterIndex, - output_reg_index: RegisterIndex, - ctx: &mut CodegenContext, -) -> Result<(), ExprError> { - let fn_name = f.name.to_string().to_lowercase(); - if AGGREGATE_FUNCTIONS.contains(&fn_name) { - // TODO: support aggregate functions that take multiple args - if f.args.len() != 1 { - return Err(ExprError::Expr { - reason: "Aggregate functions that take more or less than one argument are not supported yet", - expr: expr_ast.clone(), - }); - } - - if ctx.is_inside_agg_fn { - return Err(ExprError::Expr { - reason: "Aggregate functions cannot be nested", - expr: expr_ast.clone(), - }); - } - ctx.is_inside_agg_fn = true; - - let projected_col_name = ctx.get_new_temp_col(); - - let expr = codegen_fn_arg(&expr_ast, &f.args[0], ctx)?; - ctx.instrs.push(Instruction::Project { - input: orig_table_reg, - output: intermediate_reg_index, - expr, - alias: Some(projected_col_name), - }); - - ctx.instrs.push(Instruction::Aggregate { - input: intermediate_reg_index, - output: output_reg_index, - func: AggregateFunction::from_name(fn_name.as_str())?, - col_name: projected_col_name, - alias, - }); - - ctx.is_inside_agg_fn = false; - } - Ok(()) -} - fn codegen_fn_arg( expr_ast: &ast::Expr, arg: &FunctionArg, From 8637feff75d4b5ecb12ebada7c1070fba9b5944e Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 12 Mar 2023 23:11:49 +0530 Subject: [PATCH 11/26] test(codegen): fix groupby test output --- src/codegen.rs | 85 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 33 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 4e19bef..554917c 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -748,7 +748,7 @@ mod codegen_tests { use crate::{ codegen::codegen_ast, - expr::{BinOp, Expr}, + expr::{agg::AggregateFunction, BinOp, Expr}, identifier::{ColumnRef, SchemaRef, TableRef}, ir::Instruction, parser::parse, @@ -1380,7 +1380,7 @@ mod codegen_tests { FROM table1 WHERE col1 = 1 GROUP BY col2 - HAVING MAX(col3) > 10 + HAVING max_col3 > 10 ", |instrs| { assert_eq!( @@ -1405,30 +1405,6 @@ mod codegen_tests { right: Box::new(Expr::Value(Value::Int64(1))) }, }, - Instruction::GroupBy { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col2".into(), - }) - }, - Instruction::Filter { - index: RegisterIndex::default(), - expr: Expr::Binary { - left: Box::new(Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col3".into(), - })] - }), - op: BinOp::GreaterThan, - right: Box::new(Expr::Value(Value::Int64(10))) - }, - }, Instruction::Empty { index: RegisterIndex::default().next_index() }, @@ -1445,18 +1421,61 @@ mod codegen_tests { Instruction::Project { input: RegisterIndex::default(), output: RegisterIndex::default().next_index(), - expr: Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Max, + col_name: "__otter_temp_col_1".into(), + alias: Some("max_col3".into()), + }, + Instruction::Filter { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, - col_name: "col3".into(), - })] + col_name: "max_col3".into(), + })), + op: BinOp::GreaterThan, + right: Box::new(Expr::Value(Value::Int64(10))) }, - alias: Some("max_col3".into()) }, Instruction::Return { - index: RegisterIndex::default().next_index(), + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), } ] ) From f22dc6cd47aee14583123bbe802643993a26e92a Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 12 Mar 2023 23:12:49 +0530 Subject: [PATCH 12/26] chore(codegen): remove unused field in context --- src/codegen.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 554917c..e1ecb30 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -64,7 +64,6 @@ struct CodegenContext { pub instrs: Vec, current_reg: RegisterIndex, last_temp_col_num: usize, - is_inside_agg_fn: bool, } impl CodegenContext { @@ -73,7 +72,6 @@ impl CodegenContext { instrs: Vec::new(), current_reg: RegisterIndex::default(), last_temp_col_num: 0, - is_inside_agg_fn: false, } } From 0c25910c02600a57f3eae049aeeb3cba3638bc77 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Tue, 14 Mar 2023 18:38:07 +0530 Subject: [PATCH 13/26] test(codegen): add case of multiple groupbys --- src/codegen.rs | 133 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) diff --git a/src/codegen.rs b/src/codegen.rs index e1ecb30..a16d8ea 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1479,6 +1479,139 @@ mod codegen_tests { ) }, ); + + check_single_statement( + "SELECT col2, col3, SUM(col4 * col4) AS sos + FROM table1 + WHERE col1 = 1 + GROUP BY col2, col3 + ", + |instrs| { + assert_eq!( + &[ + Instruction::Source { + index: RegisterIndex::default(), + name: TableRef { + schema_name: None, + table_name: "table1".into() + } + }, + Instruction::Filter { + index: RegisterIndex::default(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into(), + },)), + op: BinOp::Equal, + right: Box::new(Expr::Value(Value::Int64(1))) + }, + }, + Instruction::Empty { + index: RegisterIndex::default().next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col4".into(), + })), + op: BinOp::Multiply, + right: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col4".into(), + })) + }, + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Sum, + col_name: "__otter_temp_col_1".into(), + alias: Some("sos".into()), + }, + Instruction::Return { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + } + ], + instrs + ) + }, + ); } } From 1acd25e809ab932a2d2aea58563abeb075e11e5d Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Tue, 14 Mar 2023 18:38:24 +0530 Subject: [PATCH 14/26] fix(codegen): derive Debug for AggregateFunction --- src/expr/agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expr/agg.rs b/src/expr/agg.rs index 85419aa..0f3d9c7 100644 --- a/src/expr/agg.rs +++ b/src/expr/agg.rs @@ -1,7 +1,7 @@ use super::ExprError; use std::fmt::Display; -#[derive(Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] /// Functions that reduce an entire column to a single value. pub enum AggregateFunction { Count, From 12101259498549fcf627bf32378977e8716d9c26 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 9 Apr 2023 16:53:19 +0530 Subject: [PATCH 15/26] fix(codegen): throw error when having clause has inline aggs --- src/codegen.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/codegen.rs b/src/codegen.rs index a16d8ea..c5d0c35 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -464,6 +464,16 @@ pub fn codegen_ast(ast: &Statement) -> Result { } if let Some(expr) = select.having.clone() { + if is_expr_agg(&expr) { + return Err(CodegenError::UnsupportedStatementForm( + concat!( + "HAVING clause does not support inline aggregations.", + " Select the expression `AS some_col_name` ", + "and then use `HAVING` on `some_col_name`." + ), + select.to_string(), + )); + } let expr = codegen_expr(expr, &mut ctx)?; ctx.instrs.push(Instruction::Filter { index: table_reg_index, From 2984aec18fe8db75cd2279a9cc373d65b3d32e8d Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 9 Apr 2023 16:54:00 +0530 Subject: [PATCH 16/26] wip(codegen): IntermediateExpr and codegen for it Not yet done for Function --- src/codegen.rs | 329 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 292 insertions(+), 37 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index c5d0c35..e87b00a 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -584,36 +584,114 @@ pub fn codegen_ast(ast: &Statement) -> Result { Ok(IntermediateCode { instrs: ctx.instrs }) } -fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result { +#[derive(Debug)] +struct IntermediateExprAgg { + pub pre_agg: Vec<(Expr, BoundedString)>, + pub agg: Vec<(AggregateFunction, BoundedString, BoundedString)>, + pub post_agg: Vec, + last_alias: Option, + last_expr: (Expr, BoundedString), +} + +#[derive(Debug)] +enum IntermediateExpr { + NonAgg(Expr), + Agg(IntermediateExprAgg), +} + +impl IntermediateExpr { + pub fn new_non_agg(expr: Expr) -> Self { + Self::NonAgg(expr) + } + + /// The last expression that was generated. + pub fn last_expr(&self) -> &Expr { + match self { + Self::NonAgg(e) => e, + Self::Agg(agg) => &agg.last_expr.0, + } + } + + pub fn combine(self, new: IntermediateExpr) -> Self { + match self { + Self::NonAgg(sel) => new, + Self::Agg(sel) => match new { + Self::NonAgg(new) => { + // TODO: last_expr may need updating here? + if sel.post_agg.len() <= 1 { + sel.post_agg = vec![new]; + } else { + *sel.post_agg.last_mut().unwrap() = new; + } + Self::Agg(sel) + } + Self::Agg(new) => { + // TODO: last_expr may need updating here? + sel.pre_agg.extend_from_slice(&new.pre_agg); + sel.agg.extend_from_slice(&new.agg); + sel.post_agg.extend_from_slice(&new.post_agg); + Self::Agg(sel) + } + }, + } + } +} + +fn codegen_expr( + expr_ast: ast::Expr, + ctx: &mut CodegenContext, +) -> Result { match expr_ast { - ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), - ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), - ast::Expr::IsFalse(e) => Ok(Expr::Unary { - op: UnOp::IsFalse, - operand: Box::new(codegen_expr(*e, ctx)?), - }), - ast::Expr::IsTrue(e) => Ok(Expr::Unary { - op: UnOp::IsTrue, - operand: Box::new(codegen_expr(*e, ctx)?), - }), - ast::Expr::IsNull(e) => Ok(Expr::Unary { - op: UnOp::IsNull, - operand: Box::new(codegen_expr(*e, ctx)?), - }), - ast::Expr::IsNotNull(e) => Ok(Expr::Unary { - op: UnOp::IsNotNull, - operand: Box::new(codegen_expr(*e, ctx)?), - }), + ast::Expr::Identifier(i) => Ok(IntermediateExpr::new_non_agg(Expr::ColumnRef( + vec![i].try_into()?, + ))), + ast::Expr::CompoundIdentifier(i) => Ok(IntermediateExpr::new_non_agg(Expr::ColumnRef( + i.try_into()?, + ))), + ast::Expr::IsFalse(e) => { + let inner = codegen_expr(*e, ctx)?; + Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: UnOp::IsFalse, + operand: Box::new(inner.last_expr().clone()), + }))) + } + ast::Expr::IsTrue(e) => { + let inner = codegen_expr(*e, ctx)?; + Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: UnOp::IsTrue, + operand: Box::new(inner.last_expr().clone()), + }))) + } + ast::Expr::IsNull(e) => { + let inner = codegen_expr(*e, ctx)?; + Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: UnOp::IsNull, + operand: Box::new(inner.last_expr().clone()), + }))) + } + ast::Expr::IsNotNull(e) => { + let inner = codegen_expr(*e, ctx)?; + Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: UnOp::IsNotNull, + operand: Box::new(inner.last_expr().clone()), + }))) + } ast::Expr::Between { expr, negated, low, high, } => { - let expr: Box = Box::new(codegen_expr(*expr, ctx)?); - let left = Box::new(codegen_expr(*low, ctx)?); - let right = Box::new(codegen_expr(*high, ctx)?); - let between = Expr::Binary { + let expr_gen = codegen_expr(*expr, ctx)?; + let expr: Box = Box::new(expr_gen.last_expr().clone()); + + let left_gen = codegen_expr(*low, ctx)?; + let left = Box::new(left_gen.last_expr().clone()); + + let right_gen = codegen_expr(*high, ctx)?; + let right = Box::new(right_gen.last_expr().clone()); + + let between_gen = IntermediateExpr::new_non_agg(Expr::Binary { left: Box::new(Expr::Binary { left, op: BinOp::LessThanOrEqual, @@ -625,26 +703,46 @@ fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result Ok(Expr::Binary { - left: Box::new(codegen_expr(*left, ctx)?), - op: op.try_into()?, - right: Box::new(codegen_expr(*right, ctx)?), - }), - ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { - op: op.try_into()?, - operand: Box::new(codegen_expr(*expr, ctx)?), - }), - ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), + ast::Expr::BinaryOp { left, op, right } => { + let left = codegen_expr(*left, ctx)?; + let left_operand = Box::new(left.last_expr().clone()); + let right = codegen_expr(*right, ctx)?; + let right_operand = Box::new(right.last_expr().clone()); + + let binary_expr = Expr::Binary { + left: left_operand, + op: op.try_into()?, + right: right_operand, + }; + + Ok(left + .combine(right) + .combine(IntermediateExpr::new_non_agg(binary_expr))) + } + ast::Expr::UnaryOp { op, expr } => { + let inner = codegen_expr(*expr, ctx)?; + Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + op: op.try_into()?, + operand: Box::new(inner.last_expr().clone()), + }))) + } + ast::Expr::Value(v) => Ok(IntermediateExpr::new_non_agg(Expr::Value(v.try_into()?))), ast::Expr::Function(ref f) => { let fn_name = f.name.to_string(); Ok(Expr::Function { @@ -663,6 +761,85 @@ fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result Result { +// match expr_ast { +// ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), +// ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), +// ast::Expr::IsFalse(e) => Ok(Expr::Unary { +// op: UnOp::IsFalse, +// operand: Box::new(codegen_expr(*e, ctx)?), +// }), +// ast::Expr::IsTrue(e) => Ok(Expr::Unary { +// op: UnOp::IsTrue, +// operand: Box::new(codegen_expr(*e, ctx)?), +// }), +// ast::Expr::IsNull(e) => Ok(Expr::Unary { +// op: UnOp::IsNull, +// operand: Box::new(codegen_expr(*e, ctx)?), +// }), +// ast::Expr::IsNotNull(e) => Ok(Expr::Unary { +// op: UnOp::IsNotNull, +// operand: Box::new(codegen_expr(*e, ctx)?), +// }), +// ast::Expr::Between { +// expr, +// negated, +// low, +// high, +// } => { +// let expr: Box = Box::new(codegen_expr(*expr, ctx)?); +// let left = Box::new(codegen_expr(*low, ctx)?); +// let right = Box::new(codegen_expr(*high, ctx)?); +// let between = Expr::Binary { +// left: Box::new(Expr::Binary { +// left, +// op: BinOp::LessThanOrEqual, +// right: expr.clone(), +// }), +// op: BinOp::And, +// right: Box::new(Expr::Binary { +// left: expr, +// op: BinOp::LessThanOrEqual, +// right, +// }), +// }; +// if negated { +// Ok(Expr::Unary { +// op: UnOp::Not, +// operand: Box::new(between), +// }) +// } else { +// Ok(between) +// } +// } +// ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary { +// left: Box::new(codegen_expr(*left, ctx)?), +// op: op.try_into()?, +// right: Box::new(codegen_expr(*right, ctx)?), +// }), +// ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { +// op: op.try_into()?, +// operand: Box::new(codegen_expr(*expr, ctx)?), +// }), +// ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), +// ast::Expr::Function(ref f) => { +// let fn_name = f.name.to_string(); +// Ok(Expr::Function { +// name: fn_name.as_str().into(), +// args: f +// .args +// .iter() +// .map(|arg| codegen_fn_arg(&expr_ast, arg, ctx)) +// .collect::, _>>()?, +// }) +// } +// _ => Err(ExprError::Expr { +// reason: "Unsupported expression", +// expr: expr_ast, +// }), +// } +// } + fn is_fn_name_aggregate(fn_name: &str) -> bool { AGGREGATE_FUNCTIONS.contains(fn_name) } @@ -1490,6 +1667,84 @@ mod codegen_tests { }, ); + check_single_statement( + "SELECT col2, MAX(col3) + 1 AS max_col3 + FROM table1 + GROUP BY col2 + ", + |instrs| { + assert_eq!( + &[ + Instruction::Source { + index: RegisterIndex::default(), + name: TableRef { + schema_name: None, + table_name: "table1".into() + } + }, + Instruction::Empty { + index: RegisterIndex::default().next_index() + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: Some("__otter_temp_col_1".into()) + }, + Instruction::Empty { + index: RegisterIndex::default().next_index().next_index() + }, + Instruction::GroupBy { + input: RegisterIndex::default().next_index(), + output: RegisterIndex::default().next_index().next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }) + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + }, + Instruction::Aggregate { + input: RegisterIndex::default().next_index().next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + func: AggregateFunction::Max, + col_name: "__otter_temp_col_1".into(), + alias: Some("max_col3".into()), + }, + Instruction::Return { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + } + ], + instrs, + ) + }, + ); + check_single_statement( "SELECT col2, col3, SUM(col4 * col4) AS sos FROM table1 From 6e6511a063598d89cce625d84de483ec1b3fd072 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Fri, 5 May 2023 23:28:56 +0530 Subject: [PATCH 17/26] feat(codegen): impl IntermediateExpr-based codegen for function --- src/codegen.rs | 76 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 65 insertions(+), 11 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index e87b00a..9bb8808 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -10,7 +10,7 @@ use std::{error::Error, fmt::Display}; use crate::{ expr::{agg::AggregateFunction, BinOp, Expr, ExprError, UnOp}, - identifier::IdentifierError, + identifier::{ColumnRef, IdentifierError}, ir::{Instruction, IntermediateCode}, parser::parse, value::{Value, ValueError}, @@ -745,14 +745,68 @@ fn codegen_expr( ast::Expr::Value(v) => Ok(IntermediateExpr::new_non_agg(Expr::Value(v.try_into()?))), ast::Expr::Function(ref f) => { let fn_name = f.name.to_string(); - Ok(Expr::Function { - name: fn_name.as_str().into(), - args: f - .args - .iter() - .map(|arg| codegen_fn_arg(&expr_ast, arg, ctx)) - .collect::, _>>()?, - }) + let args = f + .args + .iter() + .map(|arg| { + let ie = codegen_fn_arg(&expr_ast, arg, ctx)?; + Ok::<(IntermediateExpr, Expr), ExprError>((ie, ie.last_expr().clone())) + }) + .collect::, _>>()?; + if is_fn_name_aggregate(&fn_name.to_lowercase()) { + if args.len() > 1 { + Err(ExprError::Expr { + reason: "Aggregates with more than one arguments are not supported yet.", + expr: expr_ast, + }) + } else { + let args = args + .into_iter() + .map(|a| match a { + (IntermediateExpr::Agg(_), _) => Err(ExprError::Expr { + reason: "Aggregates within aggregates are not supported yet", + expr: expr_ast, + }), + (IntermediateExpr::NonAgg(e), last_expr) => { + Ok((e, ctx.get_new_temp_col())) + } + }) + .collect::, _>>()?; + let agg_result_col = ctx.get_new_temp_col(); + let agg_col_res = Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: agg_result_col, + }); + Ok(IntermediateExpr::Agg(IntermediateExprAgg { + pre_agg: args, + agg: vec![( + AggregateFunction::from_name(&fn_name.to_lowercase())?, + args[0].1, + agg_result_col, + )], + post_agg: vec![agg_col_res], + last_alias: Some(agg_result_col), + last_expr: (agg_col_res, agg_result_col), + })) + } + } else { + if args.is_empty() { + Ok(IntermediateExpr::new_non_agg(Expr::Function { + name: fn_name.as_str().into(), + args: args.iter().map(|ie| ie.1).collect(), + })) + } else { + let start = args[0].0; + let combined = args.iter().skip(1).fold(start, |acc, ie| acc.combine(ie.0)); + Ok( + combined.combine(IntermediateExpr::new_non_agg(Expr::Function { + name: fn_name.as_str().into(), + args: args.iter().map(|ie| ie.1).collect(), + })), + ) + } + } } _ => Err(ExprError::Expr { reason: "Unsupported expression", @@ -864,11 +918,11 @@ fn codegen_fn_arg( expr_ast: &ast::Expr, arg: &FunctionArg, ctx: &mut CodegenContext, -) -> Result { +) -> Result { match arg { ast::FunctionArg::Unnamed(arg_expr) => match arg_expr { ast::FunctionArgExpr::Expr(e) => Ok(codegen_expr(e.clone(), ctx)?), - ast::FunctionArgExpr::Wildcard => Ok(Expr::Wildcard), + ast::FunctionArgExpr::Wildcard => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), ast::FunctionArgExpr::QualifiedWildcard(_) => Err(ExprError::Expr { reason: "Qualified wildcards are not supported yet", expr: expr_ast.clone(), From 58c8a75b702f608a1dd000163ac7243d7602164a Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Fri, 5 May 2023 23:44:55 +0530 Subject: [PATCH 18/26] feat(codegen): handle IntermediateExpr in non agg cases --- src/codegen.rs | 62 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 9bb8808..7458765 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -237,12 +237,15 @@ pub fn codegen_ast(ast: &Statement) -> Result { row_index: row_reg, }); - for value in row { - let value = codegen_expr(value, &mut ctx)?; + for value_ast in row { + let value = codegen_expr(value_ast, &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in values", + value_ast, + )?; ctx.instrs.push(Instruction::AddValue { row_index: row_reg, expr: value, - }); + }) } } Ok(()) @@ -328,8 +331,11 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - if let Some(expr) = select.selection.clone() { - let expr = codegen_expr(expr, &mut ctx)?; + if let Some(expr_ast) = select.selection.clone() { + let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in WHERE clause. Use HAVING clause instead", + expr_ast, + )?; ctx.instrs.push(Instruction::Filter { index: table_reg_index, expr, @@ -463,18 +469,15 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - if let Some(expr) = select.having.clone() { - if is_expr_agg(&expr) { - return Err(CodegenError::UnsupportedStatementForm( - concat!( - "HAVING clause does not support inline aggregations.", - " Select the expression `AS some_col_name` ", - "and then use `HAVING` on `some_col_name`." - ), - select.to_string(), - )); - } - let expr = codegen_expr(expr, &mut ctx)?; + if let Some(expr_ast) = select.having.clone() { + let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + concat!( + "HAVING clause does not support inline aggregations.", + " Select the expression `AS some_col_name` ", + "and then use `HAVING` on `some_col_name`." + ), + expr_ast, + )?; ctx.instrs.push(Instruction::Filter { index: table_reg_index, expr, @@ -483,7 +486,11 @@ pub fn codegen_ast(ast: &Statement) -> Result { } SetExpr::Values(exprs) => { if exprs.0.len() == 1 && exprs.0[0].len() == 1 { - let expr: Expr = codegen_expr(exprs.0[0][0].clone(), &mut ctx)?; + let expr_ast = exprs.0[0][0].clone(); + let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in values", + expr_ast, + )?; ctx.instrs.push(Instruction::Expr { index: table_reg_index, expr, @@ -530,7 +537,10 @@ pub fn codegen_ast(ast: &Statement) -> Result { }; for order_by in query.order_by.clone() { - let order_by_expr = codegen_expr(order_by.expr, &mut ctx)?; + let order_by_expr = codegen_expr(order_by.expr, &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in ORDER BY", + order_by.expr, + )?; ctx.instrs.push(Instruction::Order { index: table_reg_index, expr: order_by_expr, @@ -635,6 +645,20 @@ impl IntermediateExpr { }, } } + + pub fn get_non_agg( + self, + err_reason: &'static str, + expr_ast: ast::Expr, + ) -> Result { + match self { + IntermediateExpr::Agg(_) => Err(ExprError::Expr { + reason: err_reason, + expr: expr_ast, + }), + IntermediateExpr::NonAgg(e) => Ok(e), + } + } } fn codegen_expr( From a19f22ea750183d5ffed6d6dd1719f55917d1cd8 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sat, 6 May 2023 16:56:33 +0530 Subject: [PATCH 19/26] fix(codegen): lifetime/borrowing issues --- src/codegen.rs | 79 ++++++++++++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 7458765..875d67a 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -238,7 +238,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { }); for value_ast in row { - let value = codegen_expr(value_ast, &mut ctx)?.get_non_agg( + let value = codegen_expr(value_ast.clone(), &mut ctx)?.get_non_agg( "Aggregate expressions are not supported in values", value_ast, )?; @@ -332,7 +332,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } if let Some(expr_ast) = select.selection.clone() { - let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( "Aggregate expressions are not supported in WHERE clause. Use HAVING clause instead", expr_ast, )?; @@ -470,7 +470,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } if let Some(expr_ast) = select.having.clone() { - let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( concat!( "HAVING clause does not support inline aggregations.", " Select the expression `AS some_col_name` ", @@ -487,7 +487,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { SetExpr::Values(exprs) => { if exprs.0.len() == 1 && exprs.0[0].len() == 1 { let expr_ast = exprs.0[0][0].clone(); - let expr = codegen_expr(expr_ast, &mut ctx)?.get_non_agg( + let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( "Aggregate expressions are not supported in values", expr_ast, )?; @@ -537,7 +537,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { }; for order_by in query.order_by.clone() { - let order_by_expr = codegen_expr(order_by.expr, &mut ctx)?.get_non_agg( + let order_by_expr = codegen_expr(order_by.expr.clone(), &mut ctx)?.get_non_agg( "Aggregate expressions are not supported in ORDER BY", order_by.expr, )?; @@ -594,7 +594,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { Ok(IntermediateCode { instrs: ctx.instrs }) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct IntermediateExprAgg { pub pre_agg: Vec<(Expr, BoundedString)>, pub agg: Vec<(AggregateFunction, BoundedString, BoundedString)>, @@ -603,7 +603,7 @@ struct IntermediateExprAgg { last_expr: (Expr, BoundedString), } -#[derive(Debug)] +#[derive(Debug, Clone)] enum IntermediateExpr { NonAgg(Expr), Agg(IntermediateExprAgg), @@ -624,8 +624,8 @@ impl IntermediateExpr { pub fn combine(self, new: IntermediateExpr) -> Self { match self { - Self::NonAgg(sel) => new, - Self::Agg(sel) => match new { + Self::NonAgg(_) => new, + Self::Agg(mut sel) => match new { Self::NonAgg(new) => { // TODO: last_expr may need updating here? if sel.post_agg.len() <= 1 { @@ -674,31 +674,35 @@ fn codegen_expr( ))), ast::Expr::IsFalse(e) => { let inner = codegen_expr(*e, ctx)?; - Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + let new_expr = Expr::Unary { op: UnOp::IsFalse, operand: Box::new(inner.last_expr().clone()), - }))) + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) } ast::Expr::IsTrue(e) => { let inner = codegen_expr(*e, ctx)?; - Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + let new_expr = Expr::Unary { op: UnOp::IsTrue, operand: Box::new(inner.last_expr().clone()), - }))) + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) } ast::Expr::IsNull(e) => { let inner = codegen_expr(*e, ctx)?; - Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + let new_expr = Expr::Unary { op: UnOp::IsNull, operand: Box::new(inner.last_expr().clone()), - }))) + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) } ast::Expr::IsNotNull(e) => { let inner = codegen_expr(*e, ctx)?; - Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + let new_expr = Expr::Unary { op: UnOp::IsNotNull, operand: Box::new(inner.last_expr().clone()), - }))) + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) } ast::Expr::Between { expr, @@ -728,6 +732,7 @@ fn codegen_expr( right, }), }); + let between_last_expr = between_gen.last_expr().clone(); let between = expr_gen .combine(left_gen) @@ -737,7 +742,7 @@ fn codegen_expr( if negated { Ok(between.combine(IntermediateExpr::new_non_agg(Expr::Unary { op: UnOp::Not, - operand: Box::new(between_gen.last_expr().clone()), + operand: Box::new(between_last_expr), }))) } else { Ok(between) @@ -761,10 +766,11 @@ fn codegen_expr( } ast::Expr::UnaryOp { op, expr } => { let inner = codegen_expr(*expr, ctx)?; - Ok(inner.combine(IntermediateExpr::new_non_agg(Expr::Unary { + let new_expr = Expr::Unary { op: op.try_into()?, operand: Box::new(inner.last_expr().clone()), - }))) + }; + Ok(inner.combine(IntermediateExpr::new_non_agg(new_expr))) } ast::Expr::Value(v) => Ok(IntermediateExpr::new_non_agg(Expr::Value(v.try_into()?))), ast::Expr::Function(ref f) => { @@ -774,7 +780,8 @@ fn codegen_expr( .iter() .map(|arg| { let ie = codegen_fn_arg(&expr_ast, arg, ctx)?; - Ok::<(IntermediateExpr, Expr), ExprError>((ie, ie.last_expr().clone())) + let last_expr = ie.last_expr().clone(); + Ok::<(IntermediateExpr, Expr), ExprError>((ie, last_expr)) }) .collect::, _>>()?; if is_fn_name_aggregate(&fn_name.to_lowercase()) { @@ -789,11 +796,9 @@ fn codegen_expr( .map(|a| match a { (IntermediateExpr::Agg(_), _) => Err(ExprError::Expr { reason: "Aggregates within aggregates are not supported yet", - expr: expr_ast, + expr: expr_ast.clone(), }), - (IntermediateExpr::NonAgg(e), last_expr) => { - Ok((e, ctx.get_new_temp_col())) - } + (IntermediateExpr::NonAgg(e), _) => Ok((e, ctx.get_new_temp_col())), }) .collect::, _>>()?; let agg_result_col = ctx.get_new_temp_col(); @@ -802,14 +807,15 @@ fn codegen_expr( table_name: None, col_name: agg_result_col, }); + let agg = vec![( + AggregateFunction::from_name(&fn_name.to_lowercase())?, + args[0].1, + agg_result_col, + )]; Ok(IntermediateExpr::Agg(IntermediateExprAgg { pre_agg: args, - agg: vec![( - AggregateFunction::from_name(&fn_name.to_lowercase())?, - args[0].1, - agg_result_col, - )], - post_agg: vec![agg_col_res], + agg, + post_agg: vec![agg_col_res.clone()], last_alias: Some(agg_result_col), last_expr: (agg_col_res, agg_result_col), })) @@ -818,15 +824,18 @@ fn codegen_expr( if args.is_empty() { Ok(IntermediateExpr::new_non_agg(Expr::Function { name: fn_name.as_str().into(), - args: args.iter().map(|ie| ie.1).collect(), + args: args.iter().map(|ie| ie.1.clone()).collect(), })) } else { - let start = args[0].0; - let combined = args.iter().skip(1).fold(start, |acc, ie| acc.combine(ie.0)); + let start = args[0].0.clone(); + let combined = args + .iter() + .skip(1) + .fold(start, |acc, ie| acc.combine(ie.0.clone())); Ok( combined.combine(IntermediateExpr::new_non_agg(Expr::Function { name: fn_name.as_str().into(), - args: args.iter().map(|ie| ie.1).collect(), + args: args.iter().map(|ie| ie.1.clone()).collect(), })), ) } From bb61b8dd7c32cc5284ab807f12f5a336f82cc387 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 7 May 2023 00:08:50 +0530 Subject: [PATCH 20/26] feat(codegen): use IntermediateExpr in actual codegen for select commented out the original version for reference --- src/codegen.rs | 266 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 204 insertions(+), 62 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 875d67a..4faaea1 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -342,15 +342,21 @@ pub fn codegen_ast(ast: &Statement) -> Result { }) } + let inter_exprs = select + .projection + .iter() + .cloned() + .map(|projection| codegen_selectitem(&projection, &mut ctx)) + .collect::, _>>()?; + // if there are groupby + aggregations, we project all operations within an // aggregation to another table first. for example, `SUM(col * col)` would be // evaluated as `Project (col * col)` into `%2` and then apply the group by on // `%2`. - // TODO: possible idea for refactor: make an intermediate representation of - // Projection that separates non-agg and agg projections. let pre_grouped_reg_index = table_reg_index; - let mut agg_intermediate_cols = Vec::new(); - if !select.projection.is_empty() { + // let mut agg_intermediate_cols = Vec::new(); + // if !select.projection.is_empty() { + if !inter_exprs.is_empty() { let grouped_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::Empty { index: grouped_reg_index, @@ -358,18 +364,12 @@ pub fn codegen_ast(ast: &Statement) -> Result { table_reg_index = grouped_reg_index; - for projection in &select.projection { - let alias = extract_alias_from_project(&projection)?; - - let expr_ast = extract_expr_ast_from_project(projection.clone()); - - match expr_ast { - // an aggregate operation - Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - ast::Expr::Function(ref f) => { - let expr = codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; - let projected_col_name = ctx.get_new_temp_col(); - agg_intermediate_cols.push(projected_col_name); + for (projection, inter_expr) in + select.projection.iter().zip(inter_exprs.iter()) + { + match inter_expr { + IntermediateExpr::Agg(agg) => { + for (expr, projected_col_name) in agg.pre_agg.clone() { ctx.instrs.push(Instruction::Project { input: pre_grouped_reg_index, output: grouped_reg_index, @@ -377,24 +377,13 @@ pub fn codegen_ast(ast: &Statement) -> Result { alias: Some(projected_col_name), }); } - _ => {} - }, - // a non-aggregate operation - _ => { - let expr = match projection { - SelectItem::UnnamedExpr(ref expr) => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::ExprWithAlias { ref expr, .. } => { - codegen_expr(expr.clone(), &mut ctx)? - } - SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - SelectItem::Wildcard => Expr::Wildcard, - }; + } + IntermediateExpr::NonAgg(expr) => { + let alias = extract_alias_from_project(&projection)?; let projection = Instruction::Project { input: pre_grouped_reg_index, output: grouped_reg_index, - expr, + expr: expr.clone(), alias, }; ctx.instrs.push(projection) @@ -403,8 +392,11 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - for group_by in select.group_by.clone() { - let group_by = codegen_expr(group_by, &mut ctx)?; + for group_by_ast in select.group_by.clone() { + let group_by = codegen_expr(group_by_ast.clone(), &mut ctx)?.get_non_agg( + "Aggregate expressions are not supported in the GROUP BY clause", + group_by_ast, + )?; let grouped_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::Empty { index: grouped_reg_index, @@ -420,10 +412,13 @@ pub fn codegen_ast(ast: &Statement) -> Result { // this is only for aggregations. // aggs are applied on the grouped table created by the `GroupBy` instructions // generated above. - if !select.projection.is_empty() { - let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); + if !inter_exprs.is_empty() { + let has_aggs = inter_exprs + .iter() + .any(|ie| matches!(ie, IntermediateExpr::Agg(_))); if has_aggs { + // codegen the aggregations themselves to an intermediate table let original_table_reg_index = table_reg_index; table_reg_index = ctx.get_and_increment_reg(); @@ -431,44 +426,179 @@ pub fn codegen_ast(ast: &Statement) -> Result { index: table_reg_index, }); - let mut agg_index = 0; - for projection in &select.projection { - let alias = extract_alias_from_project(&projection)?; - - let expr_ast = extract_expr_ast_from_project(projection.clone()); - - match expr_ast { - // an aggregate operation - Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - ast::Expr::Function(ref f) => { + for inter_expr in &inter_exprs { + match inter_expr { + IntermediateExpr::Agg(agg) => { + for (agg_fn, col_name, alias) in &agg.agg { ctx.instrs.push(Instruction::Aggregate { input: original_table_reg_index, output: table_reg_index, - func: AggregateFunction::from_name( - f.name.to_string().to_lowercase().as_str(), - )?, - col_name: agg_intermediate_cols[agg_index], - alias, + func: agg_fn.clone(), + col_name: *col_name, + alias: Some(*alias), }); - agg_index += 1; } - _ => unreachable!( - "check for fn is already done. this should not happen." - ), - }, - _ => {} + } + IntermediateExpr::NonAgg(_) => {} } } - } - if select.distinct { - return Err(CodegenError::UnsupportedStatementForm( - "DISTINCT is not supported yet", - select.to_string(), - )); + let original_table_reg_index = table_reg_index; + table_reg_index = ctx.get_and_increment_reg(); + + ctx.instrs.push(Instruction::Empty { + index: table_reg_index, + }); + + for (projection, inter_expr) in + select.projection.iter().zip(inter_exprs.iter()) + { + let alias = extract_alias_from_project(&projection)?; + + match inter_expr { + IntermediateExpr::Agg(agg) => { + for expr in agg.post_agg.clone() { + ctx.instrs.push(Instruction::Project { + input: original_table_reg_index, + output: table_reg_index, + expr, + alias, + }) + } + } + IntermediateExpr::NonAgg(_) => {} + } + } } } + // // if there are groupby + aggregations, we project all operations within an + // // aggregation to another table first. for example, `SUM(col * col)` would be + // // evaluated as `Project (col * col)` into `%2` and then apply the group by on + // // `%2`. + // // TODO: possible idea for refactor: make an intermediate representation of + // // Projection that separates non-agg and agg projections. + // let pre_grouped_reg_index = table_reg_index; + // let mut agg_intermediate_cols = Vec::new(); + // if !select.projection.is_empty() { + // let grouped_reg_index = ctx.get_and_increment_reg(); + // ctx.instrs.push(Instruction::Empty { + // index: grouped_reg_index, + // }); + + // table_reg_index = grouped_reg_index; + + // for projection in &select.projection { + // let alias = extract_alias_from_project(&projection)?; + + // let expr_ast = extract_expr_ast_from_project(projection.clone()); + + // match expr_ast { + // // an aggregate operation + // Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { + // ast::Expr::Function(ref f) => { + // let expr = codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; + // let projected_col_name = ctx.get_new_temp_col(); + // agg_intermediate_cols.push(projected_col_name); + // ctx.instrs.push(Instruction::Project { + // input: pre_grouped_reg_index, + // output: grouped_reg_index, + // expr, + // alias: Some(projected_col_name), + // }); + // } + // _ => {} + // }, + // // a non-aggregate operation + // _ => { + // let expr = match projection { + // SelectItem::UnnamedExpr(ref expr) => { + // codegen_expr(expr.clone(), &mut ctx)? + // } + // SelectItem::ExprWithAlias { ref expr, .. } => { + // codegen_expr(expr.clone(), &mut ctx)? + // } + // SelectItem::QualifiedWildcard(_) => Expr::Wildcard, + // SelectItem::Wildcard => Expr::Wildcard, + // }; + // let projection = Instruction::Project { + // input: pre_grouped_reg_index, + // output: grouped_reg_index, + // expr, + // alias, + // }; + // ctx.instrs.push(projection) + // } + // } + // } + // } + + // for group_by in select.group_by.clone() { + // let group_by = codegen_expr(group_by, &mut ctx)?; + // let grouped_reg_index = ctx.get_and_increment_reg(); + // ctx.instrs.push(Instruction::Empty { + // index: grouped_reg_index, + // }); + // ctx.instrs.push(Instruction::GroupBy { + // input: table_reg_index, + // output: grouped_reg_index, + // expr: group_by, + // }); + // table_reg_index = grouped_reg_index + // } + + // // this is only for aggregations. + // // aggs are applied on the grouped table created by the `GroupBy` instructions + // // generated above. + // if !select.projection.is_empty() { + // let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); + + // if has_aggs { + // let original_table_reg_index = table_reg_index; + // table_reg_index = ctx.get_and_increment_reg(); + + // ctx.instrs.push(Instruction::Empty { + // index: table_reg_index, + // }); + + // let mut agg_index = 0; + // for projection in &select.projection { + // let alias = extract_alias_from_project(&projection)?; + + // let expr_ast = extract_expr_ast_from_project(projection.clone()); + + // match expr_ast { + // // an aggregate operation + // Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { + // ast::Expr::Function(ref f) => { + // ctx.instrs.push(Instruction::Aggregate { + // input: original_table_reg_index, + // output: table_reg_index, + // func: AggregateFunction::from_name( + // f.name.to_string().to_lowercase().as_str(), + // )?, + // col_name: agg_intermediate_cols[agg_index], + // alias, + // }); + // agg_index += 1; + // } + // _ => unreachable!( + // "check for fn is already done. this should not happen." + // ), + // }, + // _ => {} + // } + // } + // } + + // if select.distinct { + // return Err(CodegenError::UnsupportedStatementForm( + // "DISTINCT is not supported yet", + // select.to_string(), + // )); + // } + // } + if let Some(expr_ast) = select.having.clone() { let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( concat!( @@ -661,6 +791,18 @@ impl IntermediateExpr { } } +fn codegen_selectitem( + projection: &SelectItem, + ctx: &mut CodegenContext, +) -> Result { + match projection { + SelectItem::UnnamedExpr(ref expr) => codegen_expr(expr.clone(), ctx), + SelectItem::ExprWithAlias { ref expr, .. } => codegen_expr(expr.clone(), ctx), + SelectItem::QualifiedWildcard(_) => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), + SelectItem::Wildcard => Ok(IntermediateExpr::new_non_agg(Expr::Wildcard)), + } +} + fn codegen_expr( expr_ast: ast::Expr, ctx: &mut CodegenContext, From 4835a9ba61bffe42453da8203ad624661c345f44 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 7 May 2023 16:36:41 +0530 Subject: [PATCH 21/26] test(codegen): all tests are now passing but are they correct? --- src/codegen.rs | 198 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 168 insertions(+), 30 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 4faaea1..b5340b7 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -724,7 +724,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { Ok(IntermediateCode { instrs: ctx.instrs }) } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] struct IntermediateExprAgg { pub pre_agg: Vec<(Expr, BoundedString)>, pub agg: Vec<(AggregateFunction, BoundedString, BoundedString)>, @@ -1798,7 +1798,6 @@ mod codegen_tests { ", |instrs| { assert_eq!( - instrs, &[ Instruction::Source { index: RegisterIndex::default(), @@ -1868,10 +1867,35 @@ mod codegen_tests { .next_index(), func: AggregateFunction::Max, col_name: "__otter_temp_col_1".into(), - alias: Some("max_col3".into()), + alias: Some("__otter_temp_col_2".into()), + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Project { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + }), + alias: Some("max_col3".into()) }, Instruction::Filter { index: RegisterIndex::default() + .next_index() .next_index() .next_index() .next_index(), @@ -1887,11 +1911,13 @@ mod codegen_tests { }, Instruction::Return { index: RegisterIndex::default() + .next_index() .next_index() .next_index() .next_index(), } - ] + ], + instrs, ) }, ); @@ -1960,10 +1986,39 @@ mod codegen_tests { .next_index(), func: AggregateFunction::Max, col_name: "__otter_temp_col_1".into(), + alias: Some("__otter_temp_col_2".into()), + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Project { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::Binary { + left: Box::new(Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + })), + op: BinOp::Plus, + right: Box::new(Expr::Value(Value::Int64(1))), + }, alias: Some("max_col3".into()), }, Instruction::Return { index: RegisterIndex::default() + .next_index() .next_index() .next_index() .next_index(), @@ -2092,13 +2147,41 @@ mod codegen_tests { .next_index(), func: AggregateFunction::Sum, col_name: "__otter_temp_col_1".into(), - alias: Some("sos".into()), + alias: Some("__otter_temp_col_2".into()), + }, + Instruction::Empty { + index: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index() + }, + Instruction::Project { + input: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into(), + }), + alias: Some("sos".into()) }, Instruction::Return { index: RegisterIndex::default() .next_index() .next_index() .next_index() + .next_index() .next_index(), } ], @@ -2113,9 +2196,11 @@ mod codegen_tests { mod expr_codegen_tests { use sqlparser::{ast, dialect::GenericDialect, parser::Parser, tokenizer::Tokenizer}; + use pretty_assertions::{assert_eq, assert_ne}; + use crate::{ - codegen::{codegen_expr, CodegenContext}, - expr::{BinOp, Expr, ExprError, UnOp}, + codegen::{codegen_expr, CodegenContext, IntermediateExpr, IntermediateExprAgg}, + expr::{agg::AggregateFunction, BinOp, Expr, ExprError, UnOp}, identifier::ColumnRef, value::Value, }; @@ -2130,13 +2215,24 @@ mod expr_codegen_tests { parser.parse_expr().unwrap() } - fn codegen_expr_wrapper(expr_ast: ast::Expr) -> Result { + fn codegen_expr_wrapper_no_agg(expr_ast: ast::Expr) -> Result { let mut ctx = CodegenContext::new(); - codegen_expr(expr_ast, &mut ctx) + match codegen_expr(expr_ast, &mut ctx)? { + IntermediateExpr::Agg(_) => panic!("Expected unaggregated expression"), + IntermediateExpr::NonAgg(expr) => Ok(expr), + } + } + + fn codegen_expr_wrapper_agg(expr_ast: ast::Expr) -> Result { + let mut ctx = CodegenContext::new(); + match codegen_expr(expr_ast, &mut ctx)? { + IntermediateExpr::Agg(agg) => Ok(agg), + IntermediateExpr::NonAgg(expr) => panic!("Expected aggregated expression"), + } } assert_eq!( - codegen_expr_wrapper(parse_expr("abc")), + codegen_expr_wrapper_no_agg(parse_expr("abc")), Ok(Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, @@ -2145,7 +2241,7 @@ mod expr_codegen_tests { ); assert_ne!( - codegen_expr_wrapper(parse_expr("abc")), + codegen_expr_wrapper_no_agg(parse_expr("abc")), Ok(Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, @@ -2154,7 +2250,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("table1.col1")), + codegen_expr_wrapper_no_agg(parse_expr("table1.col1")), Ok(Expr::ColumnRef(ColumnRef { schema_name: None, table_name: Some("table1".into()), @@ -2163,7 +2259,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("schema1.table1.col1")), + codegen_expr_wrapper_no_agg(parse_expr("schema1.table1.col1")), Ok(Expr::ColumnRef(ColumnRef { schema_name: Some("schema1".into()), table_name: Some("table1".into()), @@ -2172,7 +2268,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("5 IS NULL")), + codegen_expr_wrapper_no_agg(parse_expr("5 IS NULL")), Ok(Expr::Unary { op: UnOp::IsNull, operand: Box::new(Expr::Value(Value::Int64(5))) @@ -2180,7 +2276,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("1 IS TRUE")), + codegen_expr_wrapper_no_agg(parse_expr("1 IS TRUE")), Ok(Expr::Unary { op: UnOp::IsTrue, operand: Box::new(Expr::Value(Value::Int64(1))) @@ -2188,7 +2284,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("4 BETWEEN 3 AND 5")), + codegen_expr_wrapper_no_agg(parse_expr("4 BETWEEN 3 AND 5")), Ok(Expr::Binary { left: Box::new(Expr::Binary { left: Box::new(Expr::Value(Value::Int64(3))), @@ -2205,7 +2301,7 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("4 NOT BETWEEN 3 AND 5")), + codegen_expr_wrapper_no_agg(parse_expr("4 NOT BETWEEN 3 AND 5")), Ok(Expr::Unary { op: UnOp::Not, operand: Box::new(Expr::Binary { @@ -2225,19 +2321,40 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("MAX(col1)")), - Ok(Expr::Function { - name: "MAX".into(), - args: vec![Expr::ColumnRef(ColumnRef { + codegen_expr_wrapper_agg(parse_expr("MAX(col1)")), + Ok(IntermediateExprAgg { + pre_agg: vec![( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col1".into() + }), + "__otter_temp_col_1".into() + )], + agg: vec![( + AggregateFunction::Max, + "__otter_temp_col_1".into(), + "__otter_temp_col_2".into() + )], + post_agg: vec![Expr::ColumnRef(ColumnRef { schema_name: None, table_name: None, - col_name: "col1".into() - })] + col_name: "__otter_temp_col_2".into() + })], + last_alias: Some("__otter_temp_col_2".into()), + last_expr: ( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + }), + "__otter_temp_col_2".into() + ) }) ); assert_eq!( - codegen_expr_wrapper(parse_expr("some_func(col1, 1, 'abc')")), + codegen_expr_wrapper_no_agg(parse_expr("some_func(col1, 1, 'abc')")), Ok(Expr::Function { name: "some_func".into(), args: vec![ @@ -2253,10 +2370,28 @@ mod expr_codegen_tests { ); assert_eq!( - codegen_expr_wrapper(parse_expr("COUNT(*)")), - Ok(Expr::Function { - name: "COUNT".into(), - args: vec![Expr::Wildcard] + codegen_expr_wrapper_agg(parse_expr("COUNT(*)")), + Ok(IntermediateExprAgg { + pre_agg: vec![(Expr::Wildcard, "__otter_temp_col_1".into())], + agg: vec![( + AggregateFunction::Count, + "__otter_temp_col_1".into(), + "__otter_temp_col_2".into() + )], + post_agg: vec![Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + })], + last_alias: Some("__otter_temp_col_2".into()), + last_expr: ( + Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "__otter_temp_col_2".into() + }), + "__otter_temp_col_2".into() + ) }) ); } @@ -2278,7 +2413,7 @@ mod expr_eval_tests { value::{Value, ValueBinaryOpError, ValueUnaryOpError}, }; - use super::{codegen_expr, CodegenContext}; + use super::{codegen_expr, CodegenContext, IntermediateExpr}; fn str_to_expr(s: &str) -> Expr { let dialect = GenericDialect {}; @@ -2286,7 +2421,10 @@ mod expr_eval_tests { let tokens = tokenizer.tokenize().unwrap(); let mut parser = Parser::new(tokens, &dialect); let mut ctx = CodegenContext::new(); - codegen_expr(parser.parse_expr().unwrap(), &mut ctx).unwrap() + match codegen_expr(parser.parse_expr().unwrap(), &mut ctx).unwrap() { + IntermediateExpr::NonAgg(expr) => expr, + IntermediateExpr::Agg(_) => panic!("Did not expect aggregate expression here"), + } } fn exec_expr_no_context(expr: Expr) -> Result { From d8f1f379734bc9aa4a8501dbe6381539da3df412 Mon Sep 17 00:00:00 2001 From: Samyak Sarnayak Date: Sun, 7 May 2023 16:39:15 +0530 Subject: [PATCH 22/26] chore(codegen): remove unused code and add distinct check back --- src/codegen.rs | 164 +++---------------------------------------------- 1 file changed, 8 insertions(+), 156 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index b5340b7..bdec388 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -81,10 +81,6 @@ impl CodegenContext { reg } - pub fn last_used_reg(&self) -> RegisterIndex { - self.current_reg - } - pub fn get_new_temp_col(&mut self) -> BoundedString { self.last_temp_col_num += 1; format!("{TEMP_COL_NAME_PREFIX}_{}", self.last_temp_col_num) @@ -106,14 +102,6 @@ static AGGREGATE_FUNCTIONS: Set<&'static str> = phf_set! { "sum", }; -fn extract_expr_ast_from_project(projection: SelectItem) -> Option { - match projection { - SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => Some(expr), - SelectItem::QualifiedWildcard(_) => None, - SelectItem::Wildcard => None, - } -} - fn extract_alias_from_project( projection: &SelectItem, ) -> Result, CodegenError> { @@ -470,134 +458,14 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } } - } - // // if there are groupby + aggregations, we project all operations within an - // // aggregation to another table first. for example, `SUM(col * col)` would be - // // evaluated as `Project (col * col)` into `%2` and then apply the group by on - // // `%2`. - // // TODO: possible idea for refactor: make an intermediate representation of - // // Projection that separates non-agg and agg projections. - // let pre_grouped_reg_index = table_reg_index; - // let mut agg_intermediate_cols = Vec::new(); - // if !select.projection.is_empty() { - // let grouped_reg_index = ctx.get_and_increment_reg(); - // ctx.instrs.push(Instruction::Empty { - // index: grouped_reg_index, - // }); - - // table_reg_index = grouped_reg_index; - - // for projection in &select.projection { - // let alias = extract_alias_from_project(&projection)?; - - // let expr_ast = extract_expr_ast_from_project(projection.clone()); - - // match expr_ast { - // // an aggregate operation - // Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - // ast::Expr::Function(ref f) => { - // let expr = codegen_fn_arg(&expr_ast, &f.args[0], &mut ctx)?; - // let projected_col_name = ctx.get_new_temp_col(); - // agg_intermediate_cols.push(projected_col_name); - // ctx.instrs.push(Instruction::Project { - // input: pre_grouped_reg_index, - // output: grouped_reg_index, - // expr, - // alias: Some(projected_col_name), - // }); - // } - // _ => {} - // }, - // // a non-aggregate operation - // _ => { - // let expr = match projection { - // SelectItem::UnnamedExpr(ref expr) => { - // codegen_expr(expr.clone(), &mut ctx)? - // } - // SelectItem::ExprWithAlias { ref expr, .. } => { - // codegen_expr(expr.clone(), &mut ctx)? - // } - // SelectItem::QualifiedWildcard(_) => Expr::Wildcard, - // SelectItem::Wildcard => Expr::Wildcard, - // }; - // let projection = Instruction::Project { - // input: pre_grouped_reg_index, - // output: grouped_reg_index, - // expr, - // alias, - // }; - // ctx.instrs.push(projection) - // } - // } - // } - // } - - // for group_by in select.group_by.clone() { - // let group_by = codegen_expr(group_by, &mut ctx)?; - // let grouped_reg_index = ctx.get_and_increment_reg(); - // ctx.instrs.push(Instruction::Empty { - // index: grouped_reg_index, - // }); - // ctx.instrs.push(Instruction::GroupBy { - // input: table_reg_index, - // output: grouped_reg_index, - // expr: group_by, - // }); - // table_reg_index = grouped_reg_index - // } - - // // this is only for aggregations. - // // aggs are applied on the grouped table created by the `GroupBy` instructions - // // generated above. - // if !select.projection.is_empty() { - // let has_aggs = select.projection.iter().any(|p| is_projection_agg(p)); - - // if has_aggs { - // let original_table_reg_index = table_reg_index; - // table_reg_index = ctx.get_and_increment_reg(); - - // ctx.instrs.push(Instruction::Empty { - // index: table_reg_index, - // }); - - // let mut agg_index = 0; - // for projection in &select.projection { - // let alias = extract_alias_from_project(&projection)?; - - // let expr_ast = extract_expr_ast_from_project(projection.clone()); - - // match expr_ast { - // // an aggregate operation - // Some(expr_ast) if is_expr_agg(&expr_ast) => match expr_ast { - // ast::Expr::Function(ref f) => { - // ctx.instrs.push(Instruction::Aggregate { - // input: original_table_reg_index, - // output: table_reg_index, - // func: AggregateFunction::from_name( - // f.name.to_string().to_lowercase().as_str(), - // )?, - // col_name: agg_intermediate_cols[agg_index], - // alias, - // }); - // agg_index += 1; - // } - // _ => unreachable!( - // "check for fn is already done. this should not happen." - // ), - // }, - // _ => {} - // } - // } - // } - - // if select.distinct { - // return Err(CodegenError::UnsupportedStatementForm( - // "DISTINCT is not supported yet", - // select.to_string(), - // )); - // } - // } + if select.distinct { + return Err(CodegenError::UnsupportedStatementForm( + "DISTINCT is not supported yet", + select.to_string(), + )); + } + } if let Some(expr_ast) = select.having.clone() { let expr = codegen_expr(expr_ast.clone(), &mut ctx)?.get_non_agg( @@ -1073,22 +941,6 @@ fn is_fn_name_aggregate(fn_name: &str) -> bool { AGGREGATE_FUNCTIONS.contains(fn_name) } -fn is_expr_agg(e: &ast::Expr) -> bool { - match e { - ast::Expr::Function(ref f) => is_fn_name_aggregate(&f.name.to_string().to_lowercase()), - _ => false, - } -} - -fn is_projection_agg(p: &SelectItem) -> bool { - match p { - SelectItem::UnnamedExpr(ref expr) => is_expr_agg(expr), - SelectItem::ExprWithAlias { ref expr, .. } => is_expr_agg(expr), - SelectItem::QualifiedWildcard(_) => false, - SelectItem::Wildcard => false, - } -} - fn codegen_fn_arg( expr_ast: &ast::Expr, arg: &FunctionArg, @@ -2227,7 +2079,7 @@ mod expr_codegen_tests { let mut ctx = CodegenContext::new(); match codegen_expr(expr_ast, &mut ctx)? { IntermediateExpr::Agg(agg) => Ok(agg), - IntermediateExpr::NonAgg(expr) => panic!("Expected aggregated expression"), + IntermediateExpr::NonAgg(_) => panic!("Expected aggregated expression"), } } From da4761417fdf399ce91724617fe98a937764c565 Mon Sep 17 00:00:00 2001 From: Samyak S Sarnayak Date: Sun, 7 May 2023 22:53:02 +0530 Subject: [PATCH 23/26] chore(codegen): remove more unused code --- src/codegen.rs | 79 -------------------------------------------------- 1 file changed, 79 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index bdec388..8abaf0d 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -858,85 +858,6 @@ fn codegen_expr( } } -// fn codegen_expr(expr_ast: ast::Expr, ctx: &mut CodegenContext) -> Result { -// match expr_ast { -// ast::Expr::Identifier(i) => Ok(Expr::ColumnRef(vec![i].try_into()?)), -// ast::Expr::CompoundIdentifier(i) => Ok(Expr::ColumnRef(i.try_into()?)), -// ast::Expr::IsFalse(e) => Ok(Expr::Unary { -// op: UnOp::IsFalse, -// operand: Box::new(codegen_expr(*e, ctx)?), -// }), -// ast::Expr::IsTrue(e) => Ok(Expr::Unary { -// op: UnOp::IsTrue, -// operand: Box::new(codegen_expr(*e, ctx)?), -// }), -// ast::Expr::IsNull(e) => Ok(Expr::Unary { -// op: UnOp::IsNull, -// operand: Box::new(codegen_expr(*e, ctx)?), -// }), -// ast::Expr::IsNotNull(e) => Ok(Expr::Unary { -// op: UnOp::IsNotNull, -// operand: Box::new(codegen_expr(*e, ctx)?), -// }), -// ast::Expr::Between { -// expr, -// negated, -// low, -// high, -// } => { -// let expr: Box = Box::new(codegen_expr(*expr, ctx)?); -// let left = Box::new(codegen_expr(*low, ctx)?); -// let right = Box::new(codegen_expr(*high, ctx)?); -// let between = Expr::Binary { -// left: Box::new(Expr::Binary { -// left, -// op: BinOp::LessThanOrEqual, -// right: expr.clone(), -// }), -// op: BinOp::And, -// right: Box::new(Expr::Binary { -// left: expr, -// op: BinOp::LessThanOrEqual, -// right, -// }), -// }; -// if negated { -// Ok(Expr::Unary { -// op: UnOp::Not, -// operand: Box::new(between), -// }) -// } else { -// Ok(between) -// } -// } -// ast::Expr::BinaryOp { left, op, right } => Ok(Expr::Binary { -// left: Box::new(codegen_expr(*left, ctx)?), -// op: op.try_into()?, -// right: Box::new(codegen_expr(*right, ctx)?), -// }), -// ast::Expr::UnaryOp { op, expr } => Ok(Expr::Unary { -// op: op.try_into()?, -// operand: Box::new(codegen_expr(*expr, ctx)?), -// }), -// ast::Expr::Value(v) => Ok(Expr::Value(v.try_into()?)), -// ast::Expr::Function(ref f) => { -// let fn_name = f.name.to_string(); -// Ok(Expr::Function { -// name: fn_name.as_str().into(), -// args: f -// .args -// .iter() -// .map(|arg| codegen_fn_arg(&expr_ast, arg, ctx)) -// .collect::, _>>()?, -// }) -// } -// _ => Err(ExprError::Expr { -// reason: "Unsupported expression", -// expr: expr_ast, -// }), -// } -// } - fn is_fn_name_aggregate(fn_name: &str) -> bool { AGGREGATE_FUNCTIONS.contains(fn_name) } From cc5e3c2c8ad11d520c4c7a0fecf09e2fb34cda5a Mon Sep 17 00:00:00 2001 From: Samyak S Sarnayak Date: Sun, 7 May 2023 23:34:19 +0530 Subject: [PATCH 24/26] fix(codegen): move non-agg projects to final table too --- src/codegen.rs | 59 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 8abaf0d..2473b63 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -342,40 +342,30 @@ pub fn codegen_ast(ast: &Statement) -> Result { // evaluated as `Project (col * col)` into `%2` and then apply the group by on // `%2`. let pre_grouped_reg_index = table_reg_index; + let mut pre_grouped_inter_reg_index = pre_grouped_reg_index; // let mut agg_intermediate_cols = Vec::new(); // if !select.projection.is_empty() { if !inter_exprs.is_empty() { - let grouped_reg_index = ctx.get_and_increment_reg(); + pre_grouped_inter_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::Empty { - index: grouped_reg_index, + index: pre_grouped_inter_reg_index, }); - table_reg_index = grouped_reg_index; + table_reg_index = pre_grouped_inter_reg_index; - for (projection, inter_expr) in - select.projection.iter().zip(inter_exprs.iter()) - { + for inter_expr in inter_exprs.iter() { match inter_expr { IntermediateExpr::Agg(agg) => { for (expr, projected_col_name) in agg.pre_agg.clone() { ctx.instrs.push(Instruction::Project { input: pre_grouped_reg_index, - output: grouped_reg_index, + output: pre_grouped_inter_reg_index, expr, alias: Some(projected_col_name), }); } } - IntermediateExpr::NonAgg(expr) => { - let alias = extract_alias_from_project(&projection)?; - let projection = Instruction::Project { - input: pre_grouped_reg_index, - output: grouped_reg_index, - expr: expr.clone(), - alias, - }; - ctx.instrs.push(projection) - } + IntermediateExpr::NonAgg(_) => {} } } } @@ -431,7 +421,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { } } - let original_table_reg_index = table_reg_index; + let last_grouped_reg_index = table_reg_index; table_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::Empty { @@ -447,14 +437,43 @@ pub fn codegen_ast(ast: &Statement) -> Result { IntermediateExpr::Agg(agg) => { for expr in agg.post_agg.clone() { ctx.instrs.push(Instruction::Project { - input: original_table_reg_index, + input: last_grouped_reg_index, output: table_reg_index, expr, alias, }) } } - IntermediateExpr::NonAgg(_) => {} + IntermediateExpr::NonAgg(expr) => { + let alias = extract_alias_from_project(&projection)?; + let projection = Instruction::Project { + input: pre_grouped_inter_reg_index, + output: table_reg_index, + expr: expr.clone(), + alias, + }; + ctx.instrs.push(projection) + } + } + } + } else { + for (projection, inter_expr) in + select.projection.iter().zip(inter_exprs.iter()) + { + match inter_expr { + IntermediateExpr::NonAgg(expr) => { + let alias = extract_alias_from_project(&projection)?; + let projection = Instruction::Project { + input: pre_grouped_inter_reg_index, + output: table_reg_index, + expr: expr.clone(), + alias, + }; + ctx.instrs.push(projection) + } + IntermediateExpr::Agg(_) => { + unreachable!("already checked for aggregates") + } } } } From 76d72a3d8de71caf15233541f036f572b4dfd7ed Mon Sep 17 00:00:00 2001 From: Samyak S Sarnayak Date: Sun, 7 May 2023 23:40:44 +0530 Subject: [PATCH 25/26] fix(codegen): use pre-project table for non-aggs --- src/codegen.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index 2473b63..e4f968f 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -342,11 +342,10 @@ pub fn codegen_ast(ast: &Statement) -> Result { // evaluated as `Project (col * col)` into `%2` and then apply the group by on // `%2`. let pre_grouped_reg_index = table_reg_index; - let mut pre_grouped_inter_reg_index = pre_grouped_reg_index; // let mut agg_intermediate_cols = Vec::new(); // if !select.projection.is_empty() { if !inter_exprs.is_empty() { - pre_grouped_inter_reg_index = ctx.get_and_increment_reg(); + let pre_grouped_inter_reg_index = ctx.get_and_increment_reg(); ctx.instrs.push(Instruction::Empty { index: pre_grouped_inter_reg_index, }); @@ -447,7 +446,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { IntermediateExpr::NonAgg(expr) => { let alias = extract_alias_from_project(&projection)?; let projection = Instruction::Project { - input: pre_grouped_inter_reg_index, + input: pre_grouped_reg_index, output: table_reg_index, expr: expr.clone(), alias, @@ -464,7 +463,7 @@ pub fn codegen_ast(ast: &Statement) -> Result { IntermediateExpr::NonAgg(expr) => { let alias = extract_alias_from_project(&projection)?; let projection = Instruction::Project { - input: pre_grouped_inter_reg_index, + input: pre_grouped_reg_index, output: table_reg_index, expr: expr.clone(), alias, From ee82df5c2bcd1b7ffcde9b4224658672f51bc4e1 Mon Sep 17 00:00:00 2001 From: Samyak S Sarnayak Date: Sun, 7 May 2023 23:46:55 +0530 Subject: [PATCH 26/26] test(codegen): fix order of non-aggs --- src/codegen.rs | 98 +++++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 40 deletions(-) diff --git a/src/codegen.rs b/src/codegen.rs index e4f968f..a5e94e1 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1612,16 +1612,6 @@ mod codegen_tests { Instruction::Empty { index: RegisterIndex::default().next_index() }, - Instruction::Project { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col2".into(), - }), - alias: None - }, Instruction::Project { input: RegisterIndex::default(), output: RegisterIndex::default().next_index(), @@ -1667,6 +1657,20 @@ mod codegen_tests { .next_index() .next_index() }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, Instruction::Project { input: RegisterIndex::default() .next_index() @@ -1731,16 +1735,6 @@ mod codegen_tests { Instruction::Empty { index: RegisterIndex::default().next_index() }, - Instruction::Project { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col2".into(), - }), - alias: None - }, Instruction::Project { input: RegisterIndex::default(), output: RegisterIndex::default().next_index(), @@ -1786,6 +1780,20 @@ mod codegen_tests { .next_index() .next_index() }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, Instruction::Project { input: RegisterIndex::default() .next_index() @@ -1851,26 +1859,6 @@ mod codegen_tests { Instruction::Empty { index: RegisterIndex::default().next_index() }, - Instruction::Project { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col2".into(), - }), - alias: None - }, - Instruction::Project { - input: RegisterIndex::default(), - output: RegisterIndex::default().next_index(), - expr: Expr::ColumnRef(ColumnRef { - schema_name: None, - table_name: None, - col_name: "col3".into(), - }), - alias: None - }, Instruction::Project { input: RegisterIndex::default(), output: RegisterIndex::default().next_index(), @@ -1948,6 +1936,36 @@ mod codegen_tests { .next_index() .next_index() }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col2".into(), + }), + alias: None + }, + Instruction::Project { + input: RegisterIndex::default(), + output: RegisterIndex::default() + .next_index() + .next_index() + .next_index() + .next_index() + .next_index(), + expr: Expr::ColumnRef(ColumnRef { + schema_name: None, + table_name: None, + col_name: "col3".into(), + }), + alias: None + }, Instruction::Project { input: RegisterIndex::default() .next_index()