From 52aa07c38908604062ddd615f6ad9ff8d5d1cc9c Mon Sep 17 00:00:00 2001 From: Linwei Zhang Date: Sun, 28 Jan 2024 23:08:13 +0800 Subject: [PATCH] Improve database run --- bustubx/src/catalog/data_type.rs | 2 +- bustubx/src/database.rs | 66 ++++++++++-------------- bustubx/src/error.rs | 9 +++- bustubx/src/optimizer/heuristic/graph.rs | 32 +++++++++--- bustubx/src/parser/mod.rs | 14 +++-- 5 files changed, 66 insertions(+), 57 deletions(-) diff --git a/bustubx/src/catalog/data_type.rs b/bustubx/src/catalog/data_type.rs index 475f521..07c9986 100644 --- a/bustubx/src/catalog/data_type.rs +++ b/bustubx/src/catalog/data_type.rs @@ -31,7 +31,7 @@ impl TryFrom<&sqlparser::ast::DataType> for DataType { sqlparser::ast::DataType::SmallInt(_) => Ok(DataType::Int16), sqlparser::ast::DataType::Int(_) => Ok(DataType::Int32), sqlparser::ast::DataType::BigInt(_) => Ok(DataType::Int64), - _ => Err(BustubxError::NotImplement(format!( + _ => Err(BustubxError::NotSupport(format!( "Not support datatype {}", value ))), diff --git a/bustubx/src/database.rs b/bustubx/src/database.rs index 362a169..ff1c9cb 100644 --- a/bustubx/src/database.rs +++ b/bustubx/src/database.rs @@ -3,6 +3,7 @@ use tempfile::TempDir; use tracing::span; +use crate::error::{BustubxError, BustubxResult}; use crate::planner::logical_plan::LogicalPlan; use crate::{ buffer::buffer_pool::BufferPoolManager, @@ -47,27 +48,10 @@ impl Database { } } - pub fn run(&mut self, sql: &str) -> Vec { + pub fn run(&mut self, sql: &str) -> BustubxResult> { let _db_run_span = span!(tracing::Level::INFO, "database.run", sql).entered(); - // sql -> ast - let stmts = crate::parser::parse_sql(sql); - if stmts.is_err() { - println!("parse sql error"); - return Vec::new(); - } - let stmts = stmts.unwrap(); - if stmts.len() != 1 { - println!("only support one sql statement"); - return Vec::new(); - } - let stmt = &stmts[0]; - let mut binder = Planner { - context: PlannerContext { - catalog: &self.catalog, - }, - }; - // ast -> logical plan - let logical_plan = binder.plan(&stmt); + + let logical_plan = self.build_logical_plan(sql)?; println!("{:?}", logical_plan); // logical plan -> physical plan @@ -82,29 +66,27 @@ impl Database { let (tuples, schema) = execution_engine.execute(Arc::new(physical_plan)); // println!("execution result: {:?}", tuples); // print_tuples(&tuples, &schema); - tuples + Ok(tuples) } - pub fn build_logical_plan(&mut self, sql: &str) -> LogicalPlan { + pub fn build_logical_plan(&mut self, sql: &str) -> BustubxResult { // sql -> ast - let stmts = crate::parser::parse_sql(sql); - if stmts.is_err() { - panic!("parse sql error") - } - let stmts = stmts.unwrap(); + let stmts = crate::parser::parse_sql(sql)?; if stmts.len() != 1 { - panic!("only support one sql statement") + return Err(BustubxError::NotSupport( + "only support one sql statement".to_string(), + )); } let stmt = &stmts[0]; - let mut binder = Planner { + let mut planner = Planner { context: PlannerContext { catalog: &self.catalog, }, }; - // ast -> statement - let logical_plan = binder.plan(&stmt); + // ast -> logical plan + let logical_plan = planner.plan(&stmt); - logical_plan + Ok(logical_plan) } } @@ -161,7 +143,9 @@ mod tests { pub fn test_insert_sql() { let mut db = super::Database::new_temp(); db.run(&"create table t1 (a int, b int)".to_string()); - let insert_rows = db.run(&"insert into t1 values (1, 1), (2, 3), (5, 4)".to_string()); + let insert_rows = db + .run(&"insert into t1 values (1, 1), (2, 3), (5, 4)".to_string()) + .unwrap(); assert_eq!(insert_rows.len(), 1); let schema = Schema::new(vec![Column::new( @@ -177,12 +161,12 @@ mod tests { let mut db = super::Database::new_temp(); db.run(&"create table t1 (a int, b bigint)".to_string()); - let select_result = db.run(&"select * from t1".to_string()); + let select_result = db.run(&"select * from t1".to_string()).unwrap(); assert_eq!(select_result.len(), 0); db.run(&"insert into t1 values (1, 1), (2, 3), (5, 4)".to_string()); - let select_result = db.run(&"select * from t1".to_string()); + let select_result = db.run(&"select * from t1".to_string()).unwrap(); assert_eq!(select_result.len(), 3); let schema = Schema::new(vec![ @@ -220,7 +204,9 @@ mod tests { let mut db = super::Database::new_temp(); db.run(&"create table t1 (a int, b int)".to_string()); db.run(&"insert into t1 values (1, 1), (2, 3), (5, 4)".to_string()); - let select_result = db.run(&"select a from t1 where a <= b".to_string()); + let select_result = db + .run(&"select a from t1 where a <= b".to_string()) + .unwrap(); assert_eq!(select_result.len(), 2); let schema = Schema::new(vec![Column::new("a".to_string(), DataType::Int32)]); @@ -239,7 +225,9 @@ mod tests { let mut db = super::Database::new_temp(); db.run(&"create table t1 (a int, b int)".to_string()); db.run(&"insert into t1 values (1, 1), (2, 3), (5, 4)".to_string()); - let select_result = db.run(&"select * from t1 limit 1 offset 1".to_string()); + let select_result = db + .run(&"select * from t1 limit 1 offset 1".to_string()) + .unwrap(); assert_eq!(select_result.len(), 1); let schema = Schema::new(vec![ @@ -396,7 +384,9 @@ mod tests { let mut db = super::Database::new_temp(); db.run(&"create table t1 (a int, b int)".to_string()); db.run(&"insert into t1 values (5, 6), (1, 2), (1, 4)".to_string()); - let select_result = db.run(&"select * from t1 order by a, b desc".to_string()); + let select_result = db + .run(&"select * from t1 order by a, b desc".to_string()) + .unwrap(); assert_eq!(select_result.len(), 3); let schema = Schema::new(vec![ diff --git a/bustubx/src/error.rs b/bustubx/src/error.rs index a9a9a95..9f3bcfe 100644 --- a/bustubx/src/error.rs +++ b/bustubx/src/error.rs @@ -4,10 +4,15 @@ pub type BustubxResult = Result; #[derive(Debug, Error)] pub enum BustubxError { - #[error("Not implement: {0}")] - NotImplement(String), + #[error("Not support: {0}")] + NotSupport(String), + #[error("Internal error: {0}")] Internal(String), + #[error("IO error: {0}")] Io(#[from] std::io::Error), + + #[error("Parser error: {0}")] + Parser(#[from] sqlparser::parser::ParserError), } diff --git a/bustubx/src/optimizer/heuristic/graph.rs b/bustubx/src/optimizer/heuristic/graph.rs index f3a7f31..69ed669 100644 --- a/bustubx/src/optimizer/heuristic/graph.rs +++ b/bustubx/src/optimizer/heuristic/graph.rs @@ -182,7 +182,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); // 0: project // 1: join @@ -231,7 +233,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let graph = super::HepGraph::new(Arc::new(logical_plan)); let ids = graph.bfs(graph.root); @@ -247,7 +251,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let mut graph = super::HepGraph::new(Arc::new(logical_plan)); let ids = graph.bfs(graph.root); @@ -283,7 +289,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let mut graph = super::HepGraph::new(Arc::new(logical_plan)); let ids = graph.bfs(graph.root); @@ -332,7 +340,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let mut graph = super::HepGraph::new(Arc::new(logical_plan)); let ids = graph.bfs(graph.root); @@ -384,7 +394,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let mut graph = super::HepGraph::new(Arc::new(logical_plan)); let ids = graph.bfs(graph.root); @@ -414,7 +426,9 @@ mod tests { LogicalOperator::Project(_) )); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let mut graph = super::HepGraph::new(Arc::new(logical_plan)); graph.remove_node(HepNodeId::new(1), false); @@ -439,7 +453,9 @@ mod tests { let mut db = Database::new_temp(); db.run("create table t1(a int, b int)"); db.run("create table t2(a int, b int)"); - let logical_plan = db.build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a"); + let logical_plan = db + .build_logical_plan("select * from t1 inner join t2 on t1.a = t2.a") + .unwrap(); let graph = super::HepGraph::new(Arc::new(logical_plan)); let output_plan = graph.to_plan(); diff --git a/bustubx/src/parser/mod.rs b/bustubx/src/parser/mod.rs index 55d653a..e3e6252 100644 --- a/bustubx/src/parser/mod.rs +++ b/bustubx/src/parser/mod.rs @@ -1,20 +1,18 @@ -use sqlparser::{ - ast::Statement, - dialect::PostgreSqlDialect, - parser::{Parser, ParserError}, -}; +use crate::error::BustubxResult; +use sqlparser::{ast::Statement, dialect::PostgreSqlDialect, parser::Parser}; use tracing::span; -pub fn parse_sql(sql: &str) -> Result, ParserError> { +pub fn parse_sql(sql: &str) -> BustubxResult> { let _parse_sql_span = span!(tracing::Level::INFO, "parse_sql", sql).entered(); - Parser::parse_sql(&PostgreSqlDialect {}, sql) + let stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?; + Ok(stmts) } mod tests { #[test] pub fn test_sql() { let sql = "select * from t1, t2, t3 inner join t4 on t3.id = t4.id"; - let stmts = super::parse_sql(sql); + let stmts = super::parse_sql(sql).unwrap(); println!("{:?}", stmts); } }