🐛 fix: gemini tools
This commit is contained in:
parent
8ca239095b
commit
abd889c398
@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/image"
|
|
||||||
"one-api/common/requester"
|
"one-api/common/requester"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
@ -113,76 +112,21 @@ func convertFromChatOpenai(request *types.ChatCompletionRequest) (*GeminiChatReq
|
|||||||
MaxOutputTokens: request.MaxTokens,
|
MaxOutputTokens: request.MaxTokens,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if request.Functions != nil {
|
if request.Tools != nil {
|
||||||
geminiRequest.Tools = []GeminiChatTools{
|
var geminiChatTools GeminiChatTools
|
||||||
{
|
for _, tool := range request.Tools {
|
||||||
FunctionDeclarations: request.Functions,
|
geminiChatTools.FunctionDeclarations = append(geminiChatTools.FunctionDeclarations, tool.Function)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
geminiRequest.Tools = append(geminiRequest.Tools, geminiChatTools)
|
||||||
}
|
}
|
||||||
shouldAddDummyModelMessage := false
|
|
||||||
for _, message := range request.Messages {
|
|
||||||
content := GeminiChatContent{
|
|
||||||
Role: message.Role,
|
|
||||||
Parts: []GeminiPart{
|
|
||||||
{
|
|
||||||
Text: message.StringContent(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
openaiContent := message.ParseContent()
|
geminiContent, err := OpenAIToGeminiChatContent(request.Messages)
|
||||||
var parts []GeminiPart
|
if err != nil {
|
||||||
imageNum := 0
|
return nil, err
|
||||||
for _, part := range openaiContent {
|
|
||||||
if part.Type == types.ContentTypeText {
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
Text: part.Text,
|
|
||||||
})
|
|
||||||
} else if part.Type == types.ContentTypeImageURL {
|
|
||||||
imageNum += 1
|
|
||||||
if imageNum > GeminiVisionMaxImageNum {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content.Parts = parts
|
|
||||||
|
|
||||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
|
||||||
if content.Role == "assistant" {
|
|
||||||
content.Role = "model"
|
|
||||||
}
|
|
||||||
// Converting system prompt to prompt from user for the same reason
|
|
||||||
if content.Role == "system" {
|
|
||||||
content.Role = "user"
|
|
||||||
shouldAddDummyModelMessage = true
|
|
||||||
}
|
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
|
||||||
|
|
||||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
|
||||||
if shouldAddDummyModelMessage {
|
|
||||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
|
||||||
Role: "model",
|
|
||||||
Parts: []GeminiPart{
|
|
||||||
{
|
|
||||||
Text: "Okay",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
shouldAddDummyModelMessage = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
geminiRequest.Contents = geminiContent
|
||||||
|
|
||||||
return &geminiRequest, nil
|
return &geminiRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,13 +151,24 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
|
|||||||
choice := types.ChatCompletionChoice{
|
choice := types.ChatCompletionChoice{
|
||||||
Index: i,
|
Index: i,
|
||||||
Message: types.ChatCompletionMessage{
|
Message: types.ChatCompletionMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "",
|
// Content: "",
|
||||||
},
|
},
|
||||||
FinishReason: types.FinishReasonStop,
|
FinishReason: types.FinishReasonStop,
|
||||||
}
|
}
|
||||||
if len(candidate.Content.Parts) > 0 {
|
if len(candidate.Content.Parts) == 0 {
|
||||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
choice.Message.Content = ""
|
||||||
|
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||||
|
continue
|
||||||
|
// choice.Message.Content = candidate.Content.Parts[0].Text
|
||||||
|
}
|
||||||
|
// 开始判断
|
||||||
|
geminiParts := candidate.Content.Parts[0]
|
||||||
|
|
||||||
|
if geminiParts.FunctionCall != nil {
|
||||||
|
choice.Message.ToolCalls = geminiParts.FunctionCall.ToOpenAITool()
|
||||||
|
} else {
|
||||||
|
choice.Message.Content = geminiParts.Text
|
||||||
}
|
}
|
||||||
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
||||||
}
|
}
|
||||||
@ -251,34 +206,69 @@ func (h *geminiStreamHandler) handlerStream(rawLine *[]byte, dataChan chan strin
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.convertToOpenaiStream(&geminiResponse, dataChan, errChan)
|
h.convertToOpenaiStream(&geminiResponse, dataChan)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string, errChan chan error) {
|
func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatResponse, dataChan chan string) {
|
||||||
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
|
||||||
|
|
||||||
for i, candidate := range geminiResponse.Candidates {
|
|
||||||
choice := types.ChatCompletionStreamChoice{
|
|
||||||
Index: i,
|
|
||||||
Delta: types.ChatCompletionStreamChoiceDelta{
|
|
||||||
Content: candidate.Content.Parts[0].Text,
|
|
||||||
},
|
|
||||||
FinishReason: types.FinishReasonStop,
|
|
||||||
}
|
|
||||||
choices = append(choices, choice)
|
|
||||||
}
|
|
||||||
|
|
||||||
streamResponse := types.ChatCompletionStreamResponse{
|
streamResponse := types.ChatCompletionStreamResponse{
|
||||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: common.GetTimestamp(),
|
||||||
Model: h.Request.Model,
|
Model: h.Request.Model,
|
||||||
Choices: choices,
|
// Choices: choices,
|
||||||
}
|
}
|
||||||
|
|
||||||
responseBody, _ := json.Marshal(streamResponse)
|
choices := make([]types.ChatCompletionStreamChoice, 0, len(geminiResponse.Candidates))
|
||||||
dataChan <- string(responseBody)
|
|
||||||
|
for i, candidate := range geminiResponse.Candidates {
|
||||||
|
parts := candidate.Content.Parts[0]
|
||||||
|
|
||||||
|
choice := types.ChatCompletionStreamChoice{
|
||||||
|
Index: i,
|
||||||
|
Delta: types.ChatCompletionStreamChoiceDelta{
|
||||||
|
Role: types.ChatMessageRoleAssistant,
|
||||||
|
},
|
||||||
|
FinishReason: types.FinishReasonStop,
|
||||||
|
}
|
||||||
|
|
||||||
|
if parts.FunctionCall != nil {
|
||||||
|
if parts.FunctionCall.Args == nil {
|
||||||
|
parts.FunctionCall.Args = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
args, _ := json.Marshal(parts.FunctionCall.Args)
|
||||||
|
|
||||||
|
choice.Delta.ToolCalls = []*types.ChatCompletionToolCalls{
|
||||||
|
{
|
||||||
|
Id: "call_" + common.GetRandomString(24),
|
||||||
|
Type: types.ChatMessageRoleFunction,
|
||||||
|
Index: 0,
|
||||||
|
Function: &types.ChatCompletionToolCallsFunction{
|
||||||
|
Name: parts.FunctionCall.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choice.Delta.Content = parts.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
choices = append(choices, choice)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(choices) > 0 && choices[0].Delta.ToolCalls != nil {
|
||||||
|
choices := choices[0].ConvertOpenaiStream()
|
||||||
|
for _, choice := range choices {
|
||||||
|
chatCompletionCopy := streamResponse
|
||||||
|
chatCompletionCopy.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||||
|
responseBody, _ := json.Marshal(chatCompletionCopy)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
streamResponse.Choices = choices
|
||||||
|
responseBody, _ := json.Marshal(streamResponse)
|
||||||
|
dataChan <- string(responseBody)
|
||||||
|
}
|
||||||
|
|
||||||
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
|
||||||
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import "one-api/types"
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/image"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
type GeminiChatRequest struct {
|
type GeminiChatRequest struct {
|
||||||
Contents []GeminiChatContent `json:"contents"`
|
Contents []GeminiChatContent `json:"contents"`
|
||||||
@ -15,8 +21,41 @@ type GeminiInlineData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeminiPart struct {
|
type GeminiPart struct {
|
||||||
Text string `json:"text,omitempty"`
|
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionCall struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Args map[string]interface{} `json:"args,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionResponse struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Response GeminiFunctionResponseContent `json:"response,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionResponseContent struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GeminiFunctionCall) ToOpenAITool() []*types.ChatCompletionToolCalls {
|
||||||
|
args, _ := json.Marshal(g.Args)
|
||||||
|
|
||||||
|
return []*types.ChatCompletionToolCalls{
|
||||||
|
{
|
||||||
|
Id: "",
|
||||||
|
Type: types.ChatMessageRoleFunction,
|
||||||
|
Index: 0,
|
||||||
|
Function: &types.ChatCompletionToolCallsFunction{
|
||||||
|
Name: g.Name,
|
||||||
|
Arguments: string(args),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatContent struct {
|
type GeminiChatContent struct {
|
||||||
@ -30,7 +69,7 @@ type GeminiChatSafetySettings struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatTools struct {
|
type GeminiChatTools struct {
|
||||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
FunctionDeclarations []types.ChatCompletionFunction `json:"functionDeclarations,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiChatGenerationConfig struct {
|
type GeminiChatGenerationConfig struct {
|
||||||
@ -85,3 +124,91 @@ func (g *GeminiChatResponse) GetResponseText() string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenAIToGeminiChatContent(openaiContents []types.ChatCompletionMessage) ([]GeminiChatContent, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
contents := make([]GeminiChatContent, 0)
|
||||||
|
for _, openaiContent := range openaiContents {
|
||||||
|
content := GeminiChatContent{
|
||||||
|
Role: ConvertRole(openaiContent.Role),
|
||||||
|
Parts: make([]GeminiPart, 0),
|
||||||
|
}
|
||||||
|
content.Role = ConvertRole(openaiContent.Role)
|
||||||
|
if openaiContent.Role == types.ChatMessageRoleFunction {
|
||||||
|
contents = append(contents, GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []GeminiPart{
|
||||||
|
{
|
||||||
|
FunctionCall: &GeminiFunctionCall{
|
||||||
|
Name: *openaiContent.Name,
|
||||||
|
Args: map[string]interface{}{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
content = GeminiChatContent{
|
||||||
|
Role: "function",
|
||||||
|
Parts: []GeminiPart{
|
||||||
|
{
|
||||||
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
|
Name: *openaiContent.Name,
|
||||||
|
Response: GeminiFunctionResponseContent{
|
||||||
|
Name: *openaiContent.Name,
|
||||||
|
Content: openaiContent.StringContent(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
openaiMessagePart := openaiContent.ParseContent()
|
||||||
|
imageNum := 0
|
||||||
|
for _, openaiPart := range openaiMessagePart {
|
||||||
|
if openaiPart.Type == types.ContentTypeText {
|
||||||
|
content.Parts = append(content.Parts, GeminiPart{
|
||||||
|
Text: openaiPart.Text,
|
||||||
|
})
|
||||||
|
} else if openaiPart.Type == types.ContentTypeImageURL {
|
||||||
|
imageNum += 1
|
||||||
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
mimeType, data, err := image.GetImageFromUrl(openaiPart.ImageURL.URL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
content.Parts = append(content.Parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contents = append(contents, content)
|
||||||
|
if openaiContent.Role == types.ChatMessageRoleSystem {
|
||||||
|
contents = append(contents, GeminiChatContent{
|
||||||
|
Role: "model",
|
||||||
|
Parts: []GeminiPart{
|
||||||
|
{
|
||||||
|
Text: "Okay",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return contents, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertRole(roleName string) string {
|
||||||
|
switch roleName {
|
||||||
|
case types.ChatMessageRoleFunction, types.ChatMessageRoleTool:
|
||||||
|
return types.ChatMessageRoleFunction
|
||||||
|
case types.ChatMessageRoleAssistant:
|
||||||
|
return "model"
|
||||||
|
default:
|
||||||
|
return types.ChatMessageRoleUser
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user