feat: add function and tools support for Gemini (#1358)
* Update model.go * Support Gemini tool_calls. * Fix gemini tool calls (also keep support functions). * Fixed the problem of arguments not being stringified. Fix panic: candidate.Content.Parts out of range
This commit is contained in:
parent
3d149fedf4
commit
779b747e9e
@ -4,6 +4,10 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
@ -13,9 +17,6 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@ -54,7 +55,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
|||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.MaxTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if textRequest.Functions != nil {
|
if textRequest.Tools != nil {
|
||||||
|
functions := make([]model.Function, 0, len(textRequest.Tools))
|
||||||
|
for _, tool := range textRequest.Tools {
|
||||||
|
functions = append(functions, tool.Function)
|
||||||
|
}
|
||||||
|
geminiRequest.Tools = []ChatTools{
|
||||||
|
{
|
||||||
|
FunctionDeclarations: functions,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if textRequest.Functions != nil {
|
||||||
geminiRequest.Tools = []ChatTools{
|
geminiRequest.Tools = []ChatTools{
|
||||||
{
|
{
|
||||||
FunctionDeclarations: textRequest.Functions,
|
FunctionDeclarations: textRequest.Functions,
|
||||||
@ -154,6 +165,30 @@ type ChatPromptFeedback struct {
|
|||||||
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getToolCalls(candidate *ChatCandidate) []model.Tool {
|
||||||
|
var toolCalls []model.Tool
|
||||||
|
|
||||||
|
item := candidate.Content.Parts[0]
|
||||||
|
if item.FunctionCall == nil {
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
logger.FatalLog("getToolCalls failed: " + err.Error())
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
toolCall := model.Tool{
|
||||||
|
Id: fmt.Sprintf("call_%s", random.GetUUID()),
|
||||||
|
Type: "function",
|
||||||
|
Function: model.Function{
|
||||||
|
Arguments: string(argsBytes),
|
||||||
|
Name: item.FunctionCall.FunctionName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||||
fullTextResponse := openai.TextResponse{
|
fullTextResponse := openai.TextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
|
||||||
@ -165,13 +200,19 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
|
|||||||
choice := openai.TextResponseChoice{
|
choice := openai.TextResponseChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: model.Message{
|
Message: model.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "",
|
|
||||||
},
|
},
|
||||||
FinishReason: constant.StopFinishReason,
|
FinishReason: constant.StopFinishReason,
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) > 0 {
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||||
|
choice.Message.ToolCalls = getToolCalls(&candidate)
|
||||||
|
} else {
|
||||||
|
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choice.Message.Content = ""
|
||||||
|
choice.FinishReason = candidate.FinishReason
|
||||||
}
|
}
|
||||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
|
@ -12,9 +12,15 @@ type InlineData struct {
|
|||||||
Data string `json:"data"`
|
Data string `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
FunctionName string `json:"name"`
|
||||||
|
Arguments any `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
type Part struct {
|
type Part struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatContent struct {
|
type ChatContent struct {
|
||||||
@ -28,7 +34,7 @@ type ChatSafetySettings struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatTools struct {
|
type ChatTools struct {
|
||||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
FunctionDeclarations any `json:"function_declarations,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatGenerationConfig struct {
|
type ChatGenerationConfig struct {
|
||||||
|
Loading…
Reference in New Issue
Block a user