From 6395c7377eee09b9918b4ad717064d390b244f44 Mon Sep 17 00:00:00 2001 From: David Zhuang Date: Sat, 16 Dec 2023 01:33:46 -0500 Subject: [PATCH] fest: Add Google Gemini Pro, fix #810 --- README.en.md | 2 +- README.ja.md | 2 +- README.md | 2 +- common/constants.go | 1 + common/model-ratio.go | 1 + controller/channel-test.go | 2 + controller/model.go | 9 + controller/relay-gemini-chat.go | 238 +++++++++++++++++++++++++ controller/relay-text.go | 43 +++++ web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 11 files changed, 301 insertions(+), 3 deletions(-) create mode 100644 controller/relay-gemini-chat.go diff --git a/README.en.md b/README.en.md index 9345a219..82dceb5b 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee..089fc2b5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b38..8a1d6caf 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f6860f67..ac8194de 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGeminiChat = 24 ) var ChannelBaseURLs = []string{ diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..a9d5bd51 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,6 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..a9c1a1b4 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGeminiChat: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d..5c8aebc0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini-chat.go b/controller/relay-gemini-chat.go new file mode 100644 index 00000000..c1a9e63d --- /dev/null +++ b/controller/relay-gemini-chat.go @@ -0,0 +1,238 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContents `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config"` +} +type GeminiChatParts struct { + Text string `json:"text"` +} +type GeminiChatContents struct { + Role string `json:"role"` + Parts GeminiChatParts `json:"parts"` +} +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature"` + TopP float64 `json:"topP"` + TopK int `json:"topK"` + MaxOutputTokens int `json:"maxOutputTokens"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2GeminiChat(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContents, 0, len(textRequest.Messages)), + SafetySettings: []GeminiChatSafetySettings{ + { + Category: "HARM_CATEGORY_HARASSMENT", + Threshold: "BLOCK_ONLY_HIGH", + }, + { + Category: "HARM_CATEGORY_HATE_SPEECH", + Threshold: "BLOCK_ONLY_HIGH", + }, + { + Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + Threshold: "BLOCK_ONLY_HIGH", + }, + { + Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + Threshold: "BLOCK_ONLY_HIGH", + }, + }, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + TopK: textRequest.MaxTokens, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + for _, message := range textRequest.Messages { + content := GeminiChatContents{ + Role: message.Role, + Parts: GeminiChatParts{ + Text: message.StringContent(), + }, + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + } + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatContent struct { + Parts []GeminiChatPart `json:"parts"` + Role string `json:"role"` +} + +type GeminiChatPart struct { + Text string `json:"text"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: "stop", + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(geminiResponse.Candidates) > 0 { + choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(geminiResponse.Candidates) > 0 { + responseText = geminiResponse.Candidates[0].Content.Parts[0].Text + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..e14e18b8 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGeminiChat ) var httpClient *http.Client @@ -177,6 +178,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGeminiChat: + requestURLSuffix := "/v1beta/models/gemini-pro:generateContent" + switch textRequest.Model { + case "gemini-pro": + requestURLSuffix = "/v1beta/models/gemini-pro:generateContent" + } + if baseURL != "" { + fullRequestURL = fmt.Sprintf("%s%s", baseURL, requestURLSuffix) + } else { + fullRequestURL = fmt.Sprintf("https://generativelanguage.googleapis.com%s", requestURLSuffix) + } + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +289,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGeminiChat: + geminiChatRequest := requestOpenAI2GeminiChat(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -367,6 +389,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGeminiChat: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -527,6 +551,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGeminiChat: + if textRequest.Stream { // Gemini API does not support stream + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..264dbefb 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 62e8a155..114e5933 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); }