feat: do not access database before response return (close #158)
This commit is contained in:
parent
ba54c71948
commit
6d961064d2
@ -16,6 +16,7 @@ import (
|
|||||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
consumeQuota := c.GetBool("consume_quota")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
var textRequest GeneralOpenAIRequest
|
var textRequest GeneralOpenAIRequest
|
||||||
@ -73,7 +74,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
groupRatio := common.GetGroupRatio(group)
|
groupRatio := common.GetGroupRatio(group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||||
if consumeQuota {
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "get_user_quota_failed", http.StatusOK)
|
||||||
|
}
|
||||||
|
if userQuota > 10*preConsumedQuota {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
}
|
||||||
|
if consumeQuota && preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK)
|
||||||
@ -133,7 +143,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
common.SysError("Error consuming token remain quota: " + err.Error())
|
common.SysError("Error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
userId := c.GetInt("id")
|
|
||||||
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio))
|
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("通过令牌「%s」使用模型 %s 消耗 %s(模型倍率 %.2f,分组倍率 %.2f)", tokenName, textRequest.Model, common.LogQuota(quota), modelRatio, groupRatio))
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
@ -100,7 +100,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !model.IsUserEnabled(token.UserId) {
|
if !model.CacheIsUserEnabled(token.UserId) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": "用户已被封禁",
|
"message": "用户已被封禁",
|
||||||
|
@ -6,14 +6,17 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TokenCacheSeconds = 60 * 60
|
TokenCacheSeconds = 60 * 60
|
||||||
UserId2GroupCacheSeconds = 60 * 60
|
UserId2GroupCacheSeconds = 60 * 60
|
||||||
|
UserId2QuotaCacheSeconds = 10 * 60
|
||||||
|
UserId2StatusCacheSeconds = 60 * 60
|
||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
@ -60,6 +63,45 @@ func CacheGetUserGroup(id int) (group string, err error) {
|
|||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetUserQuota(id int) (quota int, err error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return GetUserQuota(id)
|
||||||
|
}
|
||||||
|
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
|
||||||
|
if err != nil {
|
||||||
|
quota, err = GetUserQuota(id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Redis set user quota error: " + err.Error())
|
||||||
|
}
|
||||||
|
return quota, err
|
||||||
|
}
|
||||||
|
quota, err = strconv.Atoi(quotaString)
|
||||||
|
return quota, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func CacheIsUserEnabled(userId int) bool {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return IsUserEnabled(userId)
|
||||||
|
}
|
||||||
|
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
|
||||||
|
if err != nil {
|
||||||
|
status := common.UserStatusDisabled
|
||||||
|
if IsUserEnabled(userId) {
|
||||||
|
status = common.UserStatusEnabled
|
||||||
|
}
|
||||||
|
enabled = fmt.Sprintf("%d", status)
|
||||||
|
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("Redis set user enabled error: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return enabled == "1"
|
||||||
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user