diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 6bf0c6d7..8b934d30 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -4,6 +4,10 @@ import ( "bufio" "encoding/json" "fmt" + "io" + "net/http" + "strings" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" @@ -13,9 +17,6 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strings" "github.com/gin-gonic/gin" ) @@ -54,7 +55,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { MaxOutputTokens: textRequest.MaxTokens, }, } - if textRequest.Functions != nil { + if textRequest.Tools != nil { + functions := make([]model.Function, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []ChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { geminiRequest.Tools = []ChatTools{ { FunctionDeclarations: textRequest.Functions, @@ -154,6 +165,30 @@ type ChatPromptFeedback struct { SafetyRatings []ChatSafetyRating `json:"safetyRatings"` } +func getToolCalls(candidate *ChatCandidate) []model.Tool { + var toolCalls []model.Tool + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + logger.FatalLog("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := model.Tool{ + Id: fmt.Sprintf("call_%s", random.GetUUID()), + Type: "function", + Function: model.Function{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), @@ -165,13 +200,19 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { choice := openai.TextResponseChoice{ Index: i, Message: model.Message{ - Role: "assistant", - Content: "", + Role: "assistant", }, FinishReason: constant.StopFinishReason, } if len(candidate.Content.Parts) > 0 { - choice.Message.Content = candidate.Content.Parts[0].Text + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + choice.Message.Content = candidate.Content.Parts[0].Text + } + } else { + choice.Message.Content = "" + choice.FinishReason = candidate.FinishReason } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } diff --git a/relay/adaptor/gemini/model.go b/relay/adaptor/gemini/model.go index d1e3c4fd..47b74fbc 100644 --- a/relay/adaptor/gemini/model.go +++ b/relay/adaptor/gemini/model.go @@ -12,9 +12,15 @@ type InlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type Part struct { - Text string `json:"text,omitempty"` - InlineData *InlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *InlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type ChatContent struct { @@ -28,7 +34,7 @@ type ChatSafetySettings struct { } type ChatTools struct { - FunctionDeclarations any `json:"functionDeclarations,omitempty"` + FunctionDeclarations any `json:"function_declarations,omitempty"` } type ChatGenerationConfig struct {