From acf61f8b89c7f805f21379e86f729e67420336c0 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Fri, 31 May 2024 14:03:06 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20Tokens=20may=20not=20be?= =?UTF-8?q?=20counted=20in=20the=20stream=20below?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/claude/chat.go | 2 ++ providers/cohere/chat.go | 3 +++ providers/minimax/chat.go | 3 +++ providers/mistral/chat.go | 3 +++ providers/openai/chat.go | 2 +- providers/openai/type.go | 8 -------- providers/zhipu/chat.go | 3 +++ providers/zhipu/type.go | 8 ++++++++ types/chat.go | 8 ++++++++ 9 files changed, 31 insertions(+), 9 deletions(-) diff --git a/providers/claude/chat.go b/providers/claude/chat.go index 9165ab3e..26b6fcdb 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -242,6 +242,8 @@ func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan strin case "content_block_delta": h.convertToOpenaiStream(&claudeResponse, dataChan) + h.Usage.CompletionTokens += common.CountTokenText(claudeResponse.Delta.Text, h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens default: return diff --git a/providers/cohere/chat.go b/providers/cohere/chat.go index 6b7a76a6..95b3cb27 100644 --- a/providers/cohere/chat.go +++ b/providers/cohere/chat.go @@ -189,6 +189,9 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream Role: types.ChatMessageRoleAssistant, Content: cohereResponse.Text, } + + h.Usage.CompletionTokens += common.CountTokenText(cohereResponse.Text, h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens } chatCompletion := types.ChatCompletionStreamResponse{ diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go index 4cef16a9..4a307fac 100644 --- a/providers/minimax/chat.go +++ b/providers/minimax/chat.go @@ -270,6 +270,9 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe if miniResponse.Usage != nil { h.handleUsage(miniResponse) + } else { + h.Usage.CompletionTokens += common.CountTokenText(miniChoice.Messages[0].Text, h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens } } diff --git a/providers/mistral/chat.go b/providers/mistral/chat.go index a02f30f5..22afa2f6 100644 --- a/providers/mistral/chat.go +++ b/providers/mistral/chat.go @@ -131,6 +131,9 @@ func (h *mistralStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri if mistralResponse.Usage != nil { *h.Usage = *mistralResponse.Usage + } else { + h.Usage.CompletionTokens += common.CountTokenText(mistralResponse.GetResponseText(), h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens } stop := false diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 7f5b36ae..63dcb8ce 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -138,7 +138,7 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s if h.Usage.TotalTokens == 0 { h.Usage.TotalTokens = h.Usage.PromptTokens } - countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName) + countTokenText := common.CountTokenText(openaiResponse.GetResponseText(), h.ModelName) h.Usage.CompletionTokens += countTokenText h.Usage.TotalTokens += countTokenText } diff --git a/providers/openai/type.go b/providers/openai/type.go index ce82ba5f..42e0c1ff 100644 --- a/providers/openai/type.go +++ b/providers/openai/type.go @@ -12,14 +12,6 @@ type OpenAIProviderChatStreamResponse struct { types.OpenAIErrorResponse } -func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) { - for _, choice := range c.Choices { - responseText += choice.Delta.Content - } - - return -} - type OpenAIProviderCompletionResponse struct { types.CompletionResponse types.OpenAIErrorResponse diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 955e491e..15e601fe 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -279,5 +279,8 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes if zhipuResponse.Usage != nil { *h.Usage = *zhipuResponse.Usage + } else { + h.Usage.CompletionTokens += common.CountTokenText(zhipuResponse.GetResponseText(), h.Request.Model) + h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens } } diff --git a/providers/zhipu/type.go b/providers/zhipu/type.go index a5779070..31d5beb4 100644 --- a/providers/zhipu/type.go +++ b/providers/zhipu/type.go @@ -57,6 +57,14 @@ type ZhipuStreamResponse struct { ZhipuResponseError } +func (z *ZhipuStreamResponse) GetResponseText() (responseText string) { + for _, choice := range z.Choices { + responseText += choice.Delta.Content + } + + return +} + type ZhipuResponseError struct { Error ZhipuError `json:"error,omitempty"` } diff --git a/types/chat.go b/types/chat.go index 1bb6822d..7a6aeebd 100644 --- a/types/chat.go +++ b/types/chat.go @@ -376,3 +376,11 @@ type ChatCompletionStreamResponse struct { PromptAnnotations any `json:"prompt_annotations,omitempty"` Usage *Usage `json:"usage,omitempty"` } + +func (c *ChatCompletionStreamResponse) GetResponseText() (responseText string) { + for _, choice := range c.Choices { + responseText += choice.Delta.Content + } + + return +}