diff --git a/controller/group.go b/controller/group.go index e1f9d9ff..6f02394f 100644 --- a/controller/group.go +++ b/controller/group.go @@ -2,13 +2,13 @@ package controller import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "net/http" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName := range billing.GroupRatio { + for groupName := range billingratio.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/model/option.go b/model/option.go index ba734c3a..bed8d4c3 100644 --- a/model/option.go +++ b/model/option.go @@ -3,7 +3,7 @@ package model import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "strconv" "strings" "time" @@ -66,9 +66,9 @@ func InitOptionMap() { config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) - config.OptionMap["ModelRatio"] = billing.ModelRatio2JSONString() - config.OptionMap["GroupRatio"] = billing.GroupRatio2JSONString() - config.OptionMap["CompletionRatio"] = billing.CompletionRatio2JSONString() + config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) @@ -82,7 +82,7 @@ func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { if option.Key == "ModelRatio" { - option.Value = billing.AddNewMissingRatio(option.Value) + option.Value = billingratio.AddNewMissingRatio(option.Value) } err := updateOptionMap(option.Key, option.Value) if err != nil { @@ -209,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) { case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": - err = billing.UpdateModelRatioByJSONString(value) + err = billingratio.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = billing.UpdateGroupRatioByJSONString(value) + err = billingratio.UpdateGroupRatioByJSONString(value) case "CompletionRatio": - err = billing.UpdateCompletionRatioByJSONString(value) + err = billingratio.UpdateCompletionRatioByJSONString(value) case "TopUpLink": config.TopUpLink = value case "ChatLink": diff --git a/relay/billing/billing.go b/relay/billing/billing.go new file mode 100644 index 00000000..a99d37ee --- /dev/null +++ b/relay/billing/billing.go @@ -0,0 +1,42 @@ +package billing + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} diff --git a/relay/billing/group.go b/relay/billing/ratio/group.go similarity index 97% rename from relay/billing/group.go rename to relay/billing/ratio/group.go index d9b0f714..8e9c5b73 100644 --- a/relay/billing/group.go +++ b/relay/billing/ratio/group.go @@ -1,4 +1,4 @@ -package billing +package ratio import ( "encoding/json" diff --git a/relay/billing/image.go b/relay/billing/ratio/image.go similarity index 98% rename from relay/billing/image.go rename to relay/billing/ratio/image.go index 92bfa48c..5a29cddc 100644 --- a/relay/billing/image.go +++ b/relay/billing/ratio/image.go @@ -1,4 +1,4 @@ -package billing +package ratio var ImageSizeRatios = map[string]map[string]float64{ "dall-e-2": { diff --git a/relay/billing/model.go b/relay/billing/ratio/model.go similarity index 99% rename from relay/billing/model.go rename to relay/billing/ratio/model.go index 1556e72b..e98c5be8 100644 --- a/relay/billing/model.go +++ b/relay/billing/ratio/model.go @@ -1,4 +1,4 @@ -package billing +package ratio import ( "encoding/json" diff --git a/relay/channel/openai/token.go b/relay/channel/openai/token.go index fbaf4189..c95a7b5e 100644 --- a/relay/channel/openai/token.go +++ b/relay/channel/openai/token.go @@ -7,7 +7,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "math" "strings" @@ -28,7 +28,7 @@ func InitTokenEncoders() { if err != nil { logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model := range billing.ModelRatio { + for model := range billingratio.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 1baf7ce0..912f7a3e 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -13,6 +13,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -49,8 +50,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billing.GetModelRatio(audioModel) - groupRatio := billing.GetGroupRatio(group) + modelRatio := billingratio.GetModelRatio(audioModel) + groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 @@ -218,7 +219,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 4691ffcb..4444bd51 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -9,7 +9,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -60,12 +60,12 @@ func isValidImageSize(model string, size string) bool { if model == "cogview-3" { return true } - _, ok := billing.ImageSizeRatios[model][size] + _, ok := billingratio.ImageSizeRatios[model][size] return ok } func getImageSizeRatio(model string, size string) float64 { - ratio, ok := billing.ImageSizeRatios[model][size] + ratio, ok := billingratio.ImageSizeRatios[model][size] if !ok { return 1 } @@ -82,7 +82,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - if len(imageRequest.Prompt) > billing.ImagePromptLengthLimitations[imageRequest.Model] { + if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation @@ -165,7 +165,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R return } var quota int64 - completionRatio := billing.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) diff --git a/relay/controller/image.go b/relay/controller/image.go index 8e5b6092..18e864f5 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -9,7 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" @@ -20,11 +20,11 @@ import ( ) func isWithinRange(element string, value int) bool { - if _, ok := billing.ImageGenerationAmounts[element]; !ok { + if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { return false } - min := billing.ImageGenerationAmounts[element][0] - max := billing.ImageGenerationAmounts[element][1] + min := billingratio.ImageGenerationAmounts[element][0] + max := billingratio.ImageGenerationAmounts[element][1] return value >= min && value <= max } @@ -87,8 +87,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billing.GetModelRatio(imageRequest.Model) - groupRatio := billing.GetGroupRatio(meta.Group) + modelRatio := billingratio.GetModelRatio(imageRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/relay/controller/text.go b/relay/controller/text.go index 1ddb20fe..068cef8d 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/helper" @@ -35,8 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billing.GetModelRatio(textRequest.Model) - groupRatio := billing.GetGroupRatio(meta.Group) + modelRatio := billingratio.GetModelRatio(textRequest.Model) + groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota promptTokens := getPromptTokens(textRequest, meta.Mode) @@ -87,7 +88,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { } errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") if errorHappened { - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return util.RelayErrorHandler(resp) } meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") @@ -96,7 +97,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } // post-consume quota diff --git a/relay/util/billing.go b/relay/util/billing.go deleted file mode 100644 index 495d011e..00000000 --- a/relay/util/billing.go +++ /dev/null @@ -1,19 +0,0 @@ -package util - -import ( - "context" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" -) - -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(ctx) - } -} diff --git a/relay/util/common.go b/relay/util/common.go index 3826e67f..315f1253 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -1,13 +1,11 @@ package util import ( - "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" relaymodel "github.com/songquanpeng/one-api/relay/model" "io" @@ -165,28 +163,6 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(ctx, userId) - if err != nil { - logger.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - if totalQuota <= 0 { - logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - func GetAzureAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version")