fix: fix quota not consuming

This commit is contained in:
JustSong 2023-05-16 13:29:22 +08:00
parent a9ea1d9d10
commit 8afdc56b11
3 changed files with 57 additions and 8 deletions

View File

@ -128,6 +128,13 @@ func relayHelper(c *gin.Context) error {
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
} }
preConsumedQuota := 500 // TODO: make this configurable, take ratio into account
if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return err
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body)
if err != nil { if err != nil {
return err return err
@ -179,7 +186,8 @@ func relayHelper(c *gin.Context) error {
} }
ratio := common.GetModelRatio(textRequest.Model) ratio := common.GetModelRatio(textRequest.Model)
quota = int(float64(quota) * ratio) quota = int(float64(quota) * ratio)
err := model.DecreaseTokenQuota(tokenId, quota) quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil { if err != nil {
common.SysError("Error consuming token remain quota: " + err.Error()) common.SysError("Error consuming token remain quota: " + err.Error())
} }

View File

@ -111,7 +111,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
consumeQuota := !token.UnlimitedQuota consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") { if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false consumeQuota = false
} }

View File

@ -130,7 +130,23 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete() return token.Delete()
} }
func DecreaseTokenQuota(tokenId int, quota int) (err error) { func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota + ?", quota)).Error
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&Token{}).Where("id = ?", id).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
@ -138,7 +154,7 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if err != nil { if err != nil {
return err return err
} }
if token.RemainQuota < quota { if !token.UnlimitedQuota && token.RemainQuota < quota {
return errors.New("令牌额度不足") return errors.New("令牌额度不足")
} }
userQuota, err := GetUserQuota(token.UserId) userQuota, err := GetUserQuota(token.UserId)
@ -163,17 +179,42 @@ func DecreaseTokenQuota(tokenId int, quota int) (err error) {
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota-quota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {
common.SysError("发送邮件失败:" + err.Error()) common.SysError("发送邮件失败:" + err.Error())
} }
} }
}() }()
} }
err = DB.Model(&Token{}).Where("id = ?", tokenId).Update("remain_quota", gorm.Expr("remain_quota - ?", quota)).Error if !token.UnlimitedQuota {
if err != nil { err = DecreaseTokenQuota(tokenId, quota)
return err if err != nil {
return err
}
} }
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
return err return err
} }
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if !token.UnlimitedQuota {
if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota)
} else {
err = IncreaseTokenQuota(tokenId, -quota)
}
if err != nil {
return err
}
}
return nil
}