From c04dfc735f7a315f6ca856a476a98398c7b764bc Mon Sep 17 00:00:00 2001 From: MartialBE Date: Sun, 19 May 2024 13:11:46 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20function=20call=20error?= =?UTF-8?q?=20(#190)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/baidu/chat.go | 3 ++- providers/gemini/type.go | 13 +++++++++++-- providers/minimax/chat.go | 8 ++++++-- providers/xunfei/chat.go | 17 ++++++++++++++++- providers/zhipu/chat.go | 6 ++++-- types/chat.go | 15 +++++++++++++++ 6 files changed, 54 insertions(+), 8 deletions(-) diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index e2c8780b..8c74b5e0 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -150,7 +150,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque if message.Role == types.ChatMessageRoleSystem { baiduChatRequest.System = message.StringContent() continue - } else if message.Role == types.ChatMessageRoleFunction || message.Role == types.ChatMessageRoleTool { + } else if message.ToolCalls != nil { baiduChatRequest.Messages = append(baiduChatRequest.Messages, BaiduMessage{ Role: types.ChatMessageRoleAssistant, FunctionCall: &types.ChatCompletionToolCallsFunction{ @@ -158,6 +158,7 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *BaiduChatReque Arguments: "{}", }, }) + } else if message.Role == types.ChatMessageRoleFunction || message.Role == types.ChatMessageRoleTool { baiduChatRequest.Messages = append(baiduChatRequest.Messages, BaiduMessage{ Role: types.ChatMessageRoleUser, Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(), diff --git a/providers/gemini/type.go b/providers/gemini/type.go index 192eda4c..efe6208e 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -133,25 +133,34 @@ func (g *GeminiChatResponse) GetResponseText() string { func OpenAIToGeminiChatContent(openaiContents []types.ChatCompletionMessage) ([]GeminiChatContent, *types.OpenAIErrorWithStatusCode) { contents := make([]GeminiChatContent, 0) + useToolName := "" for _, openaiContent := range openaiContents { content := GeminiChatContent{ Role: ConvertRole(openaiContent.Role), Parts: make([]GeminiPart, 0), } content.Role = ConvertRole(openaiContent.Role) - if openaiContent.ToolCalls != nil { + if openaiContent.ToolCalls != nil || openaiContent.FunctionCall != nil { + if openaiContent.ToolCalls != nil { + useToolName = openaiContent.ToolCalls[0].Function.Name + } else { + useToolName = openaiContent.FunctionCall.Name + } content = GeminiChatContent{ Role: "model", Parts: []GeminiPart{ { FunctionCall: &GeminiFunctionCall{ - Name: openaiContent.ToolCalls[0].Function.Name, + Name: useToolName, Args: map[string]interface{}{}, }, }, }, } } else if openaiContent.Role == types.ChatMessageRoleFunction || openaiContent.Role == types.ChatMessageRoleTool { + if openaiContent.Name == nil { + openaiContent.Name = &useToolName + } content = GeminiChatContent{ Role: "function", Parts: []GeminiPart{ diff --git a/providers/minimax/chat.go b/providers/minimax/chat.go index 9dc69724..6aa47bf0 100644 --- a/providers/minimax/chat.go +++ b/providers/minimax/chat.go @@ -145,7 +145,12 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq } // 如果role为function, 则需要在前面一条记录添加function_call,如果没有消息,则添加一个message - if message.Role == types.ChatMessageRoleFunction || message.Role == types.ChatMessageRoleTool { + if message.ToolCalls != nil { + miniMessage.FunctionCall = &types.ChatCompletionToolCallsFunction{ + Name: message.ToolCalls[0].Function.Name, + Arguments: message.ToolCalls[0].Function.Arguments, + } + } else if message.Role == types.ChatMessageRoleFunction { if len(messges) == 0 { messges = append(messges, MiniMaxChatMessage{ SenderType: "USER", @@ -187,7 +192,6 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) *MiniMaxChatReq miniRequest.Functions = append(miniRequest.Functions, &tool.Function) } } - return miniRequest } diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index e1e089bf..e22bddaa 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -3,6 +3,7 @@ package xunfei import ( "encoding/json" "errors" + "fmt" "io" "net/http" "one-api/common" @@ -75,7 +76,21 @@ func (p *XunfeiProvider) getChatRequest(request *types.ChatCompletionRequest) (* func (p *XunfeiProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *XunfeiChatRequest { messages := make([]XunfeiMessage, 0, len(request.Messages)) for _, message := range request.Messages { - if message.Role == types.ChatMessageRoleFunction || message.Role == types.ChatMessageRoleTool { + if message.FunctionCall != nil || message.ToolCalls != nil { + useToolName := "" + useToolArgs := "" + if message.ToolCalls != nil { + useToolName = message.ToolCalls[0].Function.Name + useToolArgs = message.ToolCalls[0].Function.Arguments + } else { + useToolName = message.FunctionCall.Name + useToolArgs = message.FunctionCall.Arguments + } + messages = append(messages, XunfeiMessage{ + Role: message.Role, + Content: fmt.Sprintf("使用工具:%s,参数:%s", useToolName, useToolArgs), + }) + } else if message.Role == types.ChatMessageRoleFunction || message.Role == types.ChatMessageRoleTool { messages = append(messages, XunfeiMessage{ Role: types.ChatMessageRoleUser, Content: "这是函数调用返回的内容,请回答之前的问题:\n" + message.StringContent(), diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index eb478a46..45f5ff75 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -105,8 +105,11 @@ func (p *ZhipuProvider) convertToChatOpenai(response *ZhipuResponse, request *ty } func (p *ZhipuProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *ZhipuRequest { - for i := range request.Messages { + for i, _ := range request.Messages { request.Messages[i].Role = convertRole(request.Messages[i].Role) + if request.Messages[i].FunctionCall != nil { + request.Messages[i].FuncToToolCalls() + } } zhipuRequest := &ZhipuRequest{ @@ -167,7 +170,6 @@ func (p *ZhipuProvider) convertFromChatOpenai(request *types.ChatCompletionReque } p.pluginHandle(zhipuRequest) - return zhipuRequest } diff --git a/types/chat.go b/types/chat.go index 3e3668d8..8da252ce 100644 --- a/types/chat.go +++ b/types/chat.go @@ -111,6 +111,21 @@ func (m ChatCompletionMessage) ParseContent() []ChatMessagePart { return nil } +func (m *ChatCompletionMessage) FuncToToolCalls() { + if m.ToolCalls != nil { + return + } + if m.FunctionCall != nil { + m.ToolCalls = []*ChatCompletionToolCalls{ + { + Type: ChatMessageRoleFunction, + Function: m.FunctionCall, + }, + } + m.FunctionCall = nil + } +} + type ChatMessageImageURL struct { URL string `json:"url,omitempty"` Detail string `json:"detail,omitempty"`