chore: reorganize billing related package
This commit is contained in:
parent
cd2707692f
commit
8f4d78e24d
@ -2,13 +2,13 @@ package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName := range billing.GroupRatio {
|
||||
for groupName := range billingratio.GroupRatio {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
@ -3,7 +3,7 @@ package model
|
||||
import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -66,9 +66,9 @@ func InitOptionMap() {
|
||||
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
|
||||
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
|
||||
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
|
||||
config.OptionMap["ModelRatio"] = billing.ModelRatio2JSONString()
|
||||
config.OptionMap["GroupRatio"] = billing.GroupRatio2JSONString()
|
||||
config.OptionMap["CompletionRatio"] = billing.CompletionRatio2JSONString()
|
||||
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
|
||||
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
|
||||
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
|
||||
config.OptionMap["TopUpLink"] = config.TopUpLink
|
||||
config.OptionMap["ChatLink"] = config.ChatLink
|
||||
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
|
||||
@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
|
||||
options, _ := AllOption()
|
||||
for _, option := range options {
|
||||
if option.Key == "ModelRatio" {
|
||||
option.Value = billing.AddNewMissingRatio(option.Value)
|
||||
option.Value = billingratio.AddNewMissingRatio(option.Value)
|
||||
}
|
||||
err := updateOptionMap(option.Key, option.Value)
|
||||
if err != nil {
|
||||
@ -209,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "RetryTimes":
|
||||
config.RetryTimes, _ = strconv.Atoi(value)
|
||||
case "ModelRatio":
|
||||
err = billing.UpdateModelRatioByJSONString(value)
|
||||
err = billingratio.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = billing.UpdateGroupRatioByJSONString(value)
|
||||
err = billingratio.UpdateGroupRatioByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = billing.UpdateCompletionRatioByJSONString(value)
|
||||
err = billingratio.UpdateCompletionRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
config.TopUpLink = value
|
||||
case "ChatLink":
|
||||
|
42
relay/billing/billing.go
Normal file
42
relay/billing/billing.go
Normal file
@ -0,0 +1,42 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package billing
|
||||
package ratio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
@ -1,4 +1,4 @@
|
||||
package billing
|
||||
package ratio
|
||||
|
||||
var ImageSizeRatios = map[string]map[string]float64{
|
||||
"dall-e-2": {
|
@ -1,4 +1,4 @@
|
||||
package billing
|
||||
package ratio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
@ -7,7 +7,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/image"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"math"
|
||||
"strings"
|
||||
@ -28,7 +28,7 @@ func InitTokenEncoders() {
|
||||
if err != nil {
|
||||
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model := range billing.ModelRatio {
|
||||
for model := range billingratio.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
@ -49,8 +50,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := billing.GetModelRatio(audioModel)
|
||||
groupRatio := billing.GetGroupRatio(group)
|
||||
modelRatio := billingratio.GetModelRatio(audioModel)
|
||||
groupRatio := billingratio.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
var quota int64
|
||||
var preConsumedQuota int64
|
||||
@ -218,7 +219,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
succeed = true
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
@ -60,12 +60,12 @@ func isValidImageSize(model string, size string) bool {
|
||||
if model == "cogview-3" {
|
||||
return true
|
||||
}
|
||||
_, ok := billing.ImageSizeRatios[model][size]
|
||||
_, ok := billingratio.ImageSizeRatios[model][size]
|
||||
return ok
|
||||
}
|
||||
|
||||
func getImageSizeRatio(model string, size string) float64 {
|
||||
ratio, ok := billing.ImageSizeRatios[model][size]
|
||||
ratio, ok := billingratio.ImageSizeRatios[model][size]
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
@ -82,7 +82,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
}
|
||||
if len(imageRequest.Prompt) > billing.ImagePromptLengthLimitations[imageRequest.Model] {
|
||||
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
|
||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||
}
|
||||
// Number of generated images validation
|
||||
@ -165,7 +165,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
|
||||
return
|
||||
}
|
||||
var quota int64
|
||||
completionRatio := billing.GetCompletionRatio(textRequest.Model)
|
||||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens := usage.PromptTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
@ -20,11 +20,11 @@ import (
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
if _, ok := billing.ImageGenerationAmounts[element]; !ok {
|
||||
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
min := billing.ImageGenerationAmounts[element][0]
|
||||
max := billing.ImageGenerationAmounts[element][1]
|
||||
min := billingratio.ImageGenerationAmounts[element][0]
|
||||
max := billingratio.ImageGenerationAmounts[element][1]
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
@ -87,8 +87,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
|
||||
modelRatio := billing.GetModelRatio(imageRequest.Model)
|
||||
groupRatio := billing.GetGroupRatio(meta.Group)
|
||||
modelRatio := billingratio.GetModelRatio(imageRequest.Model)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||||
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/apitype"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
@ -35,8 +36,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// get model ratio & group ratio
|
||||
modelRatio := billing.GetModelRatio(textRequest.Model)
|
||||
groupRatio := billing.GetGroupRatio(meta.Group)
|
||||
modelRatio := billingratio.GetModelRatio(textRequest.Model)
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||
@ -87,7 +88,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
}
|
||||
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
|
||||
if errorHappened {
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
@ -96,7 +97,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return respErr
|
||||
}
|
||||
// post-consume quota
|
||||
|
@ -1,19 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
|
||||
if preConsumedQuota != 0 {
|
||||
go func(ctx context.Context) {
|
||||
// return pre-consumed quota
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
|
||||
}
|
||||
}(ctx)
|
||||
}
|
||||
}
|
@ -1,13 +1,11 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
@ -165,28 +163,6 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
||||
return fullRequestURL
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
logger.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
logger.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
func GetAzureAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
|
Loading…
Reference in New Issue
Block a user