Support Gemini tool_calls.

This commit is contained in:
mxdlzg 2024-04-23 11:45:58 +08:00
parent 265ee27a0c
commit 69f6a418aa
2 changed files with 37 additions and 9 deletions

View File

@ -4,6 +4,10 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@ -13,9 +17,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@ -154,6 +155,25 @@ type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"` SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
} }
func getToolCalls(candidate *ChatCandidate) []model.Tool {
var toolCalls []model.Tool
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function",
Function: model.Function{
Arguments: item.FunctionCall.Arguments,
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
@ -165,13 +185,15 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: i, Index: i,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: "",
}, },
FinishReason: constant.StopFinishReason, FinishReason: constant.StopFinishReason,
} }
if len(candidate.Content.Parts) > 0 { if candidate.Content.Parts[i].FunctionCall != nil {
choice.Message.Content = candidate.Content.Parts[0].Text choice.Message.ToolCalls = getToolCalls(&candidate)
} else if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[i].Text
} }
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
} }

View File

@ -12,9 +12,15 @@ type InlineData struct {
Data string `json:"data"` Data string `json:"data"`
} }
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type Part struct { type Part struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"` InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
} }
type ChatContent struct { type ChatContent struct {