diff --git a/controller/relay-text.go b/controller/relay-text.go index 6f410f96..c6659799 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -377,7 +377,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) } } diff --git a/middleware/auth.go b/middleware/auth.go index 060e005c..95516d6e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) { c.Abort() return } - if !model.CacheIsUserEnabled(token.UserId) { + userEnabled, err := model.IsUserEnabled(token.UserId) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + c.Abort() + return + } + if !userEnabled { c.JSON(http.StatusForbidden, gin.H{ "error": gin.H{ "message": "用户已被封禁", diff --git a/model/cache.go b/model/cache.go index 55fbba9b..c28952b5 100644 --- a/model/cache.go +++ b/model/cache.go @@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error { return err } -func CacheIsUserEnabled(userId int) bool { +func CacheIsUserEnabled(userId int) (bool, error) { if !common.RedisEnabled { return IsUserEnabled(userId) } enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) - if err != nil { - status := common.UserStatusDisabled - if IsUserEnabled(userId) { - status = common.UserStatusEnabled - } - enabled = fmt.Sprintf("%d", status) - err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) - if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) - } + if err == nil { + return enabled == "1", nil } - return enabled == "1" + + userEnabled, err := IsUserEnabled(userId) + if err != nil { + return false, err + } + enabled = "0" + if userEnabled { + enabled = "1" + } + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + return userEnabled, err } var group2model2channels map[string]map[string][]*Channel diff --git a/model/token.go b/model/token.go index dfda27e3..0fa984d3 100644 --- a/model/token.go +++ b/model/token.go @@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err == nil { + if token.Status == common.TokenStatusExhausted { + return nil, errors.New("该令牌额度已用尽") + } else if token.Status == common.TokenStatusExpired { + return nil, errors.New("该令牌已过期") + } if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { - token.Status = common.TokenStatusExpired - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + token.Status = common.TokenStatusExpired + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { - token.Status = common.TokenStatusExhausted - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token status" + err.Error()) + if !common.RedisEnabled { + // in this case, we can make sure the token is exhausted + token.Status = common.TokenStatusExhausted + err := token.SelectUpdate() + if err != nil { + common.SysError("failed to update token status" + err.Error()) + } } return nil, errors.New("该令牌额度已用尽") } - go func() { - token.AccessedTime = common.GetTimestamp() - err := token.SelectUpdate() - if err != nil { - common.SysError("failed to update token" + err.Error()) - } - }() return token, nil } return nil, errors.New("无效的令牌") @@ -141,8 +144,9 @@ func IncreaseTokenQuota(id int, quota int) (err error) { func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota + ?", quota), - "used_quota": gorm.Expr("used_quota - ?", quota), + "remain_quota": gorm.Expr("remain_quota + ?", quota), + "used_quota": gorm.Expr("used_quota - ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err @@ -162,8 +166,9 @@ func DecreaseTokenQuota(id int, quota int) (err error) { func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ - "remain_quota": gorm.Expr("remain_quota - ?", quota), - "used_quota": gorm.Expr("used_quota + ?", quota), + "remain_quota": gorm.Expr("remain_quota - ?", quota), + "used_quota": gorm.Expr("used_quota + ?", quota), + "accessed_time": common.GetTimestamp(), }, ).Error return err diff --git a/model/user.go b/model/user.go index 67511267..cee4b023 100644 --- a/model/user.go +++ b/model/user.go @@ -226,17 +226,16 @@ func IsAdmin(userId int) bool { return user.Role >= common.RoleAdminUser } -func IsUserEnabled(userId int) bool { +func IsUserEnabled(userId int) (bool, error) { if userId == 0 { - return false + return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { - common.SysError("no such user " + err.Error()) - return false + return false, err } - return user.Status == common.UserStatusEnabled + return user.Status == common.UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) {