diff --git a/model/redemption.go b/model/redemption.go index b821fd7c..369e1f69 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -3,6 +3,7 @@ package model import ( "errors" "fmt" + "gorm.io/gorm" "one-api/common" ) @@ -48,26 +49,27 @@ func Redeem(key string, userId int) (quota int, err error) { return 0, errors.New("无效的 user id") } redemption := &Redemption{} - err = DB.Where("`key` = ?", key).First(redemption).Error - if err != nil { - return 0, errors.New("无效的兑换码") - } - if redemption.Status != common.RedemptionCodeStatusEnabled { - return 0, errors.New("该兑换码已被使用") - } - err = IncreaseUserQuota(userId, redemption.Quota) - if err != nil { - return 0, err - } - go func() { + + err = DB.Transaction(func(tx *gorm.DB) error { + err := DB.Where("`key` = ?", key).First(redemption).Error + if err != nil { + return errors.New("无效的兑换码") + } + if redemption.Status != common.RedemptionCodeStatusEnabled { + return errors.New("该兑换码已被使用") + } + err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error + if err != nil { + return err + } redemption.RedeemedTime = common.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed - err := redemption.SelectUpdate() - if err != nil { - common.SysError("failed to update redemption status: " + err.Error()) - } - RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) - }() + return redemption.SelectUpdate() + }) + if err != nil { + return 0, errors.New("兑换失败," + err.Error()) + } + RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) return redemption.Quota, nil }