chore: reorganize billing related package

This commit is contained in:
JustSong 2024-04-06 01:26:48 +08:00
parent cd2707692f
commit 8f4d78e24d
13 changed files with 77 additions and 76 deletions

View File

@ -2,13 +2,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"net/http" "net/http"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName := range billing.GroupRatio { for groupName := range billingratio.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -3,7 +3,7 @@ package model
import ( import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -66,9 +66,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = billing.ModelRatio2JSONString() config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = billing.GroupRatio2JSONString() config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = billing.CompletionRatio2JSONString() config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
if option.Key == "ModelRatio" { if option.Key == "ModelRatio" {
option.Value = billing.AddNewMissingRatio(option.Value) option.Value = billingratio.AddNewMissingRatio(option.Value)
} }
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
@ -209,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes": case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value) config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = billing.UpdateModelRatioByJSONString(value) err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = billing.UpdateGroupRatioByJSONString(value) err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = billing.UpdateCompletionRatioByJSONString(value) err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
config.TopUpLink = value config.TopUpLink = value
case "ChatLink": case "ChatLink":

42
relay/billing/billing.go Normal file
View File

@ -0,0 +1,42 @@
package billing
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}

View File

@ -1,4 +1,4 @@
package billing package ratio
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package billing package ratio
var ImageSizeRatios = map[string]map[string]float64{ var ImageSizeRatios = map[string]map[string]float64{
"dall-e-2": { "dall-e-2": {

View File

@ -1,4 +1,4 @@
package billing package ratio
import ( import (
"encoding/json" "encoding/json"

View File

@ -7,7 +7,7 @@ import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"math" "math"
"strings" "strings"
@ -28,7 +28,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model := range billing.ModelRatio { for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {

View File

@ -13,6 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
@ -49,8 +50,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
} }
modelRatio := billing.GetModelRatio(audioModel) modelRatio := billingratio.GetModelRatio(audioModel)
groupRatio := billing.GetGroupRatio(group) groupRatio := billingratio.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
var quota int64 var quota int64
var preConsumedQuota int64 var preConsumedQuota int64
@ -218,7 +219,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
succeed = true succeed = true
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) { defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context()) }(c.Request.Context())
for k, v := range resp.Header { for k, v := range resp.Header {

View File

@ -9,7 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
@ -60,12 +60,12 @@ func isValidImageSize(model string, size string) bool {
if model == "cogview-3" { if model == "cogview-3" {
return true return true
} }
_, ok := billing.ImageSizeRatios[model][size] _, ok := billingratio.ImageSizeRatios[model][size]
return ok return ok
} }
func getImageSizeRatio(model string, size string) float64 { func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billing.ImageSizeRatios[model][size] ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok { if !ok {
return 1 return 1
} }
@ -82,7 +82,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
} }
if len(imageRequest.Prompt) > billing.ImagePromptLengthLimitations[imageRequest.Model] { if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
} }
// Number of generated images validation // Number of generated images validation
@ -165,7 +165,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
return return
} }
var quota int64 var quota int64
completionRatio := billing.GetCompletionRatio(textRequest.Model) completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))

View File

@ -9,7 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
@ -20,11 +20,11 @@ import (
) )
func isWithinRange(element string, value int) bool { func isWithinRange(element string, value int) bool {
if _, ok := billing.ImageGenerationAmounts[element]; !ok { if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false return false
} }
min := billing.ImageGenerationAmounts[element][0] min := billingratio.ImageGenerationAmounts[element][0]
max := billing.ImageGenerationAmounts[element][1] max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max return value >= min && value <= max
} }
@ -87,8 +87,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = bytes.NewBuffer(jsonStr) requestBody = bytes.NewBuffer(jsonStr)
} }
modelRatio := billing.GetModelRatio(imageRequest.Model) modelRatio := billingratio.GetModelRatio(imageRequest.Model)
groupRatio := billing.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)

View File

@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
@ -35,8 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billing.GetModelRatio(textRequest.Model) modelRatio := billingratio.GetModelRatio(textRequest.Model)
groupRatio := billing.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
// pre-consume quota // pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode) promptTokens := getPromptTokens(textRequest, meta.Mode)
@ -87,7 +88,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
} }
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened { if errorHappened {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp) return util.RelayErrorHandler(resp)
} }
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@ -96,7 +97,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr) logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr return respErr
} }
// post-consume quota // post-consume quota

View File

@ -1,19 +0,0 @@
package util
import (
"context"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}

View File

@ -1,13 +1,11 @@
package util package util
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@ -165,28 +163,6 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
return fullRequestURL return fullRequestURL
} }
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}
func GetAzureAPIVersion(c *gin.Context) string { func GetAzureAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query() query := c.Request.URL.Query()
apiVersion := query.Get("api-version") apiVersion := query.Get("api-version")