🐛 fix: Tokens may not be counted in the stream below

This commit is contained in:
MartialBE 2024-05-31 14:03:06 +08:00
parent 05adacefff
commit acf61f8b89
No known key found for this signature in database
GPG Key ID: 27C0267EC84B0A5C
9 changed files with 31 additions and 9 deletions

View File

@ -242,6 +242,8 @@ func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan strin
case "content_block_delta":
h.convertToOpenaiStream(&claudeResponse, dataChan)
h.Usage.CompletionTokens += common.CountTokenText(claudeResponse.Delta.Text, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
default:
return

View File

@ -189,6 +189,9 @@ func (h *CohereStreamHandler) convertToOpenaiStream(cohereResponse *CohereStream
Role: types.ChatMessageRoleAssistant,
Content: cohereResponse.Text,
}
h.Usage.CompletionTokens += common.CountTokenText(cohereResponse.Text, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
}
chatCompletion := types.ChatCompletionStreamResponse{

View File

@ -270,6 +270,9 @@ func (h *minimaxStreamHandler) convertToOpenaiStream(miniResponse *MiniMaxChatRe
if miniResponse.Usage != nil {
h.handleUsage(miniResponse)
} else {
h.Usage.CompletionTokens += common.CountTokenText(miniChoice.Messages[0].Text, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
}
}

View File

@ -131,6 +131,9 @@ func (h *mistralStreamHandler) handlerStream(rawLine *[]byte, dataChan chan stri
if mistralResponse.Usage != nil {
*h.Usage = *mistralResponse.Usage
} else {
h.Usage.CompletionTokens += common.CountTokenText(mistralResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
}
stop := false

View File

@ -138,7 +138,7 @@ func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, dataChan chan s
if h.Usage.TotalTokens == 0 {
h.Usage.TotalTokens = h.Usage.PromptTokens
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
countTokenText := common.CountTokenText(openaiResponse.GetResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
}

View File

@ -12,14 +12,6 @@ type OpenAIProviderChatStreamResponse struct {
types.OpenAIErrorResponse
}
func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
return
}
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse

View File

@ -279,5 +279,8 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes
if zhipuResponse.Usage != nil {
*h.Usage = *zhipuResponse.Usage
} else {
h.Usage.CompletionTokens += common.CountTokenText(zhipuResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
}
}

View File

@ -57,6 +57,14 @@ type ZhipuStreamResponse struct {
ZhipuResponseError
}
func (z *ZhipuStreamResponse) GetResponseText() (responseText string) {
for _, choice := range z.Choices {
responseText += choice.Delta.Content
}
return
}
type ZhipuResponseError struct {
Error ZhipuError `json:"error,omitempty"`
}

View File

@ -376,3 +376,11 @@ type ChatCompletionStreamResponse struct {
PromptAnnotations any `json:"prompt_annotations,omitempty"`
Usage *Usage `json:"usage,omitempty"`
}
func (c *ChatCompletionStreamResponse) GetResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
return
}