From dfaa0183b71274a19eb87fe9f837d70a414737f4 Mon Sep 17 00:00:00 2001 From: glzjin Date: Sat, 19 Aug 2023 17:14:39 +0800 Subject: [PATCH] fix: fix baidu & ali's quota calculation (#444) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复阿里计费问题 * 修复百度计费问题 --- controller/relay-ali.go | 8 +++++--- controller/relay-baidu.go | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 014f6b84..9dca9a89 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -177,9 +177,11 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat common.SysError("error unmarshalling stream response: " + err.Error()) return true } - usage.PromptTokens += aliResponse.Usage.InputTokens - usage.CompletionTokens += aliResponse.Usage.OutputTokens - usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + if aliResponse.Usage.OutputTokens != 0 { + usage.PromptTokens = aliResponse.Usage.InputTokens + usage.CompletionTokens = aliResponse.Usage.OutputTokens + usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens + } response := streamResponseAli2OpenAI(&aliResponse) response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText) lastResponseText = aliResponse.Output.Text diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ad20d6d6..39f31a9a 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -215,9 +215,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt common.SysError("error unmarshalling stream response: " + err.Error()) return true } - usage.PromptTokens += baiduResponse.Usage.PromptTokens - usage.CompletionTokens += baiduResponse.Usage.CompletionTokens - usage.TotalTokens += baiduResponse.Usage.TotalTokens + if baiduResponse.Usage.TotalTokens != 0 { + usage.TotalTokens = baiduResponse.Usage.TotalTokens + usage.PromptTokens = baiduResponse.Usage.PromptTokens + usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens + } response := streamResponseBaidu2OpenAI(&baiduResponse) jsonResponse, err := json.Marshal(response) if err != nil {