diff --git a/common/constants.go b/common/constants.go index 45fc9535..60700ec8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,7 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 - ChannelTypeGeminiChat = 24 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ diff --git a/common/model-ratio.go b/common/model-ratio.go index a9d5bd51..c054fa5f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,7 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, - "gemini-pro": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index a9c1a1b4..3aaa4897 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,7 +20,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough - case common.ChannelTypeGeminiChat: + case common.ChannelTypeGemini: fallthrough case common.ChannelTypeAnthropic: fallthrough diff --git a/controller/relay-gemini-chat.go b/controller/relay-gemini.go similarity index 81% rename from controller/relay-gemini-chat.go rename to controller/relay-gemini.go index a6c8358f..455e30d8 100644 --- a/controller/relay-gemini-chat.go +++ b/controller/relay-gemini.go @@ -11,25 +11,36 @@ import ( ) type GeminiChatRequest struct { - Contents []GeminiChatContents `json:"contents"` + Contents []GeminiChatContent `json:"contents"` SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` Tools []GeminiChatTools `json:"tools,omitempty"` } -type GeminiChatParts struct { - Text string `json:"text"` + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` } -type GeminiChatContents struct { - Role string `json:"role"` - Parts []GeminiChatParts `json:"parts"` + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` } + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + type GeminiChatSafetySettings struct { Category string `json:"category"` Threshold string `json:"threshold"` } + type GeminiChatTools struct { FunctionDeclarations any `json:"functionDeclarations,omitempty"` } + type GeminiChatGenerationConfig struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"topP,omitempty"` @@ -40,43 +51,45 @@ type GeminiChatGenerationConfig struct { } // Setting safety to the lowest possible values since Gemini is already powerless enough -func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatRequest { +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { geminiRequest := GeminiChatRequest{ - Contents: make([]GeminiChatContents, 0, len(textRequest.Messages)), - SafetySettings: []GeminiChatSafetySettings{ - { - Category: "HARM_CATEGORY_HARASSMENT", - Threshold: "BLOCK_ONLY_HIGH", - }, - { - Category: "HARM_CATEGORY_HATE_SPEECH", - Threshold: "BLOCK_ONLY_HIGH", - }, - { - Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", - Threshold: "BLOCK_ONLY_HIGH", - }, - { - Category: "HARM_CATEGORY_DANGEROUS_CONTENT", - Threshold: "BLOCK_ONLY_HIGH", - }, - }, + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, GenerationConfig: GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, MaxOutputTokens: textRequest.MaxTokens, }, - Tools: []GeminiChatTools{ + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ { FunctionDeclarations: textRequest.Functions, }, - }, + } } shouldAddDummyModelMessage := false for _, message := range textRequest.Messages { - content := GeminiChatContents{ + content := GeminiChatContent{ Role: message.Role, - Parts: []GeminiChatParts{ + Parts: []GeminiPart{ { Text: message.StringContent(), }, @@ -95,9 +108,9 @@ func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatReque // If a system message is the last message, we need to add a dummy model message to make gemini happy if shouldAddDummyModelMessage { - geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContents{ + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ Role: "model", - Parts: []GeminiChatParts{ + Parts: []GeminiPart{ { Text: "ok", }, @@ -122,15 +135,6 @@ type GeminiChatCandidate struct { SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` } -type GeminiChatContent struct { - Parts []GeminiChatPart `json:"parts"` - Role string `json:"role"` -} - -type GeminiChatPart struct { - Text string `json:"text"` -} - type GeminiChatSafetyRating struct { Category string `json:"category"` Probability string `json:"probability"` @@ -142,6 +146,9 @@ type GeminiChatPromptFeedback struct { func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { @@ -151,7 +158,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse Role: "assistant", Content: candidate.Content.Parts[0].Text, }, - FinishReason: "stop", + FinishReason: stopFinishReason, } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -160,7 +167,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { var choice ChatCompletionsStreamResponseChoice - if len(geminiResponse.Candidates) > 0 { + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text } choice.FinishReason = &stopFinishReason @@ -200,7 +207,9 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorW fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) fullTextResponse.Id = responseId fullTextResponse.Created = createdTime - responseText = geminiResponse.Candidates[0].Content.Parts[0].Text + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + responseText += geminiResponse.Candidates[0].Content.Parts[0].Text + } jsonResponse, err := json.Marshal(fullTextResponse) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) diff --git a/controller/relay-text.go b/controller/relay-text.go index 662d05ec..211a34b3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,7 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent - APITypeGeminiChat + APITypeGemini ) var httpClient *http.Client @@ -119,8 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent - case common.ChannelTypeGeminiChat: - apiType = APITypeGeminiChat + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -180,11 +180,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey - case APITypeGeminiChat: - switch textRequest.Model { - case "gemini-pro": - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + // actually gemini does not support stream, it's fake + //if textRequest.Stream { + // action = "streamGenerateContent" + //} + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey @@ -285,8 +295,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) - case APITypeGeminiChat: - geminiChatRequest := requestOpenAI2GeminiChat(textRequest) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) jsonStr, err := json.Marshal(geminiChatRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -385,7 +395,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header - case APITypeGeminiChat: + case APITypeGemini: // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) @@ -547,8 +557,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } - case APITypeGeminiChat: - if textRequest.Stream { // Gemini API does not support stream + case APITypeGemini: + if textRequest.Stream { err, responseText := geminiChatStreamHandler(c, resp) if err != nil { return err diff --git a/middleware/distributor.go b/middleware/distributor.go index 8be986c9..81338130 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) case common.ChannelTypeAli: