From cbed0613fe7405a1acbdf5b808400aeb8861ba4b Mon Sep 17 00:00:00 2001 From: NaN Date: Fri, 20 Dec 2024 12:28:47 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20=E4=BF=AE=E5=A4=8D=E5=BD=93?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20map=20=E6=9D=A5=E5=88=9B=E5=BB=BA=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=97=B6=E4=B8=8D=E4=BC=9A=E5=9B=9E=E5=86=99=E4=B8=BB?= =?UTF-8?q?=E9=94=AE=E7=9A=84=E9=97=AE=E9=A2=98=20(#24)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: iTanken <23544702+iTanken@users.noreply.github.com> Co-authored-by: NaN --- create.go | 40 +++++++++++++++++++++++++++++++++++++++- dameng_test.go | 32 +++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/create.go b/create.go index 5705943..aa63ea0 100644 --- a/create.go +++ b/create.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func Create(db *gorm.DB) { @@ -169,7 +170,44 @@ func Create(db *gorm.DB) { } } - if updateInsertID { + if !updateInsertID { + return + } + // map insert support return increment id + // https://github.com/go-gorm/gorm/pull/6662 + var pkFieldName = "@id" + if db.Statement.Schema != nil { + if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + return + } + pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName + } + // append @id column with value for auto-increment primary key + // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 + switch values := db.Statement.Dest.(type) { + case map[string]interface{}: + values[pkFieldName] = insertID + case *map[string]interface{}: + (*values)[pkFieldName] = insertID + case []map[string]interface{}, *[]map[string]interface{}: + mapValues, ok := values.([]map[string]interface{}) + if !ok { + if v, ok := values.(*[]map[string]interface{}); ok { + if *v != nil { + mapValues = *v + } + } + } + // if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + // } + for _, mapValue := range mapValues { + if mapValue != nil { + mapValue[pkFieldName] = insertID + } + insertID += schema.DefaultAutoIncrementIncrement + } + default: switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: //if config.LastInsertIDReversed { diff --git a/dameng_test.go b/dameng_test.go index a5b139a..91a8d74 100644 --- a/dameng_test.go +++ b/dameng_test.go @@ -10,6 +10,7 @@ import ( "github.com/godoes/gorm-dameng/dm8" "gorm.io/gorm" + "gorm.io/gorm/logger" ) var ( @@ -86,7 +87,9 @@ func TestGormConnExample(t *testing.T) { // 参考链接: https://eco.dameng.com/document/dm/zh-cn/pm/go-rogramming-guide.html#11.8%20ORM%20%E6%96%B9%E8%A8%80%E5%8C%85 dialector := New(Config{DSN: dsn, VarcharSizeIsCharLength: true}) - db, err := gorm.Open(dialector, &gorm.Config{}) + db, err := gorm.Open(dialector, &gorm.Config{ + Logger: logger.New(log.Default(), logger.Config{LogLevel: logger.Info}), + }) if err != nil { t.Fatalf("连接数据库 [%s@%s:%d] 失败:%v", dmUsername, dmHost, dmPort, err) } else { @@ -114,6 +117,22 @@ func TestGormConnExample(t *testing.T) { } else { t.Logf("创建数据成功!数据 ID:%d", data.ID) } + // Create - 批量创建 map 型数据 + list := []map[string]any{ + {"code": "M42", "price": 200, "remark": "map1"}, + {"code": "N42", "price": 200, "remark": "map2"}, + } + var listIDs []any + db.Model(&Product{}).Create(list) + if err = db.Error; err != nil { + t.Errorf("批量创建 map 型数据失败:%v", err) + } else { + listIDs = make([]any, len(list)) + for i, item := range list { + listIDs[i] = item["id"] + } + t.Logf("批量创建 map 型数据成功!数据 IDs:%+v", listIDs) + } // Read var product Product @@ -155,15 +174,22 @@ func TestGormConnExample(t *testing.T) { } // Delete - 删除 product - db.Delete(&product, 1) + db.Delete(&product, data.ID) if err = db.Error; err != nil { t.Errorf("删除数据失败:%v", err) } else { t.Log("删除数据成功!") } + // Delete - 批量删除 + db.Delete(&Product{}, "id IN (?)", listIDs) + if err = db.Error; err != nil { + t.Errorf("批量删除数据失败:%v", err) + } else { + t.Log("批量删除数据成功!") + } //goland:noinspection SqlNoDataSourceInspection - //db.Exec(`DROP table "products"`) + db.Exec(`DROP table "products"`) if err = db.Error; err != nil { t.Errorf("删除表结构失败:%v", err) } else {