Skip to content

Commit

Permalink
🩹 修复当使用 map 来创建数据时不会回写主键的问题 (#24)
Browse files Browse the repository at this point in the history
Signed-off-by: iTanken <[email protected]>
Co-authored-by: NaN <[email protected]>
  • Loading branch information
websoe authored and iTanken committed Dec 21, 2024
1 parent 23f6d00 commit cbed061
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
40 changes: 39 additions & 1 deletion create.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

func Create(db *gorm.DB) {
Expand Down Expand Up @@ -169,7 +170,44 @@ func Create(db *gorm.DB) {
}
}

if updateInsertID {
if !updateInsertID {
return
}
// map insert support return increment id
// https://github.com/go-gorm/gorm/pull/6662
var pkFieldName = "@id"
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
// if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
// }
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
//if config.LastInsertIDReversed {
Expand Down
32 changes: 29 additions & 3 deletions dameng_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/godoes/gorm-dameng/dm8"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)

var (
Expand Down Expand Up @@ -86,7 +87,9 @@ func TestGormConnExample(t *testing.T) {

// 参考链接: https://eco.dameng.com/document/dm/zh-cn/pm/go-rogramming-guide.html#11.8%20ORM%20%E6%96%B9%E8%A8%80%E5%8C%85
dialector := New(Config{DSN: dsn, VarcharSizeIsCharLength: true})
db, err := gorm.Open(dialector, &gorm.Config{})
db, err := gorm.Open(dialector, &gorm.Config{
Logger: logger.New(log.Default(), logger.Config{LogLevel: logger.Info}),
})
if err != nil {
t.Fatalf("连接数据库 [%s@%s:%d] 失败:%v", dmUsername, dmHost, dmPort, err)
} else {
Expand Down Expand Up @@ -114,6 +117,22 @@ func TestGormConnExample(t *testing.T) {
} else {
t.Logf("创建数据成功!数据 ID:%d", data.ID)
}
// Create - 批量创建 map 型数据
list := []map[string]any{
{"code": "M42", "price": 200, "remark": "map1"},
{"code": "N42", "price": 200, "remark": "map2"},
}
var listIDs []any
db.Model(&Product{}).Create(list)
if err = db.Error; err != nil {
t.Errorf("批量创建 map 型数据失败:%v", err)
} else {
listIDs = make([]any, len(list))
for i, item := range list {
listIDs[i] = item["id"]
}
t.Logf("批量创建 map 型数据成功!数据 IDs:%+v", listIDs)
}

// Read
var product Product
Expand Down Expand Up @@ -155,15 +174,22 @@ func TestGormConnExample(t *testing.T) {
}

// Delete - 删除 product
db.Delete(&product, 1)
db.Delete(&product, data.ID)
if err = db.Error; err != nil {
t.Errorf("删除数据失败:%v", err)
} else {
t.Log("删除数据成功!")
}
// Delete - 批量删除
db.Delete(&Product{}, "id IN (?)", listIDs)
if err = db.Error; err != nil {
t.Errorf("批量删除数据失败:%v", err)
} else {
t.Log("批量删除数据成功!")
}

//goland:noinspection SqlNoDataSourceInspection
//db.Exec(`DROP table "products"`)
db.Exec(`DROP table "products"`)
if err = db.Error; err != nil {
t.Errorf("删除表结构失败:%v", err)
} else {
Expand Down

0 comments on commit cbed061

Please sign in to comment.