From 69f6a418aaf3dd5d425c7111adc7a1a00c214631 Mon Sep 17 00:00:00 2001 From: mxdlzg Date: Tue, 23 Apr 2024 11:45:58 +0800 Subject: [PATCH] Support Gemini tool_calls. --- relay/adaptor/gemini/main.go | 36 ++++++++++++++++++++++++++++------- relay/adaptor/gemini/model.go | 10 ++++++++-- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 6bf0c6d7..9c2d78b2 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,10 @@ import ( "bufio" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -13,9 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" "github.com/gin-gonic/gin" ) @@ -154,6 +155,25 @@ type ChatPromptFeedback struct { SafetyRatings []ChatSafetyRating `json:"safetyRatings"` } +func getToolCalls(candidate *ChatCandidate) []model.Tool { + var toolCalls []model.Tool + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: model.Function{ + Arguments: item.FunctionCall.Arguments, + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), @@ -165,13 +185,15 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: i, Message: model.Message{ - Role: "assistant", - Content: "", + Role: "assistant", }, FinishReason: constant.StopFinishReason, } - if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text + if candidate.Content.Parts[i].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[i].Text + } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index ff4b41a0..47b74fbc 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -12,9 +12,15 @@ type InlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type Part struct { - Text string `json:"text,omitempty"` - InlineData *InlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *InlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type ChatContent struct {