diff --git a/controller/relay-text.go b/controller/relay-text.go index b7b0ed63..5d5e52cf 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -16,6 +16,7 @@ import ( func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") + userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var textRequest GeneralOpenAIRequest @@ -73,7 +74,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - if consumeQuota { + userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { + return errorWrapper(err, "get_user_quota_failed", http.StatusOK) + } + if userQuota > 10*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if consumeQuota && preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK) @@ -133,7 +143,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { common.SysError("Error consuming token remain quota: " + err.Error()) } tokenName := c.GetString("token_name") - userId := c.GetInt("id") model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio)) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") diff --git a/middleware/auth.go b/middleware/auth.go index c172078c..65596a16 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -100,7 +100,7 @@ func TokenAuth() func(c *gin.Context) { c.Abort() return } - if !model.IsUserEnabled(token.UserId) { + if !model.CacheIsUserEnabled(token.UserId) { c.JSON(http.StatusOK, gin.H{ "error": gin.H{ "message": "用户已被封禁", diff --git a/model/cache.go b/model/cache.go index 6cf4808a..8caa2cfe 100644 --- a/model/cache.go +++ b/model/cache.go @@ -6,14 +6,17 @@ import ( "fmt" "math/rand" "one-api/common" + "strconv" "strings" "sync" "time" ) const ( - TokenCacheSeconds = 60 * 60 - UserId2GroupCacheSeconds = 60 * 60 + TokenCacheSeconds = 60 * 60 + UserId2GroupCacheSeconds = 60 * 60 + UserId2QuotaCacheSeconds = 10 * 60 + UserId2StatusCacheSeconds = 60 * 60 ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -60,6 +63,45 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } +func CacheGetUserQuota(id int) (quota int, err error) { + if !common.RedisEnabled { + return GetUserQuota(id) + } + quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) + if err != nil { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second) + if err != nil { + common.SysError("Redis set user quota error: " + err.Error()) + } + return quota, err + } + quota, err = strconv.Atoi(quotaString) + return quota, err +} + +func CacheIsUserEnabled(userId int) bool { + if !common.RedisEnabled { + return IsUserEnabled(userId) + } + enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) + if err != nil { + status := common.UserStatusDisabled + if IsUserEnabled(userId) { + status = common.UserStatusEnabled + } + enabled = fmt.Sprintf("%d", status) + err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second) + if err != nil { + common.SysError("Redis set user enabled error: " + err.Error()) + } + } + return enabled == "1" +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex