From bb7e0ae80f91e43dcd0ec8145402d69d1908d787 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Tue, 21 May 2024 01:36:39 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20Gemini=20only=20returns?= =?UTF-8?q?=20a=20single=20tools=5Fcall=20#197?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/gemini/chat.go | 72 +++--------------- providers/gemini/type.go | 92 ++++++++++++++++++++--- providers/zhipu/chat.go | 16 ++-- types/chat.go | 158 +++++++++++++++++++++++++++------------ 4 files changed, 213 insertions(+), 125 deletions(-) diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index d66299b7..5875d86b 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -112,10 +112,13 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq MaxOutputTokens: request.MaxTokens, }, } - if request.Tools != nil { + + functions := request.GetFunctions() + + if functions != nil { var geminiChatTools GeminiChatTools - for _, tool := range request.Tools { - geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, tool.Function) + for _, function := range functions { + geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, *function) } geminiRequest.Tools = append(geminiRequest.Tools, geminiChatTools) } @@ -147,30 +150,8 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque Model: request.Model, Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), } - for i, candidate := range response.Candidates { - choice := types.ChatCompletionChoice{ - Index: i, - Message: types.ChatCompletionMessage{ - Role: "assistant", - // Content: "", - }, - FinishReason: types.FinishReasonStop, - } - if len(candidate.Content.Parts) == 0 { - choice.Message.Content = "" - openaiResponse.Choices = append(openaiResponse.Choices, choice) - continue - // choice.Message.Content = candidate.Content.Parts[0].Text - } - // 开始判断 - geminiParts := candidate.Content.Parts[0] - - if geminiParts.FunctionCall != nil { - choice.Message.ToolCalls = geminiParts.FunctionCall.ToOpenAITool() - } else { - choice.Message.Content = geminiParts.Text - } - openaiResponse.Choices = append(openaiResponse.Choices, choice) + for _, candidate := range response.Candidates { + openaiResponse.Choices = append(openaiResponse.Choices, candidate.ToOpenAIChoice(request)) } *p.Usage = convertOpenAIUsage(request.Model, response.UsageMetadata) @@ -218,42 +199,11 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) - for i, candidate := range geminiResponse.Candidates { - parts := candidate.Content.Parts[0] - - choice := types.ChatCompletionStreamChoice{ - Index: i, - Delta: types.ChatCompletionStreamChoiceDelta{ - Role: types.ChatMessageRoleAssistant, - }, - FinishReason: types.FinishReasonStop, - } - - if parts.FunctionCall != nil { - if parts.FunctionCall.Args == nil { - parts.FunctionCall.Args = map[string]interface{}{} - } - args, _ := json.Marshal(parts.FunctionCall.Args) - - choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ - { - Id: "call_" + common.GetRandomString(24), - Type: types.ChatMessageRoleFunction, - Index: 0, - Function: &types.ChatCompletionToolCallsFunction{ - Name: parts.FunctionCall.Name, - Arguments: string(args), - }, - }, - } - } else { - choice.Delta.Content = parts.Text - } - - choices = append(choices, choice) + for _, candidate := range geminiResponse.Candidates { + choices = append(choices, candidate.ToOpenAIStreamChoice(h.Request)) } - if len(choices) > 0 && choices[0].Delta.ToolCalls != nil { + if len(choices) > 0 && (choices[0].Delta.ToolCalls != nil || choices[0].Delta.FunctionCall != nil) { choices := choices[0].ConvertOpenaiStream() for _, choice := range choices { chatCompletionCopy := streamResponse diff --git a/providers/gemini/type.go b/providers/gemini/type.go index efe6208e..798d39da 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -32,6 +32,80 @@ type GeminiFunctionCall struct { Args map[string]interface{} `json:"args,omitempty"` } +func (candidate *GeminiChatCandidate) ToOpenAIStreamChoice(request *types.ChatCompletionRequest) types.ChatCompletionStreamChoice { + choice := types.ChatCompletionStreamChoice{ + Index: int(candidate.Index), + Delta: types.ChatCompletionStreamChoiceDelta{ + Role: types.ChatMessageRoleAssistant, + }, + FinishReason: types.FinishReasonStop, + } + + content := "" + isTools := false + + for _, part := range candidate.Content.Parts { + if part.FunctionCall != nil { + if choice.Delta.ToolCalls == nil { + choice.Delta.ToolCalls = make([]*types.ChatCompletionToolCalls, 0) + } + isTools = true + choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, part.FunctionCall.ToOpenAITool()) + } else { + content += part.Text + } + } + + choice.Delta.Content = content + + if isTools { + choice.FinishReason = types.FinishReasonToolCalls + } + choice.CheckChoice(request) + + return choice +} + +func (candidate *GeminiChatCandidate) ToOpenAIChoice(request *types.ChatCompletionRequest) types.ChatCompletionChoice { + choice := types.ChatCompletionChoice{ + Index: int(candidate.Index), + Message: types.ChatCompletionMessage{ + Role: "assistant", + }, + FinishReason: types.FinishReasonStop, + } + + if len(candidate.Content.Parts) == 0 { + choice.Message.Content = "" + return choice + } + + content := "" + useTools := false + + for _, part := range candidate.Content.Parts { + if part.FunctionCall != nil { + if choice.Message.ToolCalls == nil { + choice.Message.ToolCalls = make([]*types.ChatCompletionToolCalls, 0) + } + useTools = true + choice.Message.ToolCalls = append(choice.Message.ToolCalls, part.FunctionCall.ToOpenAITool()) + } else { + content += part.Text + } + } + + choice.Message.Content = content + + if useTools { + choice.FinishReason = types.FinishReasonToolCalls + } + + choice.CheckChoice(request) + + return choice +} + type GeminiFunctionResponse struct { Name string `json:"name,omitempty"` Response GeminiFunctionResponseContent `json:"response,omitempty"` @@ -42,18 +116,16 @@ type GeminiFunctionResponseContent struct { Content string `json:"content,omitempty"` } -func (g *GeminiFunctionCall) ToOpenAITool() []*types.ChatCompletionToolCalls { +func (g *GeminiFunctionCall) ToOpenAITool() *types.ChatCompletionToolCalls { args, _ := json.Marshal(g.Args) - return []*types.ChatCompletionToolCalls{ - { - Id: "", - Type: types.ChatMessageRoleFunction, - Index: 0, - Function: &types.ChatCompletionToolCallsFunction{ - Name: g.Name, - Arguments: string(args), - }, + return &types.ChatCompletionToolCalls{ + Id: "call_" + common.GetRandomString(24), + Type: types.ChatMessageRoleFunction, + Index: 0, + Function: &types.ChatCompletionToolCallsFunction{ + Name: g.Name, + Arguments: string(args), }, } } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 45f5ff75..03981e21 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -8,7 +8,6 @@ import ( "one-api/common/requester" "one-api/types" "strings" - "time" ) type zhipuStreamHandler struct { @@ -99,6 +98,12 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty Usage: response.Usage, } + if len(openaiResponse.Choices) > 0 && openaiResponse.Choices[0].Message.ToolCalls != nil && request.Functions != nil { + for i, _ := range openaiResponse.Choices { + openaiResponse.Choices[i].CheckChoice(request) + } + } + *p.Usage = *response.Usage return @@ -254,9 +259,9 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes Model: h.Request.Model, } - choice := zhipuResponse.Choices[0] - - if choice.Delta.ToolCalls != nil { + if zhipuResponse.Choices[0].Delta.ToolCalls != nil { + choice := zhipuResponse.Choices[0] + choice.CheckChoice(h.Request) choices := choice.ConvertOpenaiStream() for _, choice := range choices { chatCompletionCopy := streamResponse @@ -265,10 +270,9 @@ func (h *zhipuStreamHandler) convertToOpenaiStream(zhipuResponse *ZhipuStreamRes dataChan <- string(responseBody) } } else { - streamResponse.Choices = []types.ChatCompletionStreamChoice{choice} + streamResponse.Choices = zhipuResponse.Choices responseBody, _ := json.Marshal(streamResponse) dataChan <- string(responseBody) - time.Sleep(20 * time.Millisecond) } if zhipuResponse.Usage != nil { diff --git a/types/chat.go b/types/chat.go index 8da252ce..0618b41a 100644 --- a/types/chat.go +++ b/types/chat.go @@ -111,6 +111,7 @@ func (m ChatCompletionMessage) ParseContent() []ChatMessagePart { return nil } +// 将FunctionCall转换为ToolCalls func (m *ChatCompletionMessage) FuncToToolCalls() { if m.ToolCalls != nil { return @@ -126,6 +127,21 @@ func (m *ChatCompletionMessage) FuncToToolCalls() { } } +// 将ToolCalls转换为FunctionCall +func (m *ChatCompletionMessage) ToolToFuncCalls() { + + if m.FunctionCall != nil { + return + } + if m.ToolCalls != nil { + m.FunctionCall = &ChatCompletionToolCallsFunction{ + Name: m.ToolCalls[0].Function.Name, + Arguments: m.ToolCalls[0].Function.Arguments, + } + m.ToolCalls = nil + } +} + type ChatMessageImageURL struct { URL string `json:"url,omitempty"` Detail string `json:"detail,omitempty"` @@ -173,6 +189,22 @@ func (r ChatCompletionRequest) GetFunctionCate() string { return "" } +func (r *ChatCompletionRequest) GetFunctions() []*ChatCompletionFunction { + if r.Tools == nil && r.Functions == nil { + return nil + } + + if r.Tools != nil { + var functions []*ChatCompletionFunction + for _, tool := range r.Tools { + functions = append(functions, &tool.Function) + } + return functions + } + + return r.Functions +} + type ChatCompletionFunction struct { Name string `json:"name"` Description string `json:"description"` @@ -193,6 +225,13 @@ type ChatCompletionChoice struct { FinishDetails any `json:"finish_details,omitempty"` } +func (c *ChatCompletionChoice) CheckChoice(request *ChatCompletionRequest) { + if request.Functions != nil && c.Message.ToolCalls != nil { + c.Message.ToolToFuncCalls() + c.FinishReason = FinishReasonFunctionCall + } +} + type ChatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -213,61 +252,16 @@ func (cc *ChatCompletionResponse) GetContent() string { } func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStreamChoice { - var function *ChatCompletionToolCallsFunction - var functions []*ChatCompletionToolCallsFunction var choices []ChatCompletionStreamChoice var stopFinish string if c.Delta.FunctionCall != nil { - function = c.Delta.FunctionCall stopFinish = FinishReasonFunctionCall + choices = c.Delta.FunctionCall.Split(&c, stopFinish, 0) } else { - function = c.Delta.ToolCalls[0].Function stopFinish = FinishReasonToolCalls - } - - if function.Name == "" { - c.FinishReason = stopFinish - choices = append(choices, c) - return choices - } - - functions = append(functions, &ChatCompletionToolCallsFunction{ - Name: function.Name, - Arguments: "", - }) - - if function.Arguments == "" || function.Arguments == "{}" { - functions = append(functions, &ChatCompletionToolCallsFunction{ - Arguments: "{}", - }) - } else { - functions = append(functions, &ChatCompletionToolCallsFunction{ - Arguments: function.Arguments, - }) - } - - // 循环functions, 生成choices - for _, function := range functions { - choice := ChatCompletionStreamChoice{ - Index: 0, - Delta: ChatCompletionStreamChoiceDelta{ - Role: c.Delta.Role, - }, + for index, tool := range c.Delta.ToolCalls { + choices = append(choices, tool.Function.Split(&c, stopFinish, index)...) } - if stopFinish == FinishReasonFunctionCall { - choice.Delta.FunctionCall = function - } else { - choice.Delta.ToolCalls = []*ChatCompletionToolCalls{ - { - Id: c.Delta.ToolCalls[0].Id, - Index: 0, - Type: "function", - Function: function, - }, - } - } - - choices = append(choices, choice) } choices = append(choices, ChatCompletionStreamChoice{ @@ -279,6 +273,53 @@ func (c ChatCompletionStreamChoice) ConvertOpenaiStream() []ChatCompletionStream return choices } +func (f *ChatCompletionToolCallsFunction) Split(c *ChatCompletionStreamChoice, stopFinish string, index int) []ChatCompletionStreamChoice { + var functions []*ChatCompletionToolCallsFunction + var choices []ChatCompletionStreamChoice + functions = append(functions, &ChatCompletionToolCallsFunction{ + Name: f.Name, + Arguments: "", + }) + + if f.Arguments == "" || f.Arguments == "{}" { + functions = append(functions, &ChatCompletionToolCallsFunction{ + Arguments: "{}", + }) + } else { + functions = append(functions, &ChatCompletionToolCallsFunction{ + Arguments: f.Arguments, + }) + } + + for fIndex, function := range functions { + choice := ChatCompletionStreamChoice{ + Index: c.Index, + Delta: ChatCompletionStreamChoiceDelta{ + Role: c.Delta.Role, + }, + } + if stopFinish == FinishReasonFunctionCall { + choice.Delta.FunctionCall = function + } else { + toolCalls := &ChatCompletionToolCalls{ + // Id: c.Delta.ToolCalls[0].Id, + Index: index, + Type: ChatMessageRoleFunction, + Function: function, + } + + if fIndex == 0 { + toolCalls.Id = c.Delta.ToolCalls[0].Id + } + choice.Delta.ToolCalls = []*ChatCompletionToolCalls{toolCalls} + } + + choices = append(choices, choice) + } + + return choices +} + type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` @@ -286,6 +327,20 @@ type ChatCompletionStreamChoiceDelta struct { ToolCalls []*ChatCompletionToolCalls `json:"tool_calls,omitempty"` } +func (m *ChatCompletionStreamChoiceDelta) ToolToFuncCalls() { + + if m.FunctionCall != nil { + return + } + if m.ToolCalls != nil { + m.FunctionCall = &ChatCompletionToolCallsFunction{ + Name: m.ToolCalls[0].Function.Name, + Arguments: m.ToolCalls[0].Function.Arguments, + } + m.ToolCalls = nil + } +} + type ChatCompletionStreamChoice struct { Index int `json:"index"` Delta ChatCompletionStreamChoiceDelta `json:"delta"` @@ -293,6 +348,13 @@ type ChatCompletionStreamChoice struct { ContentFilterResults any `json:"content_filter_results,omitempty"` } +func (c *ChatCompletionStreamChoice) CheckChoice(request *ChatCompletionRequest) { + if request.Functions != nil && c.Delta.ToolCalls != nil { + c.Delta.ToolToFuncCalls() + c.FinishReason = FinishReasonToolCalls + } +} + type ChatCompletionStreamResponse struct { ID string `json:"id"` Object string `json:"object"`