From abd889c39820429dd8e8ebfc80e807d500abae09 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Thu, 11 Apr 2024 11:54:10 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20gemini=20tools?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- providers/gemini/chat.go | 166 ++++++++++++++++++--------------------- providers/gemini/type.go | 135 ++++++++++++++++++++++++++++++- 2 files changed, 209 insertions(+), 92 deletions(-) diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go index 36389220..310160ba 100644 --- a/providers/gemini/chat.go +++ b/providers/gemini/chat.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "one-api/common" - "one-api/common/image" "one-api/common/requester" "one-api/types" "strings" @@ -113,76 +112,21 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq MaxOutputTokens: request.MaxTokens, }, } - if request.Functions != nil { - geminiRequest.Tools = []GeminiChatTools{ - { - FunctionDeclarations: request.Functions, - }, + if request.Tools != nil { + var geminiChatTools GeminiChatTools + for _, tool := range request.Tools { + geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, tool.Function) } + geminiRequest.Tools = append(geminiRequest.Tools, geminiChatTools) } - shouldAddDummyModelMessage := false - for _, message := range request.Messages { - content := GeminiChatContent{ - Role: message.Role, - Parts: []GeminiPart{ - { - Text: message.StringContent(), - }, - }, - } - openaiContent := message.ParseContent() - var parts []GeminiPart - imageNum := 0 - for _, part := range openaiContent { - if part.Type == types.ContentTypeText { - parts = append(parts, GeminiPart{ - Text: part.Text, - }) - } else if part.Type == types.ContentTypeImageURL { - imageNum += 1 - if imageNum > GeminiVisionMaxImageNum { - continue - } - mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL) - if err != nil { - return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest) - } - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: data, - }, - }) - } - } - content.Parts = parts - - // there's no assistant role in gemini and API shall vomit if Role is not user or model - if content.Role == "assistant" { - content.Role = "model" - } - // Converting system prompt to prompt from user for the same reason - if content.Role == "system" { - content.Role = "user" - shouldAddDummyModelMessage = true - } - geminiRequest.Contents = append(geminiRequest.Contents, content) - - // If a system message is the last message, we need to add a dummy model message to make gemini happy - if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ - Role: "model", - Parts: []GeminiPart{ - { - Text: "Okay", - }, - }, - }) - shouldAddDummyModelMessage = false - } + geminiContent, err := OpenAIToGeminiChatContent(request.Messages) + if err != nil { + return nil, err } + geminiRequest.Contents = geminiContent + return &geminiRequest, nil } @@ -207,13 +151,24 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque choice := types.ChatCompletionChoice{ Index: i, Message: types.ChatCompletionMessage{ - Role: "assistant", - Content: "", + Role: "assistant", + // Content: "", }, FinishReason: types.FinishReasonStop, } - if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text + if len(candidate.Content.Parts) == 0 { + choice.Message.Content = "" + openaiResponse.Choices = append(openaiResponse.Choices, choice) + continue + // choice.Message.Content = candidate.Content.Parts[0].Text + } + // 开始判断 + geminiParts := candidate.Content.Parts[0] + + if geminiParts.FunctionCall != nil { + choice.Message.ToolCalls = geminiParts.FunctionCall.ToOpenAITool() + } else { + choice.Message.Content = geminiParts.Text } openaiResponse.Choices = append(openaiResponse.Choices, choice) } @@ -251,34 +206,69 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin return } - h.convertToOpenaiStream(&geminiResponse, dataChan, errChan) + h.convertToOpenaiStream(&geminiResponse, dataChan) } -func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) { - choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) - - for i, candidate := range geminiResponse.Candidates { - choice := types.ChatCompletionStreamChoice{ - Index: i, - Delta: types.ChatCompletionStreamChoiceDelta{ - Content: candidate.Content.Parts[0].Text, - }, - FinishReason: types.FinishReasonStop, - } - choices = append(choices, choice) - } - +func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) { streamResponse := types.ChatCompletionStreamResponse{ ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Object: "chat.completion.chunk", Created: common.GetTimestamp(), Model: h.Request.Model, - Choices: choices, + // Choices: choices, } - responseBody, _ := json.Marshal(streamResponse) - dataChan <- string(responseBody) + choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates)) + + for i, candidate := range geminiResponse.Candidates { + parts := candidate.Content.Parts[0] + + choice := types.ChatCompletionStreamChoice{ + Index: i, + Delta: types.ChatCompletionStreamChoiceDelta{ + Role: types.ChatMessageRoleAssistant, + }, + FinishReason: types.FinishReasonStop, + } + + if parts.FunctionCall != nil { + if parts.FunctionCall.Args == nil { + parts.FunctionCall.Args = map[string]interface{}{} + } + args, _ := json.Marshal(parts.FunctionCall.Args) + + choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{ + { + Id: "call_" + common.GetRandomString(24), + Type: types.ChatMessageRoleFunction, + Index: 0, + Function: &types.ChatCompletionToolCallsFunction{ + Name: parts.FunctionCall.Name, + Arguments: string(args), + }, + }, + } + } else { + choice.Delta.Content = parts.Text + } + + choices = append(choices, choice) + } + + if len(choices) > 0 && choices[0].Delta.ToolCalls != nil { + choices := choices[0].ConvertOpenaiStream() + for _, choice := range choices { + chatCompletionCopy := streamResponse + chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice} + responseBody, _ := json.Marshal(chatCompletionCopy) + dataChan <- string(responseBody) + } + } else { + streamResponse.Choices = choices + responseBody, _ := json.Marshal(streamResponse) + dataChan <- string(responseBody) + } h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model) h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens diff --git a/providers/gemini/type.go b/providers/gemini/type.go index 9e515172..46933cb8 100644 --- a/providers/gemini/type.go +++ b/providers/gemini/type.go @@ -1,6 +1,12 @@ package gemini -import "one-api/types" +import ( + "encoding/json" + "net/http" + "one-api/common" + "one-api/common/image" + "one-api/types" +) type GeminiChatRequest struct { Contents []GeminiChatContent `json:"contents"` @@ -15,8 +21,41 @@ type GeminiInlineData struct { } type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiFunctionCall struct { + Name string `json:"name,omitempty"` + Args map[string]interface{} `json:"args,omitempty"` +} + +type GeminiFunctionResponse struct { + Name string `json:"name,omitempty"` + Response GeminiFunctionResponseContent `json:"response,omitempty"` +} + +type GeminiFunctionResponseContent struct { + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` +} + +func (g *GeminiFunctionCall) ToOpenAITool() []*types.ChatCompletionToolCalls { + args, _ := json.Marshal(g.Args) + + return []*types.ChatCompletionToolCalls{ + { + Id: "", + Type: types.ChatMessageRoleFunction, + Index: 0, + Function: &types.ChatCompletionToolCallsFunction{ + Name: g.Name, + Arguments: string(args), + }, + }, + } } type GeminiChatContent struct { @@ -30,7 +69,7 @@ type GeminiChatSafetySettings struct { } type GeminiChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` + FunctionDeclarations []types.ChatCompletionFunction `json:"functionDeclarations,omitempty"` } type GeminiChatGenerationConfig struct { @@ -85,3 +124,91 @@ func (g *GeminiChatResponse) GetResponseText() string { } return "" } + +func OpenAIToGeminiChatContent(openaiContents []types.ChatCompletionMessage) ([]GeminiChatContent, *types.OpenAIErrorWithStatusCode) { + contents := make([]GeminiChatContent, 0) + for _, openaiContent := range openaiContents { + content := GeminiChatContent{ + Role: ConvertRole(openaiContent.Role), + Parts: make([]GeminiPart, 0), + } + content.Role = ConvertRole(openaiContent.Role) + if openaiContent.Role == types.ChatMessageRoleFunction { + contents = append(contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + FunctionCall: &GeminiFunctionCall{ + Name: *openaiContent.Name, + Args: map[string]interface{}{}, + }, + }, + }, + }) + content = GeminiChatContent{ + Role: "function", + Parts: []GeminiPart{ + { + FunctionResponse: &GeminiFunctionResponse{ + Name: *openaiContent.Name, + Response: GeminiFunctionResponseContent{ + Name: *openaiContent.Name, + Content: openaiContent.StringContent(), + }, + }, + }, + }, + } + } else { + openaiMessagePart := openaiContent.ParseContent() + imageNum := 0 + for _, openaiPart := range openaiMessagePart { + if openaiPart.Type == types.ContentTypeText { + content.Parts = append(content.Parts, GeminiPart{ + Text: openaiPart.Text, + }) + } else if openaiPart.Type == types.ContentTypeImageURL { + imageNum += 1 + if imageNum > GeminiVisionMaxImageNum { + continue + } + mimeType, data, err := image.GetImageFromUrl(openaiPart.ImageURL.URL) + if err != nil { + return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest) + } + content.Parts = append(content.Parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: data, + }, + }) + } + } + } + contents = append(contents, content) + if openaiContent.Role == types.ChatMessageRoleSystem { + contents = append(contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "Okay", + }, + }, + }) + } + + } + + return contents, nil +} + +func ConvertRole(roleName string) string { + switch roleName { + case types.ChatMessageRoleFunction, types.ChatMessageRoleTool: + return types.ChatMessageRoleFunction + case types.ChatMessageRoleAssistant: + return "model" + default: + return types.ChatMessageRoleUser + } +}