From bccfdda6cfaf65db263f55a92a73dc6a3ae45f87 Mon Sep 17 00:00:00 2001 From: Hengyu Ai Date: Thu, 23 May 2024 18:12:17 +0800 Subject: [PATCH] fix(queries/user.go): thread safe invitation code creation --- pkg/queries/user.go | 75 +++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/pkg/queries/user.go b/pkg/queries/user.go index 74c1672..0385d7b 100644 --- a/pkg/queries/user.go +++ b/pkg/queries/user.go @@ -25,6 +25,7 @@ import ( "fmt" "math/rand" "strings" + "sync" "time" "unicode" @@ -130,7 +131,7 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error { u.InvitedByUserID = inviter.ID // TODO: only once for the inviter? inviter.Reward += 100 - db.Save(inviter) + db.Select("reward").Save(inviter) } // 检查邮箱是否已存在 @@ -154,15 +155,19 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error { u.IsActive = false u.IsAdmin = false - code, err := createInvitationCode(db) + err = db.Transaction(func(tx *gorm.DB) error { + if err := tx.Create(u).Error; err != nil { + return errors.Wrap(err, errors.DatabaseError) + } + if err := createInvitationCode(tx, u); err != nil { + return err + } + return nil + }) + if err != nil { return err } - u.InvitationCode = code - - if err = db.Create(u).Error; err != nil { - return errors.Wrap(err, errors.DatabaseError) - } body := fmt.Sprintf(`

欢迎注册%s

我们已经接收到您的电子邮箱验证申请,请点击以下链接完成注册。

@@ -254,13 +259,7 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) { } if user.InvitationCode == "" { - code, err := createInvitationCode(db) - if err != nil { - return nil, err - } - - user.InvitationCode = code - err = db.Select("invitation_code").Save(user).Error + err := createInvitationCode(db, user) if err != nil { return nil, errors.Wrap(err, errors.DatabaseError) } @@ -468,25 +467,43 @@ func CheckInvitationCode(code string) bool { return true } -func createInvitationCode(db *gorm.DB) (string, error) { - // try a few times before giving up - for i := 0; i < 5; i++ { - codeRunes := make([]rune, 0, 5) - for i := 0; i < 5; i++ { - codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)]) - } - code := string(codeRunes) +var invitationCodeMutex sync.Mutex - _, err := GetUserByInvitationCode(db, code) - if err != nil { - if errors.Is(err, errors.UserNotExists) { - return code, nil +func createInvitationCode(db *gorm.DB, user *models.User) error { + if db == nil { + db = database.GetDB() + } + if user.InvitationCode != "" { + return nil + } + // try a few times to generate an unique code + return db.Transaction(func(tx *gorm.DB) error { + invitationCodeMutex.Lock() + defer invitationCodeMutex.Unlock() + + for i := 0; i < 10; i++ { + code := generateInvitationCode() + _, err := GetUserByInvitationCode(tx, code) + if err != nil { + if errors.Is(err, errors.UserNotExists) { + user.InvitationCode = code + return tx.Select("invitation_code").Save(user).Error + } + return err } - return "", err } - } - return "", errors.New(errors.InternalServerError) + return errors.New(errors.InternalServerError) + }) +} + +func generateInvitationCode() string { + // genetate random code with length 5, only contains [A-Za-z0-9] + codeRunes := make([]rune, 0, 5) + for i := 0; i < 5; i++ { + codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)]) + } + return string(codeRunes) } func GetUserByInvitationCode(db *gorm.DB, code string) (*models.User, error) {