🐛 fix: gemini tools

This commit is contained in:
Martial BE 2024-04-11 11:54:10 +08:00
parent 8ca239095b
commit abd889c398
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
2 changed files with 209 additions and 92 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/image"
"one-api/common/requester" "one-api/common/requester"
"one-api/types" "one-api/types"
"strings" "strings"
@ -113,76 +112,21 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq
MaxOutputTokens: request.MaxTokens, MaxOutputTokens: request.MaxTokens,
}, },
} }
if request.Functions != nil { if request.Tools != nil {
geminiRequest.Tools = []GeminiChatTools{ var geminiChatTools GeminiChatTools
{ for _, tool := range request.Tools {
FunctionDeclarations: request.Functions, geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, tool.Function)
},
} }
geminiRequest.Tools = append(geminiRequest.Tools, geminiChatTools)
} }
shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent() geminiContent, err := OpenAIToGeminiChatContent(request.Messages)
var parts []GeminiPart if err != nil {
imageNum := 0 return nil, err
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
} }
geminiRequest.Contents = geminiContent
return &geminiRequest, nil return &geminiRequest, nil
} }
@ -207,13 +151,24 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
choice := types.ChatCompletionChoice{ choice := types.ChatCompletionChoice{
Index: i, Index: i,
Message: types.ChatCompletionMessage{ Message: types.ChatCompletionMessage{
Role: "assistant", Role: "assistant",
Content: "", // Content: "",
}, },
FinishReason: types.FinishReasonStop, FinishReason: types.FinishReasonStop,
} }
if len(candidate.Content.Parts) > 0 { if len(candidate.Content.Parts) == 0 {
choice.Message.Content = candidate.Content.Parts[0].Text choice.Message.Content = ""
openaiResponse.Choices = append(openaiResponse.Choices, choice)
continue
// choice.Message.Content = candidate.Content.Parts[0].Text
}
// 开始判断
geminiParts := candidate.Content.Parts[0]
if geminiParts.FunctionCall != nil {
choice.Message.ToolCalls = geminiParts.FunctionCall.ToOpenAITool()
} else {
choice.Message.Content = geminiParts.Text
} }
openaiResponse.Choices = append(openaiResponse.Choices, choice) openaiResponse.Choices = append(openaiResponse.Choices, choice)
} }
@ -251,34 +206,69 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
return return
} }
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan) h.convertToOpenaiStream(&geminiResponse, dataChan)
} }
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) { func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates {
choice := types.ChatCompletionStreamChoice{
Index: i,
Delta: types.ChatCompletionStreamChoiceDelta{
Content: candidate.Content.Parts[0].Text,
},
FinishReason: types.FinishReasonStop,
}
choices = append(choices, choice)
}
streamResponse := types.ChatCompletionStreamResponse{ streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
Model: h.Request.Model, Model: h.Request.Model,
Choices: choices, // Choices: choices,
} }
responseBody, _ := json.Marshal(streamResponse) choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
dataChan <- string(responseBody)
for i, candidate := range geminiResponse.Candidates {
parts := candidate.Content.Parts[0]
choice := types.ChatCompletionStreamChoice{
Index: i,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
},
FinishReason: types.FinishReasonStop,
}
if parts.FunctionCall != nil {
if parts.FunctionCall.Args == nil {
parts.FunctionCall.Args = map[string]interface{}{}
}
args, _ := json.Marshal(parts.FunctionCall.Args)
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: "call_" + common.GetRandomString(24),
Type: types.ChatMessageRoleFunction,
Index: 0,
Function: &types.ChatCompletionToolCallsFunction{
Name: parts.FunctionCall.Name,
Arguments: string(args),
},
},
}
} else {
choice.Delta.Content = parts.Text
}
choices = append(choices, choice)
}
if len(choices) > 0 && choices[0].Delta.ToolCalls != nil {
choices := choices[0].ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := streamResponse
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
} else {
streamResponse.Choices = choices
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model) h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens

View File

@ -1,6 +1,12 @@
package gemini package gemini
import "one-api/types" import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/types"
)
type GeminiChatRequest struct { type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"` Contents []GeminiChatContent `json:"contents"`
@ -15,8 +21,41 @@ type GeminiInlineData struct {
} }
type GeminiPart struct { type GeminiPart struct {
Text string `json:"text,omitempty"` FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiFunctionCall struct {
Name string `json:"name,omitempty"`
Args map[string]interface{} `json:"args,omitempty"`
}
type GeminiFunctionResponse struct {
Name string `json:"name,omitempty"`
Response GeminiFunctionResponseContent `json:"response,omitempty"`
}
type GeminiFunctionResponseContent struct {
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
func (g *GeminiFunctionCall) ToOpenAITool() []*types.ChatCompletionToolCalls {
args, _ := json.Marshal(g.Args)
return []*types.ChatCompletionToolCalls{
{
Id: "",
Type: types.ChatMessageRoleFunction,
Index: 0,
Function: &types.ChatCompletionToolCallsFunction{
Name: g.Name,
Arguments: string(args),
},
},
}
} }
type GeminiChatContent struct { type GeminiChatContent struct {
@ -30,7 +69,7 @@ type GeminiChatSafetySettings struct {
} }
type GeminiChatTools struct { type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"` FunctionDeclarations []types.ChatCompletionFunction `json:"functionDeclarations,omitempty"`
} }
type GeminiChatGenerationConfig struct { type GeminiChatGenerationConfig struct {
@ -85,3 +124,91 @@ func (g *GeminiChatResponse) GetResponseText() string {
} }
return "" return ""
} }
func OpenAIToGeminiChatContent(openaiContents []types.ChatCompletionMessage) ([]GeminiChatContent, *types.OpenAIErrorWithStatusCode) {
contents := make([]GeminiChatContent, 0)
for _, openaiContent := range openaiContents {
content := GeminiChatContent{
Role: ConvertRole(openaiContent.Role),
Parts: make([]GeminiPart, 0),
}
content.Role = ConvertRole(openaiContent.Role)
if openaiContent.Role == types.ChatMessageRoleFunction {
contents = append(contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
FunctionCall: &GeminiFunctionCall{
Name: *openaiContent.Name,
Args: map[string]interface{}{},
},
},
},
})
content = GeminiChatContent{
Role: "function",
Parts: []GeminiPart{
{
FunctionResponse: &GeminiFunctionResponse{
Name: *openaiContent.Name,
Response: GeminiFunctionResponseContent{
Name: *openaiContent.Name,
Content: openaiContent.StringContent(),
},
},
},
},
}
} else {
openaiMessagePart := openaiContent.ParseContent()
imageNum := 0
for _, openaiPart := range openaiMessagePart {
if openaiPart.Type == types.ContentTypeText {
content.Parts = append(content.Parts, GeminiPart{
Text: openaiPart.Text,
})
} else if openaiPart.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(openaiPart.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
content.Parts = append(content.Parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
}
contents = append(contents, content)
if openaiContent.Role == types.ChatMessageRoleSystem {
contents = append(contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
}
}
return contents, nil
}
func ConvertRole(roleName string) string {
switch roleName {
case types.ChatMessageRoleFunction, types.ChatMessageRoleTool:
return types.ChatMessageRoleFunction
case types.ChatMessageRoleAssistant:
return "model"
default:
return types.ChatMessageRoleUser
}
}