From a8891c0f7286c9edcc996371063a6c66f136a400 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Tue, 2 Apr 2024 12:02:00 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20repair=20the=20error=20ca?= =?UTF-8?q?used=20by=20incomplete=20parameters=20in=20third-party=20OpenAI?= =?UTF-8?q?=20interface=20(#135)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/openai/chat.go | 11 +++++++++++ types/chat.go | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 7e892b38..bd1bea6b 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -41,6 +41,17 @@ func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionReque return nil, errWithCode } + if response.Usage == nil { + response.Usage = &types.Usage{ + PromptTokens: p.Usage.PromptTokens, + CompletionTokens: 0, + TotalTokens: 0, + } + // 那么需要计算 + response.Usage.CompletionTokens = common.CountTokenText(response.GetContent(), request.Model) + response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens + } + *p.Usage = *response.Usage return &response.ChatCompletionResponse, nil diff --git a/types/chat.go b/types/chat.go index 718b47a0..3e3668d8 100644 --- a/types/chat.go +++ b/types/chat.go @@ -189,6 +189,14 @@ type ChatCompletionResponse struct { PromptFilterResults any `json:"prompt_filter_results,omitempty"` } +func (cc *ChatCompletionResponse) GetContent() string { + var content string + for _, choice := range cc.Choices { + content += choice.Message.StringContent() + } + return content +} + func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStreamChoice { var function *ChatCompletionToolCallsFunction var functions []*ChatCompletionToolCallsFunction