From a02a50d23605083aaa3bd9e035ba87f16020a10a Mon Sep 17 00:00:00 2001 From: WqyJh <781345688@qq.com> Date: Thu, 4 Jul 2024 16:13:31 +0800 Subject: [PATCH] fix: aws llama3 ratio --- relay/billing/ratio/model.go | 44 ++++++++++++++++++++++++++++-------- relay/controller/audio.go | 9 ++++---- relay/controller/helper.go | 9 ++++---- relay/controller/image.go | 7 +++--- relay/controller/text.go | 7 +++--- 5 files changed, 52 insertions(+), 24 deletions(-) diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 56d31e13..8a7d5743 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -2,6 +2,7 @@ package ratio import ( "encoding/json" + "fmt" "strings" "github.com/songquanpeng/one-api/common/logger" @@ -169,6 +170,9 @@ var ModelRatio = map[string]float64{ "step-1v-32k": 0.024 * RMB, "step-1-32k": 0.024 * RMB, "step-1-200k": 0.15 * RMB, + // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ + "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens + "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens // https://cohere.com/pricing "command": 0.5, "command-nightly": 0.5, @@ -185,7 +189,11 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, } -var CompletionRatio = map[string]float64{} +var CompletionRatio = map[string]float64{ + // aws llama3 + "llama3-8b-8192(33)": 0.0006 / 0.0003, + "llama3-70b-8192(33)": 0.0035 / 0.00265, +} var DefaultModelRatio map[string]float64 var DefaultCompletionRatio map[string]float64 @@ -234,22 +242,28 @@ func UpdateModelRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &ModelRatio) } -func GetModelRatio(name string) float64 { +func GetModelRatio(name string, channelType int) float64 { if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { name = strings.TrimSuffix(name, "-internet") } - ratio, ok := ModelRatio[name] - if !ok { - ratio, ok = DefaultModelRatio[name] + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := ModelRatio[model]; ok { + return ratio } - if !ok { - logger.SysError("model ratio not found: " + name) - return 30 + if ratio, ok := DefaultModelRatio[model]; ok { + return ratio } - return ratio + if ratio, ok := ModelRatio[name]; ok { + return ratio + } + if ratio, ok := DefaultModelRatio[name]; ok { + return ratio + } + logger.SysError("model ratio not found: " + name) + return 30 } func CompletionRatio2JSONString() string { @@ -265,7 +279,17 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { return json.Unmarshal([]byte(jsonStr), &CompletionRatio) } -func GetCompletionRatio(name string) float64 { +func GetCompletionRatio(name string, channelType int) float64 { + if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { + name = strings.TrimSuffix(name, "-internet") + } + model := fmt.Sprintf("%s(%d)", name, channelType) + if ratio, ok := CompletionRatio[model]; ok { + return ratio + } + if ratio, ok := DefaultCompletionRatio[model]; ok { + return ratio + } if ratio, ok := CompletionRatio[name]; ok { return ratio } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 8f9708d0..83040662 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,6 +7,10 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" @@ -21,9 +25,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "io" - "net/http" - "strings" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { @@ -53,7 +54,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := billingratio.GetModelRatio(audioModel) + modelRatio := billingratio.GetModelRatio(audioModel, channelType) groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 diff --git a/relay/controller/helper.go b/relay/controller/helper.go index c47cb558..87d22f13 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -4,6 +4,10 @@ import ( "context" "errors" "fmt" + "math" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" @@ -16,9 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" - "math" - "net/http" - "strings" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { @@ -95,7 +96,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M return } var quota int64 - completionRatio := billingratio.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) diff --git a/relay/controller/image.go b/relay/controller/image.go index e6245226..1e06e858 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -6,6 +6,9 @@ import ( "encoding/json" "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" @@ -17,8 +20,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { @@ -166,7 +167,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = bytes.NewBuffer(jsonStr) } - modelRatio := billingratio.GetModelRatio(imageModel) + modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) diff --git a/relay/controller/text.go b/relay/controller/text.go index 6ed19b1d..0d3c56b0 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -4,6 +4,9 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay" @@ -14,8 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" ) func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { @@ -35,7 +36,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := billingratio.GetModelRatio(textRequest.Model) + modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio // pre-consume quota