feat: add function and tools support for Gemini (#1358)

* Update model.go

* Support Gemini tool_calls.

* Fix gemini tool calls (also keep support functions).

* Fixed the problem of arguments not being stringified.

Fix panic: candidate.Content.Parts out of range
This commit is contained in:
Wei Tingjiang 2024-04-24 21:26:45 +08:00 committed by GitHub
parent 3d149fedf4
commit 779b747e9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 10 deletions

View File

@ -4,6 +4,10 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -13,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -54,7 +55,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
}, },
} }
if textRequest.Functions != nil { if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
functions = append(functions, tool.Function)
}
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: functions,
},
}
} else if textRequest.Functions != nil {
geminiRequest.Tools = []ChatTools{ geminiRequest.Tools = []ChatTools{
{ {
FunctionDeclarations: textRequest.Functions, FunctionDeclarations: textRequest.Functions,
@ -154,6 +165,30 @@ type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"` SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
} }
func getToolCalls(candidate *ChatCandidate) []model.Tool {
var toolCalls []model.Tool
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
logger.FatalLog("getToolCalls failed: " + err.Error())
return toolCalls
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function",
Function: model.Function{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
@ -165,13 +200,19 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: i, Index: i,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: "",
}, },
FinishReason: constant.StopFinishReason, FinishReason: constant.StopFinishReason,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
choice.Message.Content = candidate.Content.Parts[0].Text
}
} else {
choice.Message.Content = ""
choice.FinishReason = candidate.FinishReason
} }
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }

View File

@ -12,9 +12,15 @@ type InlineData struct {
Data string `json:"data"` Data string `json:"data"`
} }
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type Part struct { type Part struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"` InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
} }
type ChatContent struct { type ChatContent struct {
@ -28,7 +34,7 @@ type ChatSafetySettings struct {
} }
type ChatTools struct { type ChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"` FunctionDeclarations any `json:"function_declarations,omitempty"`
} }
type ChatGenerationConfig struct { type ChatGenerationConfig struct {