diff --git a/contrib/drivers/oracle/oracle_do_insert.go b/contrib/drivers/oracle/oracle_do_insert.go index d52496166f5..2dd36408f68 100644 --- a/contrib/drivers/oracle/oracle_do_insert.go +++ b/contrib/drivers/oracle/oracle_do_insert.go @@ -10,6 +10,8 @@ import ( "context" "database/sql" "fmt" + "github.com/gogf/gf/v2/container/gset" + "github.com/gogf/gf/v2/text/gstr" "strings" "github.com/gogf/gf/v2/database/gdb" @@ -24,10 +26,7 @@ func (d *Driver) DoInsert( ) (result sql.Result, err error) { switch option.InsertOption { case gdb.InsertOptionSave: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by oracle driver`, - ) + return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: return nil, gerror.NewCode( @@ -93,3 +92,114 @@ func (d *Driver) DoInsert( } return batchResult, nil } + +// doSave support upsert for Oracle +func (d *Driver) doSave(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + if len(option.OnConflict) == 0 { + return nil, gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, + ) + } + + if len(list) == 0 { + return nil, gerror.NewCode( + gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`, + ) + } + + var ( + one = list[0] + charL, charR = d.GetChars() + valueCharL, valueCharR = "'", "'" + + conflictKeys = option.OnConflict + conflictKeySet = gset.New(false) + + // insertKeys: Handle valid keys that need to be inserted + // insertValues: Handle values that need to be inserted + // updateValues: Handle values that need to be updated + // queryValues: Handle data that need to be upsert + queryValues, insertKeys, insertValues, updateValues []string + ) + + // conflictKeys slice type conv to set type + for _, conflictKey := range conflictKeys { + conflictKeySet.Add(gstr.ToUpper(conflictKey)) + } + + for key, value := range one { + saveValue := gconv.String(value) + queryValues = append( + queryValues, + fmt.Sprintf( + valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR, + saveValue, key, + ), + ) + + insertKeys = append(insertKeys, charL+key+charR) + insertValues = append(insertValues, "T2."+charL+key+charR) + + // filter conflict keys in updateValues + if !conflictKeySet.Contains(key) { + updateValues = append( + updateValues, + fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR), + ) + } + } + + batchResult := new(gdb.SqlResult) + sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys) + r, err := d.DoExec(ctx, link, sqlStr) + if err != nil { + return r, err + } + if n, err := r.RowsAffected(); err != nil { + return r, err + } else { + batchResult.Result = r + batchResult.Affected += n + } + return batchResult, nil +} + +// parseSqlForUpsert +// MERGE INTO {{table}} T1 +// USING ( SELECT {{queryValues}} FROM DUAL T2 +// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) +// WHEN NOT MATCHED THEN +// INSERT {{insertKeys}} VALUES {{insertValues}} +// WHEN MATCHED THEN +// UPDATE SET {{updateValues}} +func parseSqlForUpsert(table string, + queryValues, insertKeys, insertValues, updateValues, duplicateKey []string, +) (sqlStr string) { + var ( + queryValueStr = strings.Join(queryValues, ",") + insertKeyStr = strings.Join(insertKeys, ",") + insertValueStr = strings.Join(insertValues, ",") + updateValueStr = strings.Join(updateValues, ",") + duplicateKeyStr string + pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`) + ) + + for index, keys := range duplicateKey { + if index != 0 { + duplicateKeyStr += " AND " + } + duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) + duplicateKeyStr += duplicateTmp + } + + return fmt.Sprintf(pattern, + table, + queryValueStr, + duplicateKeyStr, + insertKeyStr, + insertValueStr, + updateValueStr, + ) +} diff --git a/contrib/drivers/oracle/oracle_z_unit_basic_test.go b/contrib/drivers/oracle/oracle_z_unit_basic_test.go index fd040d259b1..4b97fb79b54 100644 --- a/contrib/drivers/oracle/oracle_z_unit_basic_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_basic_test.go @@ -19,7 +19,7 @@ import ( "github.com/gogf/gf/v2/test/gtest" ) -func TestTables(t *testing.T) { +func Test_Tables(t *testing.T) { gtest.C(t, func(t *gtest.T) { tables := []string{"t_user1", "pop", "haha"} @@ -60,7 +60,7 @@ func TestTables(t *testing.T) { }) } -func TestTableFields(t *testing.T) { +func Test_Table_Fields(t *testing.T) { gtest.C(t, func(t *gtest.T) { createTable("t_user") defer dropTable("t_user") @@ -107,7 +107,7 @@ func TestTableFields(t *testing.T) { }) } -func TestDoInsert(t *testing.T) { +func Test_Do_Insert(t *testing.T) { gtest.C(t, func(t *gtest.T) { createTable("t_user") defer dropTable("t_user") diff --git a/contrib/drivers/oracle/oracle_z_unit_model_test.go b/contrib/drivers/oracle/oracle_z_unit_model_test.go index 712939cfaf4..5de12788dbd 100644 --- a/contrib/drivers/oracle/oracle_z_unit_model_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_model_test.go @@ -128,7 +128,7 @@ func Test_Model_RightJoin(t *testing.T) { }) } -func TestPage(t *testing.T) { +func Test_Page(t *testing.T) { table := createInitTable() defer dropTable(table) result, err := db.Model(table).Page(1, 2).Order("ID").All() @@ -162,7 +162,6 @@ func TestPage(t *testing.T) { func Test_Model_Insert(t *testing.T) { table := createTable() defer dropTable(table) - // db.SetDebug(true) gtest.C(t, func(t *gtest.T) { user := db.Model(table) result, err := user.Data(g.Map{ @@ -1101,6 +1100,83 @@ func Test_Model_WhereOrNotLike(t *testing.T) { }) } +func Test_Model_Save(t *testing.T) { + table := createTable("test") + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime *gtime.Time + } + var ( + user User + count int + result sql.Result + createTime = gtime.Now().Format("Y-m-d") + err error + ) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "15d55ad283aa400af464c76d713c07ad", + "nickname": "n1", + "create_time": createTime, + }).OnConflict("id").Save() + + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 1) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "15d55ad283aa400af464c76d713c07ad") + t.Assert(user.NickName, "n1") + t.Assert(user.CreateTime.Format("Y-m-d"), createTime) + + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "n2", + "create_time": createTime, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "25d55ad283aa400af464c76d713c07ad") + t.Assert(user.NickName, "n2") + t.Assert(user.CreateTime.Format("Y-m-d"), createTime) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} + +func Test_Model_Replace(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t11", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "T11", + "create_time": "2018-10-24 10:00:00", + }).Replace() + t.Assert(err, "Replace operation is not supported by oracle driver") + }) +} + /* not support the "AS" func Test_Model_Raw(t *testing.T) { table := createInitTable() diff --git a/contrib/drivers/pgsql/pgsql_format_upsert.go b/contrib/drivers/pgsql/pgsql_format_upsert.go index bc082243c2d..c4c8af91122 100644 --- a/contrib/drivers/pgsql/pgsql_format_upsert.go +++ b/contrib/drivers/pgsql/pgsql_format_upsert.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" @@ -19,7 +20,9 @@ import ( // For example: ON CONFLICT (id) DO UPDATE SET ... func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { if len(option.OnConflict) == 0 { - return "", gerror.New("Please specify conflict columns") + return "", gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, + ) } var onDuplicateStr string diff --git a/contrib/drivers/sqlite/sqlite_format_upsert.go b/contrib/drivers/sqlite/sqlite_format_upsert.go index 34fc3ccce64..5821144a13e 100644 --- a/contrib/drivers/sqlite/sqlite_format_upsert.go +++ b/contrib/drivers/sqlite/sqlite_format_upsert.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" @@ -19,7 +20,9 @@ import ( // For example: ON CONFLICT (id) DO UPDATE SET ... func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) { if len(option.OnConflict) == 0 { - return "", gerror.New("Please specify conflict columns") + return "", gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, + ) } var onDuplicateStr string diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index d3b5b5b88aa..1d212975f28 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -396,6 +396,7 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption) ) } } + return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil }