🐛 fix: gemini tools

This commit is contained in:
Martial BE 2024-04-11 11:54:10 +08:00
parent 8ca239095b
commit abd889c398
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
2 changed files with 209 additions and 92 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/common/requester"
"one-api/types"
"strings"
@ -113,76 +112,21 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq
MaxOutputTokens: request.MaxTokens,
},
}
if request.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: request.Functions,
},
if request.Tools != nil {
var geminiChatTools GeminiChatTools
for _, tool := range request.Tools {
geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, tool.Function)
}
geminiRequest.Tools = append(geminiRequest.Tools, geminiChatTools)
}
shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := GeminiChatContent{
Role: message.Role,
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
geminiContent, err := OpenAIToGeminiChatContent(request.Messages)
if err != nil {
return nil, err
}
geminiRequest.Contents = geminiContent
return &geminiRequest, nil
}
@ -207,13 +151,24 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
choice := types.ChatCompletionChoice{
Index: i,
Message: types.ChatCompletionMessage{
Role: "assistant",
Content: "",
Role: "assistant",
// Content: "",
},
FinishReason: types.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
if len(candidate.Content.Parts) == 0 {
choice.Message.Content = ""
openaiResponse.Choices = append(openaiResponse.Choices, choice)
continue
// choice.Message.Content = candidate.Content.Parts[0].Text
}
// 开始判断
geminiParts := candidate.Content.Parts[0]
if geminiParts.FunctionCall != nil {
choice.Message.ToolCalls = geminiParts.FunctionCall.ToOpenAITool()
} else {
choice.Message.Content = geminiParts.Text
}
openaiResponse.Choices = append(openaiResponse.Choices, choice)
}
@ -251,34 +206,69 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
return
}
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
h.convertToOpenaiStream(&geminiResponse, dataChan)
}
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates {
choice := types.ChatCompletionStreamChoice{
Index: i,
Delta: types.ChatCompletionStreamChoiceDelta{
Content: candidate.Content.Parts[0].Text,
},
FinishReason: types.FinishReasonStop,
}
choices = append(choices, choice)
}
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model,
Choices: choices,
// Choices: choices,
}
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
for i, candidate := range geminiResponse.Candidates {
parts := candidate.Content.Parts[0]
choice := types.ChatCompletionStreamChoice{
Index: i,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
},
FinishReason: types.FinishReasonStop,
}
if parts.FunctionCall != nil {
if parts.FunctionCall.Args == nil {
parts.FunctionCall.Args = map[string]interface{}{}
}
args, _ := json.Marshal(parts.FunctionCall.Args)
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
{
Id: "call_" + common.GetRandomString(24),
Type: types.ChatMessageRoleFunction,
Index: 0,
Function: &types.ChatCompletionToolCallsFunction{
Name: parts.FunctionCall.Name,
Arguments: string(args),
},
},
}
} else {
choice.Delta.Content = parts.Text
}
choices = append(choices, choice)
}
if len(choices) > 0 && choices[0].Delta.ToolCalls != nil {
choices := choices[0].ConvertOpenaiStream()
for _, choice := range choices {
chatCompletionCopy := streamResponse
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
responseBody, _ := json.Marshal(chatCompletionCopy)
dataChan <- string(responseBody)
}
} else {
streamResponse.Choices = choices
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens

View File

@ -1,6 +1,12 @@
package gemini
import "one-api/types"
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/types"
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
@ -15,8 +21,41 @@ type GeminiInlineData struct {
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiFunctionCall struct {
Name string `json:"name,omitempty"`
Args map[string]interface{} `json:"args,omitempty"`
}
type GeminiFunctionResponse struct {
Name string `json:"name,omitempty"`
Response GeminiFunctionResponseContent `json:"response,omitempty"`
}
type GeminiFunctionResponseContent struct {
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
}
func (g *GeminiFunctionCall) ToOpenAITool() []*types.ChatCompletionToolCalls {
args, _ := json.Marshal(g.Args)
return []*types.ChatCompletionToolCalls{
{
Id: "",
Type: types.ChatMessageRoleFunction,
Index: 0,
Function: &types.ChatCompletionToolCallsFunction{
Name: g.Name,
Arguments: string(args),
},
},
}
}
type GeminiChatContent struct {
@ -30,7 +69,7 @@ type GeminiChatSafetySettings struct {
}
type GeminiChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
FunctionDeclarations []types.ChatCompletionFunction `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
@ -85,3 +124,91 @@ func (g *GeminiChatResponse) GetResponseText() string {
}
return ""
}
func OpenAIToGeminiChatContent(openaiContents []types.ChatCompletionMessage) ([]GeminiChatContent, *types.OpenAIErrorWithStatusCode) {
contents := make([]GeminiChatContent, 0)
for _, openaiContent := range openaiContents {
content := GeminiChatContent{
Role: ConvertRole(openaiContent.Role),
Parts: make([]GeminiPart, 0),
}
content.Role = ConvertRole(openaiContent.Role)
if openaiContent.Role == types.ChatMessageRoleFunction {
contents = append(contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
FunctionCall: &GeminiFunctionCall{
Name: *openaiContent.Name,
Args: map[string]interface{}{},
},
},
},
})
content = GeminiChatContent{
Role: "function",
Parts: []GeminiPart{
{
FunctionResponse: &GeminiFunctionResponse{
Name: *openaiContent.Name,
Response: GeminiFunctionResponseContent{
Name: *openaiContent.Name,
Content: openaiContent.StringContent(),
},
},
},
},
}
} else {
openaiMessagePart := openaiContent.ParseContent()
imageNum := 0
for _, openaiPart := range openaiMessagePart {
if openaiPart.Type == types.ContentTypeText {
content.Parts = append(content.Parts, GeminiPart{
Text: openaiPart.Text,
})
} else if openaiPart.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(openaiPart.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
content.Parts = append(content.Parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
}
contents = append(contents, content)
if openaiContent.Role == types.ChatMessageRoleSystem {
contents = append(contents, GeminiChatContent{
Role: "model",
Parts: []GeminiPart{
{
Text: "Okay",
},
},
})
}
}
return contents, nil
}
func ConvertRole(roleName string) string {
switch roleName {
case types.ChatMessageRoleFunction, types.ChatMessageRoleTool:
return types.ChatMessageRoleFunction
case types.ChatMessageRoleAssistant:
return "model"
default:
return types.ChatMessageRoleUser
}
}