From bebe7330cf26a1155199894de94ef1006327037f Mon Sep 17 00:00:00 2001 From: Richard Knop Date: Sat, 26 Oct 2024 00:21:40 +0100 Subject: [PATCH] feat: simple UPDATE implementation --- README.md | 15 +- internal/pkg/database/database.go | 14 +- internal/pkg/minisql/cursor.go | 28 ++-- internal/pkg/minisql/delete.go | 4 +- internal/pkg/minisql/insert_test.go | 29 ++-- internal/pkg/minisql/minisql_test.go | 47 ++++-- .../pkg/minisql/minisqltest/minisqltest.go | 70 ++++++++- internal/pkg/minisql/row.go | 21 +++ internal/pkg/minisql/select.go | 2 +- internal/pkg/minisql/select_test.go | 146 ++++++++---------- internal/pkg/minisql/update.go | 89 ++++++++++- internal/pkg/minisql/update_test.go | 131 ++++++++++++++++ internal/pkg/parser/parser.go | 36 +++-- internal/pkg/parser/select.go | 5 +- internal/pkg/parser/update.go | 12 +- internal/pkg/parser/update_test.go | 28 +++- 16 files changed, 510 insertions(+), 167 deletions(-) create mode 100644 internal/pkg/minisql/update_test.go diff --git a/README.md b/README.md index bb6f13a..99a4016 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ minisql> I plan to implement more features of traditional relational databases in the future as part of this project simply to learn and discovery how various features I have grown acustomed to over the years are implemented under the hood. However, currently only a very small number of features are implemented: -- simple SQL parser (partial support for `CREATE TABLE`, `INSERT`, `SELECT` queries) +- simple SQL parser (partial support for `CREATE TABLE`, `INSERT`, `SELECT`, `UPDATE` queries) - only tables supported, no indexes (this means all selects are scanning whole tables for now) - only `int4`, `int8` and `varchar` columns supported - no primary key support (tables internally use row ID as key in B tree data structure) @@ -30,13 +30,13 @@ I plan to implement more features of traditional relational databases in the fut ### Planned features: -- support additional basic query types such as `UPDATE`, `DELETE`, `DROP TABLE` +- support additional basic query types such as `DELETE`, `DROP TABLE` - support `NULL` values - B+ tree and support indexes (starting with unique and primary) - more column types starting with simpler ones such as `bool` and `timestamp` - support bigger column types such as `text` that can overflow to more pages via linked list data structure - joins such as `INNER`, `LEFT`, `RIGHT` -- support `ORDER BY`, `GROUP BY` +- support `ORDER BY`, `LIMIT`, `GROUP BY` - dedicate first 100B of root page for config similar to how sqlite does it - support altering tables - transactions @@ -85,7 +85,6 @@ Insert more than a single page worth of data: ```sh minisql> insert into foo(id, email, age) values(1, 'john@example.com', 35), (2, 'jane@example.com', 32), (3, 'jack@example.com', 27), (4, 'jane@example.com', 32), (5, 'jack@example.com', 27), (6, 'jane@example.com', 32), (7, 'jack@example.com', 27), (8, 'jane@example.com', 32), (9, 'jack@example.com', 27), (10, 'jane@example.com', 32), (11, 'jack@example.com', 27), (12, 'jane@example.com', 32), (13, 'jack@example.com', 27), (14, 'jack@example.com', 27), (15, 'jack@example.com', 27) Rows affected: 15 - minisql> ``` @@ -101,3 +100,11 @@ minisql> select * from foo minisql> ``` +Update rows: + +```sh +minisql> update foo set id = 45 where id = 75 +Rows affected: 0 +minisql> +``` + diff --git a/internal/pkg/database/database.go b/internal/pkg/database/database.go index 8228bd6..1a493d3 100644 --- a/internal/pkg/database/database.go +++ b/internal/pkg/database/database.go @@ -324,12 +324,7 @@ func (d *Database) executeUpdate(ctx context.Context, stmt minisql.Statement) (m return minisql.StatementResult{}, errTableDoesNotExist } - if err := aTable.Update(ctx, stmt); err != nil { - return minisql.StatementResult{}, err - } - - // TODO - calculate rows affected properly - return minisql.StatementResult{RowsAffected: 0}, nil + return aTable.Update(ctx, stmt) } func (d *Database) executeDelete(ctx context.Context, stmt minisql.Statement) (minisql.StatementResult, error) { @@ -338,10 +333,5 @@ func (d *Database) executeDelete(ctx context.Context, stmt minisql.Statement) (m return minisql.StatementResult{}, errTableDoesNotExist } - if err := aTable.Delete(ctx, stmt); err != nil { - return minisql.StatementResult{}, err - } - - // TODO - calculate rows affected properly - return minisql.StatementResult{RowsAffected: 0}, nil + return aTable.Delete(ctx, stmt) } diff --git a/internal/pkg/minisql/cursor.go b/internal/pkg/minisql/cursor.go index 2fc75d6..d9829ad 100644 --- a/internal/pkg/minisql/cursor.go +++ b/internal/pkg/minisql/cursor.go @@ -35,7 +35,7 @@ func (c *Cursor) LeafNodeInsert(ctx context.Context, key uint64, aRow *Row) erro } aPage.LeafNode.Header.Cells += 1 - err = saveToCell(ctx, &aPage.LeafNode.Cells[c.CellIdx], key, aRow) + err = saveToCell(&aPage.LeafNode.Cells[c.CellIdx], key, aRow) return err } @@ -98,7 +98,7 @@ func (c *Cursor) LeafNodeSplitInsert(ctx context.Context, key uint64, aRow *Row) destCell := &destPage.LeafNode.Cells[cellIdx] if i == c.CellIdx { - if err := saveToCell(ctx, destCell, key, aRow); err != nil { + if err := saveToCell(destCell, key, aRow); err != nil { return err } } else if i > c.CellIdx { @@ -134,7 +134,7 @@ func (c *Cursor) LeafNodeSplitInsert(ctx context.Context, key uint64, aRow *Row) return c.Table.InternalNodeInsert(ctx, parentPageIdx, newPageIdx) } -func saveToCell(ctx context.Context, cell *Cell, key uint64, aRow *Row) error { +func saveToCell(cell *Cell, key uint64, aRow *Row) error { rowBuf, err := aRow.Marshal() if err != nil { return err @@ -144,32 +144,42 @@ func saveToCell(ctx context.Context, cell *Cell, key uint64, aRow *Row) error { return nil } -func (c *Cursor) fetchRow(ctx context.Context) (Row, error) { +func updateCell(cell *Cell, aRow *Row) error { + rowBuf, err := aRow.Marshal() + if err != nil { + return err + } + copy(cell.Value[:], rowBuf) + return nil +} + +func (c *Cursor) fetchRow(ctx context.Context) (Row, *Cell, error) { aPage, err := c.Table.pager.GetPage(ctx, c.Table, c.PageIdx) if err != nil { - return Row{}, err + return Row{}, nil, err } aRow := NewRow(c.Table.Columns) if err := UnmarshalRow(aPage.LeafNode.Cells[c.CellIdx].Value[:], &aRow); err != nil { - return Row{}, err + return Row{}, nil, err } + destCell := &aPage.LeafNode.Cells[c.CellIdx] // There are still more cells in the page, move cursor to next cell and return if c.CellIdx < aPage.LeafNode.Header.Cells-1 { c.CellIdx += 1 - return aRow, nil + return aRow, destCell, nil } // If there is no leaf page to the right, set end of table flag and return if aPage.LeafNode.Header.NextLeaf == 0 { c.EndOfTable = true - return aRow, nil + return aRow, destCell, nil } // Otherwise, we try to move the cursor to the next leaf page c.PageIdx = aPage.LeafNode.Header.NextLeaf c.CellIdx = 0 - return aRow, nil + return aRow, destCell, nil } diff --git a/internal/pkg/minisql/delete.go b/internal/pkg/minisql/delete.go index a1ae3e7..a5819d8 100644 --- a/internal/pkg/minisql/delete.go +++ b/internal/pkg/minisql/delete.go @@ -5,7 +5,7 @@ import ( "fmt" ) -func (t *Table) Delete(ctx context.Context, stmt Statement) error { +func (t *Table) Delete(ctx context.Context, stmt Statement) (StatementResult, error) { fmt.Println("TODO - implement DELETE") - return fmt.Errorf("not implemented") + return StatementResult{}, fmt.Errorf("not implemented") } diff --git a/internal/pkg/minisql/insert_test.go b/internal/pkg/minisql/insert_test.go index 458a233..492c5fa 100644 --- a/internal/pkg/minisql/insert_test.go +++ b/internal/pkg/minisql/insert_test.go @@ -26,7 +26,7 @@ func TestTable_Insert(t *testing.T) { stmt := Statement{ Kind: Insert, TableName: "foo", - Fields: []string{"id", "email", "age"}, + Fields: columnNames(testColumns...), Inserts: [][]any{aRow.Values}, } @@ -58,7 +58,7 @@ func TestTable_Insert_MultiInsert(t *testing.T) { stmt := Statement{ Kind: Insert, TableName: "foo", - Fields: []string{"id", "email", "age"}, + Fields: columnNames(testColumns...), Inserts: [][]any{aRow.Values, aRow2.Values, aRow3.Values}, } @@ -116,7 +116,7 @@ func TestTable_Insert_SplitRootLeaf(t *testing.T) { stmt := Statement{ Kind: Insert, TableName: "foo", - Fields: []string{"id", "email", "age"}, + Fields: columnNames(testColumns...), Inserts: [][]any{aRow.Values}, } @@ -198,19 +198,20 @@ func TestTable_Insert_SplitLeaf(t *testing.T) { return old }, nil) - // Insert test rows + // Batch insert test rows + stmt := Statement{ + Kind: Insert, + TableName: "foo", + Fields: columnNames(testBigColumns...), + Inserts: [][]any{}, + } for _, aRow := range rows { - stmt := Statement{ - Kind: Insert, - TableName: "foo", - Fields: []string{"id", "email", "name", "description"}, - Inserts: [][]any{aRow.Values}, - } - - err := aTable.Insert(ctx, stmt) - require.NoError(t, err) + stmt.Inserts = append(stmt.Inserts, aRow.Values) } + err := aTable.Insert(ctx, stmt) + require.NoError(t, err) + // Assert root node assert.Equal(t, 3, int(aRootPage.InternalNode.Header.KeysNum)) assert.True(t, aRootPage.InternalNode.Header.IsRoot) @@ -293,7 +294,7 @@ func TestTable_Insert_SplitInternalNode_CreateNewRoot(t *testing.T) { stmt := Statement{ Kind: Insert, TableName: "foo", - Fields: []string{"id", "email", "name", "description"}, + Fields: columnNames(testBigColumns...), Inserts: [][]any{}, } for _, aRow := range rows { diff --git a/internal/pkg/minisql/minisql_test.go b/internal/pkg/minisql/minisql_test.go index e5d4fec..108099b 100644 --- a/internal/pkg/minisql/minisql_test.go +++ b/internal/pkg/minisql/minisql_test.go @@ -65,6 +65,14 @@ func init() { } } +func columnNames(columns ...Column) []string { + names := make([]string, 0, len(columns)) + for _, aColumn := range columns { + names = append(names, aColumn.Name) + } + return names +} + type dataGen struct { *gofakeit.Faker } @@ -77,14 +85,6 @@ func newDataGen(seed uint64) *dataGen { return &g } -func (g *dataGen) Rows(number int) []Row { - rows := make([]Row, 0, number) - for i := 0; i < number; i++ { - rows = append(rows, g.Row()) - } - return rows -} - func (g *dataGen) Row() Row { return Row{ Columns: testColumns, @@ -96,10 +96,20 @@ func (g *dataGen) Row() Row { } } -func (g *dataGen) BigRows(number int) []Row { +func (g *dataGen) Rows(number int) []Row { + // Make sure all rows will have unique ID, this is important in some tests + idMap := map[int64]struct{}{} rows := make([]Row, 0, number) for i := 0; i < number; i++ { - rows = append(rows, g.BigRow()) + aRow := g.Row() + _, ok := idMap[aRow.Values[0].(int64)] + for ok { + aRow = g.Row() + _, ok = idMap[aRow.Values[0].(int64)] + } + rows = append(rows, aRow) + idMap[aRow.Values[0].(int64)] = struct{}{} + } return rows } @@ -116,6 +126,23 @@ func (g *dataGen) BigRow() Row { } } +func (g *dataGen) BigRows(number int) []Row { + // Make sure all rows will have unique ID, this is important in some tests + idMap := map[int64]struct{}{} + rows := make([]Row, 0, number) + for i := 0; i < number; i++ { + aRow := g.BigRow() + _, ok := idMap[aRow.Values[0].(int64)] + for ok { + aRow = g.BigRow() + _, ok = idMap[aRow.Values[0].(int64)] + } + rows = append(rows, aRow) + idMap[aRow.Values[0].(int64)] = struct{}{} + } + return rows +} + func newInternalPageWithCells(iCells []ICell, rightChildIdx uint32) *Page { aRoot := NewInternalNode() aRoot.Header.KeysNum = uint32(len(iCells)) diff --git a/internal/pkg/minisql/minisqltest/minisqltest.go b/internal/pkg/minisql/minisqltest/minisqltest.go index 49e8abb..74b194f 100644 --- a/internal/pkg/minisql/minisqltest/minisqltest.go +++ b/internal/pkg/minisql/minisqltest/minisqltest.go @@ -29,6 +29,29 @@ var ( Name: "age", }, } + + testBigColumns = []minisql.Column{ + { + Kind: minisql.Int8, + Size: 8, + Name: "id", + }, + { + Kind: minisql.Varchar, + Size: 255, + Name: "name", + }, + { + Kind: minisql.Varchar, + Size: 255, + Name: "email", + }, + { + Kind: minisql.Varchar, + Size: minisql.PageSize - 6 - 8 - 4*8 - 8 - 255 - 255, + Name: "description", + }, + } ) type DataGen struct { @@ -43,25 +66,64 @@ func NewDataGen(seed uint64) *DataGen { return &g } +func (g *DataGen) Row() minisql.Row { + return minisql.Row{ + Columns: testColumns, + Values: []any{ + g.Int64(), + g.Email(), + int32(g.IntRange(18, 100)), + }, + } +} + func (g *DataGen) Rows(number int) []minisql.Row { + // Make sure all rows will have unique ID, this is important in some tests + idMap := map[int64]struct{}{} rows := make([]minisql.Row, 0, number) for i := 0; i < number; i++ { - rows = append(rows, g.Row()) + aRow := g.Row() + _, ok := idMap[aRow.Values[0].(int64)] + for ok { + aRow = g.Row() + _, ok = idMap[aRow.Values[0].(int64)] + } + rows = append(rows, aRow) + idMap[aRow.Values[0].(int64)] = struct{}{} + } return rows } -func (g *DataGen) Row() minisql.Row { +func (g *DataGen) BigRow() minisql.Row { return minisql.Row{ - Columns: testColumns, + Columns: testBigColumns, Values: []any{ g.Int64(), g.Email(), - int32(g.IntRange(18, 100)), + g.Name(), + g.Sentence(15), }, } } +func (g *DataGen) BigRows(number int) []minisql.Row { + // Make sure all rows will have unique ID, this is important in some tests + idMap := map[int64]struct{}{} + rows := make([]minisql.Row, 0, number) + for i := 0; i < number; i++ { + aRow := g.BigRow() + _, ok := idMap[aRow.Values[0].(int64)] + for ok { + aRow = g.BigRow() + _, ok = idMap[aRow.Values[0].(int64)] + } + rows = append(rows, aRow) + idMap[aRow.Values[0].(int64)] = struct{}{} + } + return rows +} + func (g *DataGen) NewRootLeafPageWithCells(cells, rowSize int) *minisql.Page { aRootLeaf := minisql.NewLeafNode(uint64(rowSize)) aRootLeaf.Header.Header.IsRoot = true diff --git a/internal/pkg/minisql/row.go b/internal/pkg/minisql/row.go index 0d623c3..8b39cb4 100644 --- a/internal/pkg/minisql/row.go +++ b/internal/pkg/minisql/row.go @@ -8,6 +8,8 @@ import ( type Row struct { Columns []Column Values []any + // store internal pointer to cell so we can update the row + cell *Cell } // MaxCells returns how many rows can be stored in a single page @@ -61,6 +63,25 @@ func (r Row) GetValue(name string) (any, bool) { return r.Values[columnIdx], true } +func (r Row) SetValue(name string, value any) bool { + var ( + found bool + columnIdx = 0 + ) + for i, aColumn := range r.Columns { + if aColumn.Name == name { + found = true + columnIdx = i + break + } + } + if !found { + return false + } + r.Values[columnIdx] = value + return true +} + func (r Row) columnOffset(idx int) uint32 { offset := uint32(0) for i := 0; i < idx; i++ { diff --git a/internal/pkg/minisql/select.go b/internal/pkg/minisql/select.go index 5310f0e..f6fe3e4 100644 --- a/internal/pkg/minisql/select.go +++ b/internal/pkg/minisql/select.go @@ -37,7 +37,7 @@ func (t *Table) Select(ctx context.Context, stmt Statement) (StatementResult, er go func(out chan<- Row) { defer close(out) for aCursor.EndOfTable == false { - aRow, err := aCursor.fetchRow(ctx) + aRow, _, err := aCursor.fetchRow(ctx) if err != nil { errorsPipe <- err return diff --git a/internal/pkg/minisql/select_test.go b/internal/pkg/minisql/select_test.go index fcce311..7be97ab 100644 --- a/internal/pkg/minisql/select_test.go +++ b/internal/pkg/minisql/select_test.go @@ -9,66 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestTable_Select_SplitRootLeaf(t *testing.T) { - t.Parallel() - - var ( - ctx = context.Background() - pagerMock = new(MockPager) - rows = gen.Rows(15) - cells, rowSize = 0, rows[0].Size() - aRootPage = newRootLeafPageWithCells(cells, int(rowSize)) - rightChild = &Page{LeafNode: NewLeafNode(rowSize)} - leftChild = &Page{LeafNode: NewLeafNode(rowSize)} - aTable = NewTable(testLogger, "foo", testColumns, pagerMock, 0) - ) - - pagerMock.On("GetPage", mock.Anything, aTable, uint32(0)).Return(aRootPage, nil) - pagerMock.On("GetPage", mock.Anything, aTable, uint32(1)).Return(rightChild, nil) - pagerMock.On("GetPage", mock.Anything, aTable, uint32(2)).Return(leftChild, nil) - - // TotalPages is called twice, let's make sure the second time it's called, - // it will return incremented value since we have created a new page already - totalPages := uint32(1) - pagerMock.On("TotalPages").Return(func() uint32 { - old := totalPages - totalPages += 1 - return old - }, nil) - - // Insert test rows - for _, aRow := range rows { - stmt := Statement{ - Kind: Insert, - TableName: "foo", - Fields: []string{"id", "email", "age"}, - Inserts: [][]any{aRow.Values}, - } - - err := aTable.Insert(ctx, stmt) - require.NoError(t, err) - } - - // Select all rows - stmt := Statement{ - Kind: Select, - TableName: "foo", - Fields: []string{"id", "email", "age"}, - } - aResult, err := aTable.Select(ctx, stmt) - - require.NoError(t, err) - assert.Equal(t, aTable.Columns, aResult.Columns) - - aRow, err := aResult.Rows(ctx) - i := 0 - for ; err == nil; aRow, err = aResult.Rows(ctx) { - assert.Equal(t, rows[i], aRow) - i += 1 - } -} - -func TestTable_Select_LeafNodeInsert(t *testing.T) { +func TestTable_Select(t *testing.T) { t.Parallel() var ( @@ -99,34 +40,73 @@ func TestTable_Select_LeafNodeInsert(t *testing.T) { return old }, nil) - // Insert test rows + // Batch insert test rows + insertStmt := Statement{ + Kind: Insert, + TableName: "foo", + Fields: columnNames(testColumns...), + Inserts: [][]any{}, + } for _, aRow := range rows { - stmt := Statement{ - Kind: Insert, - TableName: "foo", - Fields: []string{"id", "email", "age"}, - Inserts: [][]any{aRow.Values}, - } - - err := aTable.Insert(ctx, stmt) - require.NoError(t, err) + insertStmt.Inserts = append(insertStmt.Inserts, aRow.Values) } - // Select all rows - stmt := Statement{ - Kind: Select, - TableName: "foo", - Fields: []string{"id", "email", "age"}, + err := aTable.Insert(ctx, insertStmt) + require.NoError(t, err) + + testCases := []struct { + Name string + Stmt Statement + Expected []Row + }{ + { + "Select all rows", + Statement{ + Kind: Select, + TableName: "foo", + Fields: columnNames(testColumns...), + }, + rows, + }, + { + "Select single row", + Statement{ + Kind: Select, + TableName: "foo", + Fields: columnNames(testColumns...), + Conditions: OneOrMore{ + { + { + Operand1: Operand{ + Type: Field, + Value: "id", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: rows[5].Values[0].(int64), + }, + }, + }, + }, + }, + []Row{rows[5]}, + }, } - aResult, err := aTable.Select(ctx, stmt) - require.NoError(t, err) - assert.Equal(t, aTable.Columns, aResult.Columns) + for _, aTestCase := range testCases { + t.Run(aTestCase.Name, func(t *testing.T) { + aResult, err := aTable.Select(ctx, aTestCase.Stmt) + require.NoError(t, err) + + // Use iterator to collect all rows + actual := []Row{} + aRow, err := aResult.Rows(ctx) + for ; err == nil; aRow, err = aResult.Rows(ctx) { + actual = append(actual, aRow) + } - aRow, err := aResult.Rows(ctx) - i := 0 - for ; err == nil; aRow, err = aResult.Rows(ctx) { - assert.Equal(t, rows[i], aRow) - i += 1 + assert.Equal(t, aTestCase.Expected, actual) + }) } } diff --git a/internal/pkg/minisql/update.go b/internal/pkg/minisql/update.go index 36a27ee..7cc66a5 100644 --- a/internal/pkg/minisql/update.go +++ b/internal/pkg/minisql/update.go @@ -5,7 +5,90 @@ import ( "fmt" ) -func (t *Table) Update(ctx context.Context, stmt Statement) error { - fmt.Println("TODO - implement UPDATE") - return fmt.Errorf("not implemented") +func (t *Table) Update(ctx context.Context, stmt Statement) (StatementResult, error) { + aCursor, err := t.Seek(ctx, uint64(0)) + if err != nil { + return StatementResult{}, err + } + aPage, err := t.pager.GetPage(ctx, t, aCursor.PageIdx) + if err != nil { + return StatementResult{}, err + } + aCursor.EndOfTable = aPage.LeafNode.Header.Cells == 0 + + t.logger.Sugar().Debug("updating rows") + + var ( + unfilteredPipe = make(chan Row) + filteredPipe = make(chan Row) + errorsPipe = make(chan error, 1) + stopChan = make(chan bool) + ) + + go func(out chan<- Row) { + defer close(out) + for aCursor.EndOfTable == false { + aRow, destCell, err := aCursor.fetchRow(ctx) + if err != nil { + errorsPipe <- err + return + } + aRow.cell = destCell + + select { + case <-stopChan: + return + case out <- aRow: + continue + } + } + }(unfilteredPipe) + + // Filter rows according the WHERE conditions + go func(in <-chan Row, out chan<- Row, conditions OneOrMore) { + defer close(out) + for aRow := range in { + if len(conditions) == 0 { + out <- aRow + continue + } + ok, err := aRow.CheckOneOrMore(conditions) + if err != nil { + errorsPipe <- err + return + } + if ok { + out <- aRow + } + } + }(unfilteredPipe, filteredPipe, stmt.Conditions) + + aResult := StatementResult{ + Columns: t.Columns, + } + + go func(in <-chan Row) { + defer close(stopChan) + for aRow := range in { + for name, value := range stmt.Updates { + aRow.SetValue(name, value) + } + + if err := updateCell(aRow.cell, &aRow); err != nil { + errorsPipe <- err + return + } + + aResult.RowsAffected += 1 + } + }(filteredPipe) + + select { + case <-ctx.Done(): + return aResult, fmt.Errorf("context done: %w", ctx.Err()) + case err := <-errorsPipe: + return aResult, err + case <-stopChan: + return aResult, nil + } } diff --git a/internal/pkg/minisql/update_test.go b/internal/pkg/minisql/update_test.go new file mode 100644 index 0000000..11d96b0 --- /dev/null +++ b/internal/pkg/minisql/update_test.go @@ -0,0 +1,131 @@ +package minisql + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestTable_Update(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + pagerMock = new(MockPager) + rows = gen.Rows(38) + cells, rowSize = 0, rows[0].Size() + aRootPage = newRootLeafPageWithCells(cells, int(rowSize)) + leaf1 = &Page{LeafNode: NewLeafNode(rowSize)} + leaf2 = &Page{LeafNode: NewLeafNode(rowSize)} + leaf3 = &Page{LeafNode: NewLeafNode(rowSize)} + leaf4 = &Page{LeafNode: NewLeafNode(rowSize)} + aTable = NewTable(testLogger, "foo", testColumns, pagerMock, 0) + ) + + pagerMock.On("GetPage", mock.Anything, aTable, uint32(0)).Return(aRootPage, nil) + pagerMock.On("GetPage", mock.Anything, aTable, uint32(1)).Return(leaf2, nil) + pagerMock.On("GetPage", mock.Anything, aTable, uint32(2)).Return(leaf1, nil) + pagerMock.On("GetPage", mock.Anything, aTable, uint32(3)).Return(leaf3, nil) + pagerMock.On("GetPage", mock.Anything, aTable, uint32(4)).Return(leaf4, nil) + + // TotalPages is called 3 times, let's make sure each time it's called, it returns + // an incremented value since we have created a new page in the meantime + totalPages := uint32(1) + pagerMock.On("TotalPages").Return(func() uint32 { + old := totalPages + totalPages += 1 + return old + }, nil) + + // Batch insert test rows + insertStmt := Statement{ + Kind: Insert, + TableName: "foo", + Fields: columnNames(testColumns...), + Inserts: [][]any{}, + } + for _, aRow := range rows { + insertStmt.Inserts = append(insertStmt.Inserts, aRow.Values) + } + + err := aTable.Insert(ctx, insertStmt) + require.NoError(t, err) + + // Update singe row + stmt := Statement{ + Kind: Update, + TableName: "foo", + Updates: map[string]any{ + "email": "updatedsingle@foo.bar", + }, + Conditions: OneOrMore{ + { + { + Operand1: Operand{ + Type: Field, + Value: "id", + }, + Operator: Eq, + Operand2: Operand{ + Type: Integer, + Value: rows[5].Values[0].(int64), + }, + }, + }, + }, + } + aResult, err := aTable.Update(ctx, stmt) + require.NoError(t, err) + assert.Equal(t, 1, aResult.RowsAffected) + + // Select all rows and check that email is updated for all + aResult, err = aTable.Select(ctx, Statement{ + Kind: Select, + TableName: "foo", + Fields: columnNames(testColumns...), + }) + require.NoError(t, err) + i := 0 + aRow, err := aResult.Rows(ctx) + for ; err == nil; aRow, err = aResult.Rows(ctx) { + assert.Equal(t, rows[i].Values[0].(int64), aRow.Values[0].(int64)) + if i == 5 { + assert.Equal(t, "updatedsingle@foo.bar", aRow.Values[1].(string)) + } else { + assert.Equal(t, rows[i].Values[1].(string), aRow.Values[1].(string)) + } + assert.Equal(t, rows[i].Values[2].(int32), aRow.Values[2].(int32)) + i += 1 + } + + // Update all rows + stmt = Statement{ + Kind: Update, + TableName: "foo", + Updates: map[string]any{ + "email": "updatedall@foo.bar", + }, + } + aResult, err = aTable.Update(ctx, stmt) + require.NoError(t, err) + assert.Equal(t, 38, aResult.RowsAffected) + + // Select all rows and check that email is updated for all + aResult, err = aTable.Select(ctx, Statement{ + Kind: Select, + TableName: "foo", + Fields: columnNames(testColumns...), + }) + require.NoError(t, err) + i = 0 + aRow, err = aResult.Rows(ctx) + for ; err == nil; aRow, err = aResult.Rows(ctx) { + assert.Equal(t, rows[i].Values[0].(int64), aRow.Values[0].(int64)) + assert.Equal(t, "updatedall@foo.bar", aRow.Values[1].(string)) + assert.Equal(t, rows[i].Values[2].(int32), aRow.Values[2].(int32)) + i += 1 + } +} diff --git a/internal/pkg/parser/parser.go b/internal/pkg/parser/parser.go index f52c274..cc0f7e8 100644 --- a/internal/pkg/parser/parser.go +++ b/internal/pkg/parser/parser.go @@ -12,12 +12,16 @@ import ( ) var ( - errInvalidStatementKind = fmt.Errorf("invalid statement kind") - errEmptyStatementKind = fmt.Errorf("statement kind cannot be empty") - errEmptyTableName = fmt.Errorf("table name cannot be empty") - errEmptyWhereClause = fmt.Errorf("at WHERE: empty WHERE clause") - errWhereWithoutOperator = fmt.Errorf("at WHERE: condition without operator") - errWhereRequiredForUpdateDelete = fmt.Errorf("at WHERE: WHERE clause is mandatory for UPDATE & DELETE") + errInvalidStatementKind = fmt.Errorf("invalid statement kind") + errEmptyStatementKind = fmt.Errorf("statement kind cannot be empty") + errEmptyTableName = fmt.Errorf("table name cannot be empty") + errEmptyWhereClause = fmt.Errorf("at WHERE: empty WHERE clause") + errWhereWithoutOperator = fmt.Errorf("at WHERE: condition without operator") + errWhereRequiredForUpdateDelete = fmt.Errorf("at WHERE: WHERE clause is mandatory for UPDATE & DELETE") + errWhereExpectedField = fmt.Errorf("at WHERE: expected field") + errWhereExpectedAndOr = fmt.Errorf("expected one of AND / OR") + errWhereExpectedQuotedValueOrInt = fmt.Errorf("at WHERE: expected quoted value or int value") + errWhereUnknownOperator = fmt.Errorf("at WHERE: unknown operator") ) var reservedWords = []string{ @@ -268,7 +272,7 @@ func (p *parser) doParseWhere() (bool, error) { case stepWhereConditionField: identifier := p.peek() if !isIdentifier(identifier) { - return false, fmt.Errorf("at WHERE: expected field") + return false, errWhereExpectedField } p.Statement.Conditions = p.Statement.Conditions.Append(minisql.Condition{ Operand1: minisql.Operand{ @@ -297,7 +301,7 @@ func (p *parser) doParseWhere() (bool, error) { case "!=": currentCondition.Operator = minisql.Ne default: - return false, fmt.Errorf("at WHERE: unknown operator") + return false, errWhereUnknownOperator } p.Conditions.UpdateLast(currentCondition) p.pop() @@ -313,9 +317,9 @@ func (p *parser) doParseWhere() (bool, error) { Value: identifier, } } else { - value, err := p.peekIntOrQuotedStringWithLength() - if err != nil { - return false, fmt.Errorf("at WHERE: expected quoted value or int value") + value, ln := p.peekIntOrQuotedStringWithLength() + if ln == 0 { + return false, errWhereExpectedQuotedValueOrInt } currentCondition.Operand2 = minisql.Operand{ Type: minisql.QuotedString, @@ -331,7 +335,7 @@ func (p *parser) doParseWhere() (bool, error) { case stepWhereOperator: anOperator := strings.ToUpper(p.peek()) if anOperator != "AND" && anOperator != "OR" { - return false, fmt.Errorf("expected one of AND / OR") + return false, errWhereExpectedAndOr } if anOperator == "OR" { p.Conditions = append(p.Conditions, make(minisql.Conditions, 0, 1)) @@ -408,16 +412,16 @@ func (p *parser) peepIntWithLength() (int64, int) { return int64(intValue), len(p.sql[p.i:len(p.sql)]) } -func (p *parser) peekIntOrQuotedStringWithLength() (any, error) { +func (p *parser) peekIntOrQuotedStringWithLength() (any, int) { intValue, ln := p.peepIntWithLength() if ln > 0 { - return intValue, nil + return intValue, ln } quotedValue, ln := p.peekQuotedStringWithLength() if ln > 0 { - return quotedValue, nil + return quotedValue, ln } - return nil, fmt.Errorf("neither int not quoted value found") + return nil, 0 } func (p *parser) peekIdentifierWithLength() (string, int) { diff --git a/internal/pkg/parser/select.go b/internal/pkg/parser/select.go index b5a9104..ad42ee3 100644 --- a/internal/pkg/parser/select.go +++ b/internal/pkg/parser/select.go @@ -6,7 +6,8 @@ import ( ) var ( - errSelectWithoutFields = fmt.Errorf("at SELECT: expected field to SELECT") + errSelectWithoutFields = fmt.Errorf("at SELECT: expected field to SELECT") + errSelectExpectedTableName = fmt.Errorf("at SELECT: expected quoted table name") ) func (p *parser) doParseSelect() (bool, error) { @@ -54,7 +55,7 @@ func (p *parser) doParseSelect() (bool, error) { case stepSelectFromTable: tableName := p.peek() if len(tableName) == 0 { - return false, fmt.Errorf("at SELECT: expected quoted table name") + return false, errSelectExpectedTableName } p.TableName = tableName p.pop() diff --git a/internal/pkg/parser/update.go b/internal/pkg/parser/update.go index f05e6c1..6503142 100644 --- a/internal/pkg/parser/update.go +++ b/internal/pkg/parser/update.go @@ -6,8 +6,8 @@ import ( ) var ( - errUpdateExpectedEquals = fmt.Errorf("at UPDATE: expected '='") - errUpdateExpectedQuotedValue = fmt.Errorf("at UPDATE: expected quoted value") + errUpdateExpectedEquals = fmt.Errorf("at UPDATE: expected '='") + errUpdateExpectedQuotedValueOrInt = fmt.Errorf("at UPDATE: expected quoted value or int") ) func (p *parser) doParseUpdate() (bool, error) { @@ -43,11 +43,11 @@ func (p *parser) doParseUpdate() (bool, error) { p.pop() p.step = stepUpdateValue case stepUpdateValue: - quotedValue, ln := p.peekQuotedStringWithLength() + value, ln := p.peekIntOrQuotedStringWithLength() if ln == 0 { - return false, errUpdateExpectedQuotedValue + return false, errUpdateExpectedQuotedValueOrInt } - p.setUpdate(p.nextUpdateField, quotedValue) + p.setUpdate(p.nextUpdateField, value) p.nextUpdateField = "" p.pop() maybeWhere := p.peek() @@ -67,7 +67,7 @@ func (p *parser) doParseUpdate() (bool, error) { return false, nil } -func (p *parser) setUpdate(field, value string) { +func (p *parser) setUpdate(field string, value any) { if p.Updates == nil { p.Updates = make(map[string]any) } diff --git a/internal/pkg/parser/update_test.go b/internal/pkg/parser/update_test.go index 5bd916b..39889b5 100644 --- a/internal/pkg/parser/update_test.go +++ b/internal/pkg/parser/update_test.go @@ -54,7 +54,7 @@ func TestParse_Update(t *testing.T) { Kind: minisql.Update, TableName: "a", }, - Err: errUpdateExpectedQuotedValue, + Err: errUpdateExpectedQuotedValueOrInt, }, { Name: "Incomplete UPDATE due to no WHERE clause fails", @@ -116,6 +116,32 @@ func TestParse_Update(t *testing.T) { }, }, }, + { + Name: "UPDATE works with int value being set", + SQL: "UPDATE 'a' SET b = 25 WHERE a = '1'", + Expected: minisql.Statement{ + Kind: minisql.Update, + TableName: "a", + Updates: map[string]any{ + "b": int64(25), + }, + Conditions: minisql.OneOrMore{ + { + { + Operand1: minisql.Operand{ + Type: minisql.Field, + Value: "a", + }, + Operator: minisql.Eq, + Operand2: minisql.Operand{ + Type: minisql.QuotedString, + Value: "1", + }, + }, + }, + }, + }, + }, { Name: "UPDATE works with simple quote inside", SQL: "UPDATE 'a' SET b = 'hello\\'world' WHERE a = '1'",