chore: pass through error out

This commit is contained in:
JustSong 2023-09-03 21:31:58 +08:00
parent 7e575abb95
commit 621eb91b46
5 changed files with 57 additions and 38 deletions

View File

@ -377,7 +377,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
} }
} }

View File

@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) {
c.Abort() c.Abort()
return return
} }
if !model.CacheIsUserEnabled(token.UserId) { userEnabled, err := model.IsUserEnabled(token.UserId)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return
}
if !userEnabled {
c.JSON(http.StatusForbidden, gin.H{ c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{ "error": gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",

View File

@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
return err return err
} }
func CacheIsUserEnabled(userId int) bool { func CacheIsUserEnabled(userId int) (bool, error) {
if !common.RedisEnabled { if !common.RedisEnabled {
return IsUserEnabled(userId) return IsUserEnabled(userId)
} }
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err != nil { if err == nil {
status := common.UserStatusDisabled return enabled == "1", nil
if IsUserEnabled(userId) { }
status = common.UserStatusEnabled
userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if userEnabled {
enabled = "1"
} }
enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil { if err != nil {
common.SysError("Redis set user enabled error: " + err.Error()) common.SysError("Redis set user enabled error: " + err.Error())
} }
} return userEnabled, err
return enabled == "1"
} }
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel

View File

@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
token, err = CacheGetTokenByKey(key) token, err = CacheGetTokenByKey(key)
if err == nil { if err == nil {
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
}
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
common.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
}
return nil, errors.New("该令牌额度已用尽") return nil, errors.New("该令牌额度已用尽")
} }
go func() {
token.AccessedTime = common.GetTimestamp()
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token" + err.Error())
}
}()
return token, nil return token, nil
} }
return nil, errors.New("无效的令牌") return nil, errors.New("无效的令牌")
@ -143,6 +146,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": common.GetTimestamp(),
}, },
).Error ).Error
return err return err
@ -164,6 +168,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": common.GetTimestamp(),
}, },
).Error ).Error
return err return err

View File

@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
func IsUserEnabled(userId int) bool { func IsUserEnabled(userId int) (bool, error) {
if userId == 0 { if userId == 0 {
return false return false, errors.New("user id is empty")
} }
var user User var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil { if err != nil {
common.SysError("no such user " + err.Error()) return false, err
return false
} }
return user.Status == common.UserStatusEnabled return user.Status == common.UserStatusEnabled, nil
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {