From 32e7b138bba6de4b37021666422ffad750810b82 Mon Sep 17 00:00:00 2001 From: Richard Knop Date: Sun, 20 Oct 2024 22:11:05 +0100 Subject: [PATCH] feat: implemented basic support for WHERE conditions --- README.md | 4 +- internal/pkg/minisql/row.go | 207 ++++++++++++++++++++++++++++ internal/pkg/minisql/row_test.go | 207 ++++++++++++++++++++++++++++ internal/pkg/minisql/select.go | 61 +------- internal/pkg/minisql/select_test.go | 35 ----- internal/pkg/minisql/stmt.go | 39 +++++- 6 files changed, 454 insertions(+), 99 deletions(-) diff --git a/README.md b/README.md index 1bef7c3..78e1ccd 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ minisql> I plan to implement more features of traditional relational databases in the future as part of this project simply to learn and discovery how various features I have grown acustomed to over the years are implemented under the hood. However, currently only a very small number of features are implemented: -- simple SQL parser (partial support for `CREATE TABLE`, `INSERT`, `SELECT`, `UPDATE`, `DELETE` queries) +- simple SQL parser (partial support for `CREATE TABLE`, `INSERT`, `SELECT` queries) - only tables supported, no indexes (this means all selects are scanning whole tables for now) - only `int4`, `int8` and `varchar` columns supported - no primary key support (tables internally use row ID as key in B tree data structure) @@ -30,6 +30,8 @@ I plan to implement more features of traditional relational databases in the fut ### Planned features: +- support additional basic query types such as `UPDATE`, `DELETE`, `DROP TABLE` +- - support `NULL` values - B+ tree and support indexes (starting with unique and primary) - more column types starting with simpler ones such as `bool` and `timestamp` - support bigger column types such as `text` that can overflow to more pages via linked list data structure diff --git a/internal/pkg/minisql/row.go b/internal/pkg/minisql/row.go index 040dd8a..3bfd5dc 100644 --- a/internal/pkg/minisql/row.go +++ b/internal/pkg/minisql/row.go @@ -185,3 +185,210 @@ func UnmarshalRow(buf []byte, aRow *Row) error { return nil } + +// CheckOneOrMore checks whether row satisfies one or more sets of conditions +// (cond1 AND cond2) OR (cond3 and cond4) ... etc +func (r Row) CheckOneOrMore(conditions OneOrMore) (bool, error) { + if len(conditions) == 0 { + return true, nil + } + + for _, aConditionGroup := range conditions { + groupConditionResult := true + for _, aCondition := range aConditionGroup { + ok, err := r.checkCondition(aCondition) + if err != nil { + return false, err + } + + if !ok { + groupConditionResult = false + break + } + } + + if groupConditionResult { + return true, nil + } + } + + return false, nil +} + +func (r Row) checkCondition(aCondition Condition) (bool, error) { + // left side is field, right side is literal value + if aCondition.Operand1.IsField() && !aCondition.Operand2.IsField() { + return r.compareFieldValue(aCondition.Operand1, aCondition.Operand2, aCondition.Operator) + } + + // left side is literal value, right side is field + if aCondition.Operand2.IsField() && !aCondition.Operand1.IsField() { + return r.compareFieldValue(aCondition.Operand2, aCondition.Operand1, aCondition.Operator) + } + + // both left and right are fields, compare 2 row values + if aCondition.Operand1.IsField() && aCondition.Operand2.IsField() { + return r.compareFields(aCondition.Operand1, aCondition.Operand2, aCondition.Operator) + } + + // both left and right are literal values, compare them + return aCondition.Operand1.Value == aCondition.Operand2.Value, nil +} + +func (r Row) compareFieldValue(fieldOperand, valueOperand Operand, operator Operator) (bool, error) { + if fieldOperand.Type != Field { + return false, fmt.Errorf("field operand invalid, type '%d'", fieldOperand.Type) + } + if valueOperand.Type == Field { + return false, fmt.Errorf("cannot compare column value against field operand") + } + name := fmt.Sprint(fieldOperand.Value) + aColumn, ok := r.GetColumn(name) + if !ok { + return false, fmt.Errorf("row does not contain column '%s'", name) + } + value, ok := r.GetValue(aColumn.Name) + if !ok { + return false, fmt.Errorf("row does not have '%s' column", name) + } + switch aColumn.Kind { + case Int4: + // Int values from parser always come back as int64, int4 row data + // will come back as int32 and int8 as int64 + return compareInt4(int64(value.(int32)), valueOperand.Value.(int64), operator) + case Int8: + return compareInt8(value.(int64), valueOperand.Value.(int64), operator) + case Varchar: + return compareVarchar(value, valueOperand.Value, operator) + default: + return false, fmt.Errorf("unknown column kind '%s'", aColumn.Kind) + } +} + +func (r Row) compareFields(field1, field2 Operand, operator Operator) (bool, error) { + if !field1.IsField() { + return false, fmt.Errorf("field operand invalid, type '%d'", field1.Type) + } + if field2.IsField() { + return false, fmt.Errorf("field operand invalid, type '%d'", field2.Type) + } + + if field1.Value == field2.Value { + return true, nil + } + + name1 := fmt.Sprint(field1.Value) + aColumn1, ok := r.GetColumn(name1) + if !ok { + return false, fmt.Errorf("row does not contain column '%s'", name1) + } + name2 := fmt.Sprint(field2.Value) + aColumn2, ok := r.GetColumn(name2) + if !ok { + return false, fmt.Errorf("row does not contain column '%s'", name2) + } + + if aColumn1.Kind != aColumn2.Kind { + return false, nil + } + + value1, ok := r.GetValue(aColumn1.Name) + if !ok { + return false, fmt.Errorf("row does not have '%s' column", name1) + } + value2, ok := r.GetValue(aColumn2.Name) + if !ok { + return false, fmt.Errorf("row does not have '%s' column", name2) + } + + switch aColumn1.Kind { + case Int4: + return compareInt4(value1, value2, operator) + case Int8: + return compareInt8(value1, value2, operator) + case Varchar: + return compareVarchar(value1, value2, operator) + default: + return false, fmt.Errorf("unknown column kind '%s'", aColumn1.Kind) + } +} + +func compareInt4(value1, value2 any, operator Operator) (bool, error) { + theValue1, ok := value1.(int64) + if !ok { + return false, fmt.Errorf("value '%v' cannot be cast as int64", value1) + } + theValue2, ok := value2.(int64) + if !ok { + return false, fmt.Errorf("operand value '%v' cannot be cast as int64", value2) + } + switch operator { + case Eq: + return int32(theValue1) == int32(theValue2), nil + case Ne: + return int32(theValue1) != int32(theValue2), nil + case Gt: + return int32(theValue1) > int32(theValue2), nil + case Lt: + return int32(theValue1) < int32(theValue2), nil + case Gte: + return int32(theValue1) >= int32(theValue2), nil + case Lte: + return int32(theValue1) <= int32(theValue2), nil + } + return false, fmt.Errorf("unknown operator '%s'", operator) +} + +func compareInt8(value1, value2 any, operator Operator) (bool, error) { + theValue1, ok := value1.(int64) + if !ok { + return false, fmt.Errorf("value '%v' cannot be cast as int64", value1) + } + theValue2, ok := value2.(int64) + if !ok { + return false, fmt.Errorf("operand value '%v' cannot be cast as int64", value2) + } + switch operator { + case Eq: + return theValue1 == theValue2, nil + case Ne: + return theValue1 != theValue2, nil + case Gt: + return theValue1 > theValue2, nil + case Lt: + return theValue1 < theValue2, nil + case Gte: + return theValue1 >= theValue2, nil + case Lte: + return theValue1 <= theValue2, nil + } + return false, fmt.Errorf("unknown operator '%s'", operator) +} + +func compareVarchar(value1, value2 any, operator Operator) (bool, error) { + theValue1, ok := value1.(string) + if !ok { + return false, fmt.Errorf("value '%v' cannot be cast as string", value1) + } + theValue2, ok := value2.(string) + if !ok { + return false, fmt.Errorf("operand value '%v' cannot be cast as string", value2) + } + // From Golang dosc (https://go.dev/ref/spec#Comparison_operators) + // Two string values are compared lexically byte-wise. + switch operator { + case Eq: + return theValue1 == theValue2, nil + case Ne: + return theValue1 != theValue2, nil + case Gt: + return theValue1 > theValue2, nil + case Lt: + return theValue1 < theValue2, nil + case Gte: + return theValue1 >= theValue2, nil + case Lte: + return theValue1 <= theValue2, nil + } + return false, fmt.Errorf("unknown operator '%s'", operator) +} diff --git a/internal/pkg/minisql/row_test.go b/internal/pkg/minisql/row_test.go index 2d34c8b..8fca521 100644 --- a/internal/pkg/minisql/row_test.go +++ b/internal/pkg/minisql/row_test.go @@ -23,3 +23,210 @@ func TestRow_Marshal(t *testing.T) { assert.Equal(t, aRow, actual) } + +func TestRow_CheckOneOrMore(t *testing.T) { + t.Parallel() + + var ( + aRow = Row{ + Columns: testColumns, + Values: []any{ + int64(125478), + "john.doe@example.com", + int32(25), + }, + } + idMatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "id", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: int64(125478), + }, + } + idMismatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "id", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: int64(678), + }, + } + emailMatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "email", + }, + Operator: Eq, + Operand2: Operand{ + Type: QuotedString, + Value: "john.doe@example.com", + }, + } + emailMismatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "email", + }, + Operator: Eq, + Operand2: Operand{ + Type: QuotedString, + Value: "jack.ipsum@example.com", + }, + } + ageMatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "age", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: int64(25), + }, + } + ageMismatch = Condition{ + Operand1: Operand{ + Type: Field, + Value: "age", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: int64(42), + }, + } + ) + + testCases := []struct { + Name string + Row Row + Conditions OneOrMore + Expected bool + }{ + { + "row matches if conditions are empty", + aRow, + OneOrMore{}, + true, + }, + { + "row matches if condition comparing with integer is true", + aRow, + OneOrMore{ + { + idMatch, + }, + }, + true, + }, + { + "row does not match if condition comparing with integer is false", + aRow, + OneOrMore{ + { + idMismatch, + }, + }, + false, + }, + { + "row matches if condition comparing with quoted string is true", + aRow, + OneOrMore{ + { + emailMatch, + }, + }, + true, + }, + { + "row does not match if condition comparing with quoted string is false", + aRow, + OneOrMore{ + { + emailMismatch, + }, + }, + false, + }, + { + "row matches if all conditions are true", + aRow, + OneOrMore{ + { + idMatch, + emailMatch, + }, + }, + true, + }, + { + "row does not match if not all conditions are true", + aRow, + OneOrMore{ + { + idMatch, + emailMismatch, + }, + }, + false, + }, + { + "row matches if all condition groups are true", + aRow, + OneOrMore{ + { + idMatch, + emailMatch, + }, + { + ageMatch, + }, + }, + true, + }, + { + "row matches if at least one of condition groups is true", + aRow, + OneOrMore{ + { + idMatch, + emailMismatch, + }, + { + ageMatch, + }, + }, + true, + }, + { + "row does not match if all condition groups are false", + aRow, + OneOrMore{ + { + idMatch, + emailMismatch, + }, + { + ageMismatch, + }, + }, + false, + }, + } + + for _, aTestCase := range testCases { + t.Run(aTestCase.Name, func(t *testing.T) { + actual, err := aTestCase.Row.CheckOneOrMore(aTestCase.Conditions) + require.NoError(t, err) + assert.Equal(t, aTestCase.Expected, actual) + }) + } +} diff --git a/internal/pkg/minisql/select.go b/internal/pkg/minisql/select.go index e537834..5310f0e 100644 --- a/internal/pkg/minisql/select.go +++ b/internal/pkg/minisql/select.go @@ -60,7 +60,7 @@ func (t *Table) Select(ctx context.Context, stmt Statement) (StatementResult, er out <- aRow continue } - ok, err := rowMatchesConditions(conditions, aRow) + ok, err := aRow.CheckOneOrMore(conditions) if err != nil { errorsPipe <- err return @@ -106,62 +106,3 @@ func (t *Table) Select(ctx context.Context, stmt Statement) (StatementResult, er return aResult, nil } - -func rowMatchesConditions(conditions OneOrMore, aRow Row) (bool, error) { - if len(conditions) == 0 { - return true, nil - } - - for _, aConditionGroup := range conditions { - groupConditionResult := true - for _, aCondition := range aConditionGroup { - ok, err := checkConditionOnRow(aCondition, aRow) - if err != nil { - return false, err - } - - if !ok { - groupConditionResult = false - break - } - } - - if groupConditionResult { - return true, nil - } - } - - return false, nil -} - -func checkConditionOnRow(aCondition Condition, aRow Row) (bool, error) { - // left side is field, right side is literal value - if aCondition.Operand1.IsField() && !aCondition.Operand2.IsField() { - value, ok := aRow.GetValue(fmt.Sprint(aCondition.Operand1)) - if !ok { - return false, fmt.Errorf("row does not have '%s' column", aCondition.Operand1.Value) - } - return value == aCondition.Operand2, nil - } - - // left side is literal value, right side is field - if aCondition.Operand2.IsField() && !aCondition.Operand1.IsField() { - value, ok := aRow.GetValue(fmt.Sprint(aCondition.Operand2)) - if !ok { - return false, fmt.Errorf("row does not have '%s' column", aCondition.Operand2.Value) - } - return value == aCondition.Operand1, nil - } - - // both left and right are fields, compare 2 row values - if aCondition.Operand1.IsField() && aCondition.Operand2.IsField() { - - } - - // both left and right are literal values, compare them - if !aCondition.Operand1.IsField() && !aCondition.Operand2.IsField() { - - } - - return false, nil -} diff --git a/internal/pkg/minisql/select_test.go b/internal/pkg/minisql/select_test.go index 8f06612..fcce311 100644 --- a/internal/pkg/minisql/select_test.go +++ b/internal/pkg/minisql/select_test.go @@ -130,38 +130,3 @@ func TestTable_Select_LeafNodeInsert(t *testing.T) { i += 1 } } - -func Test_RowMatchesConditions(t *testing.T) { - t.Parallel() - - aRow := Row{ - Columns: testColumns, - Values: []any{ - int64(125478), - "john.doe@example.com", - int32(25), - }, - } - - testCases := []struct { - Name string - aRow Row - Conditions OneOrMore - Expected bool - }{ - { - "row matches if conditions are empty", - aRow, - OneOrMore{}, - true, - }, - } - - for _, aTestCase := range testCases { - t.Run(aTestCase.Name, func(t *testing.T) { - actual, err := rowMatchesConditions(aTestCase.Conditions, aTestCase.aRow) - require.NoError(t, err) - assert.Equal(t, aTestCase.Expected, actual) - }) - } -} diff --git a/internal/pkg/minisql/stmt.go b/internal/pkg/minisql/stmt.go index 76a8172..5389d13 100644 --- a/internal/pkg/minisql/stmt.go +++ b/internal/pkg/minisql/stmt.go @@ -21,6 +21,25 @@ const ( Lte ) +func (o Operator) String() string { + switch o { + case Eq: + return "=" + case Ne: + return "!=" + case Gt: + return ">" + case Lt: + return "<" + case Gte: + return ">=" + case Lte: + return "<=" + default: + return "Unknown" + } +} + type OperandType int const ( @@ -97,10 +116,24 @@ const ( Varchar ) +func (k ColumnKind) String() string { + switch k { + case Int4: + return "Int4" + case Int8: + return "Int8" + case Varchar: + return "Varchar" + default: + return "Unknown" + } +} + type Column struct { - Kind ColumnKind - Size uint32 - Name string + Kind ColumnKind + Size uint32 + Nullable bool + Name string } type Statement struct {