From 5cf23d8698cfebd720a6722e4b89a839c8a804c0 Mon Sep 17 00:00:00 2001 From: David Zhuang Date: Sat, 16 Dec 2023 23:48:32 -0500 Subject: [PATCH 01/11] feat: add Google Gemini Pro support (#826) * fest: Add Google Gemini Pro, fix #810 * fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini * refactor: removing unused if statement * fest: Add dummy model message for system message in gemini model * chore: update implementation --------- Co-authored-by: JustSong --- README.en.md | 2 +- README.ja.md | 2 +- README.md | 2 +- common/constants.go | 2 + common/model-ratio.go | 1 + controller/channel-test.go | 2 + controller/model.go | 9 + controller/relay-gemini.go | 281 +++++++++++++++++++++++++ controller/relay-text.go | 49 +++++ middleware/distributor.go | 2 + web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 12 files changed, 353 insertions(+), 3 deletions(-) create mode 100644 controller/relay-gemini.go diff --git a/README.en.md b/README.en.md index 9345a219..82dceb5b 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee..089fc2b5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b38..8a1d6caf 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f6860f67..60700ec8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "", //24 } diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..c054fa5f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,6 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..3aaa4897 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d..5c8aebc0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 00000000..455e30d8 --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,281 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "ok", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: stopFinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + responseText += geminiResponse.Candidates[0].Content.Parts[0].Text + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8b..211a34b3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -177,6 +180,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + // actually gemini does not support stream, it's fake + //if textRequest.Stream { + // action = "streamGenerateContent" + //} + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +295,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -367,6 +395,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGemini: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -527,6 +557,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/middleware/distributor.go b/middleware/distributor.go index 8be986c9..81338130 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) case common.ChannelTypeAli: diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..264dbefb 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 62e8a155..114e5933 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } From 58dee76bf7828718021f924dd741c3754515cd7c Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 16:16:18 +0800 Subject: [PATCH 02/11] fix: fix Gemini stream problem --- controller/relay-gemini.go | 81 ++++++++++++++++++++++---------------- controller/relay-text.go | 7 ++-- controller/relay.go | 2 +- 3 files changed, 51 insertions(+), 39 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 455e30d8..4c2daba9 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -1,11 +1,13 @@ package controller import ( + "bufio" "encoding/json" "fmt" "io" "net/http" "one-api/common" + "strings" "github.com/gin-gonic/gin" ) @@ -180,50 +182,61 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCo func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) - createdTime := common.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) go func() { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - common.SysError("error reading stream response: " + err.Error()) - stopChan <- true - return + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data } - err = resp.Body.Close() - if err != nil { - common.SysError("error closing stream response: " + err.Error()) - stopChan <- true - return - } - var geminiResponse GeminiChatResponse - err = json.Unmarshal(responseBody, &geminiResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - stopChan <- true - return - } - fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) - fullTextResponse.Id = responseId - fullTextResponse.Created = createdTime - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - responseText += geminiResponse.Candidates[0].Content.Parts[0].Text - } - jsonResponse, err := json.Marshal(fullTextResponse) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - stopChan <- true - return - } - dataChan <- string(jsonResponse) stopChan <- true }() setEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - c.Render(-1, common.CustomEvent{Data: "data: " + data}) + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) diff --git a/controller/relay-text.go b/controller/relay-text.go index 211a34b3..b53b0caa 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -190,10 +190,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { version = c.GetString("api_version") } action := "generateContent" - // actually gemini does not support stream, it's fake - //if textRequest.Stream { - // action = "streamGenerateContent" - //} + if textRequest.Stream { + action = "streamGenerateContent" + } fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") diff --git a/controller/relay.go b/controller/relay.go index 0e660a68..15021997 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` } `json:"delta"` - FinishReason *string `json:"finish_reason"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { From 7069c49bdf4e4ae39ab4944faaa5fd66b121d3fb Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:06:37 +0800 Subject: [PATCH 03/11] fix: fix xunfei panic error (close #820) --- controller/relay-xunfei.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 00ec8981..904e6d14 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin case stop = <-stopChan: } } - + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) From 6acb9537a921008813eb691e021aa87f316fb82e Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 18:33:27 +0800 Subject: [PATCH 04/11] fix: try to return a more meaningful error message (close #817) --- controller/relay-utils.go | 57 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 9deca75a..a6a1f0f6 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -263,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +type GeneralErrorResponse struct { + Error OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), @@ -281,12 +322,20 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr if err != nil { return } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) if err != nil { return } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + openAIErrorWithStatusCode.OpenAIError = errResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() + } + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } return } From 66f06e5d6f14828c3522fb59bfca35c612685bef Mon Sep 17 00:00:00 2001 From: Ghostz <137054651+ye4293@users.noreply.github.com> Date: Sun, 17 Dec 2023 18:54:08 +0800 Subject: [PATCH 05/11] feat: reset image num to 1 when not given (#821) * Update relay-image.go * fix: reset image num to 1 when not given --------- Co-authored-by: JustSong --- controller/relay-image.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index b3248fcc..bf916474 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -19,7 +19,6 @@ func isWithinRange(element string, value int) bool { if _, ok := common.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] max := common.DalleGenerationImageAmounts[element][1] @@ -42,6 +41,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + // Size validation if imageRequest.Size != "" { imageSize = imageRequest.Size From 7d6a169669f1eb2666abb0747964ab1cd9f86f12 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Sun, 17 Dec 2023 19:17:00 +0800 Subject: [PATCH 06/11] feat: able to set sqlite busy_timeout (#818) * add sqlite busy_timeout=3000 * chore: update impl --------- Co-authored-by: JustSong --- README.md | 1 + common/database.go | 1 + model/main.go | 4 +++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8a1d6caf..916e4331 100644 --- a/README.md +++ b/README.md @@ -371,6 +371,7 @@ graph LR + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/database.go b/common/database.go index c7e9fd52..76f2cd55 100644 --- a/common/database.go +++ b/common/database.go @@ -4,3 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/model/main.go b/model/main.go index 08182634..bfd6888b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } From 0fe26cc4bd9d2e85b47ac6ae23fc974295c72f52 Mon Sep 17 00:00:00 2001 From: Oliver Lee Date: Sun, 17 Dec 2023 19:43:23 +0800 Subject: [PATCH 07/11] feat: update ali relay implementation (#830) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修改通译千问最新接口:1.删除history参数,改用官方推荐的messages参数 2.整理messages参数顺序,补充必要上下文信息 3.用autogen调试测试通过 * chore: update impl --------- Co-authored-by: JustSong --- common/model-ratio.go | 6 +++-- controller/model.go | 18 +++++++++++++++ controller/relay-ali.go | 33 ++++++++-------------------- web/src/pages/Channel/EditChannel.js | 2 +- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index c054fa5f..d1c96d96 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -88,8 +88,10 @@ var ModelRatio = map[string]float64{ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": 1.4286, // ¥0.02 / 1k tokens + "qwen-max": 1.4286, // ¥0.02 / 1k tokens + "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 5c8aebc0..9ae40f5c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -486,6 +486,24 @@ func init() { Root: "qwen-plus", Parent: nil, }, + { + Id: "qwen-max", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max", + Parent: nil, + }, + { + Id: "qwen-max-longcontext", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max-longcontext", + Parent: nil, + }, { Id: "text-embedding-v1", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index b41ca327..65626f6a 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -13,13 +13,13 @@ import ( // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` + Content string `json:"content"` + Role string `json:"role"` } type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` + //Prompt string `json:"prompt"` + Messages []AliMessage `json:"messages"` } type AliParameters struct { @@ -83,32 +83,17 @@ type AliChatResponse struct { func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } + messages = append(messages, AliMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) } return &AliChatRequest{ Model: request.Model, Input: AliInput{ - Prompt: prompt, - History: messages, + Messages: messages, }, //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's // TopP: request.TopP, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 114e5933..364da69d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -69,7 +69,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: - localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; + localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; From bc6769826bb827ca323beac6118bfd15e829fee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Sun, 17 Dec 2023 19:49:08 +0800 Subject: [PATCH 08/11] feat: add condition to validate n value for non-Azure channels (#775) - Add a condition to validate the n value only for non-Azure channels, ensuring it falls within the acceptable range. - Fix Azure compatibility --- controller/relay-image.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index bf916474..7e1fed39 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -82,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + // channel not azure + if channelType != common.ChannelTypeAzure { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } } // map model name @@ -105,7 +108,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview From af378c59afe3510ebc28b7f081a7a8ccfa8de891 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:19:16 +0800 Subject: [PATCH 09/11] docs: update readme --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 916e4331..1a967ace 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [OpenAI-SB](https://openai-sb.com) + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) - + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) + [x] 自定义渠道:例如各种未收录的第三方代理服务 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 From 461f5dab561ff09d1bada624058c943daa7b9f99 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 22:25:03 +0800 Subject: [PATCH 10/11] docs: update readme --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 1a967ace..ff9e0bc0 100644 --- a/README.md +++ b/README.md @@ -73,11 +73,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) -2. 支持配置镜像以及众多第三方代理服务: - + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) - + [x] [API2D](https://api2d.com/r/197971) - + [x] 自定义渠道:例如各种未收录的第三方代理服务 +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 From 97030e27f88ac27a9ab4bcc2484f7d8e83d29d04 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 17 Dec 2023 23:30:45 +0800 Subject: [PATCH 11/11] fix: fix gemini panic (close #833) --- controller/relay-gemini.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go index 4c2daba9..2458458e 100644 --- a/controller/relay-gemini.go +++ b/controller/relay-gemini.go @@ -114,7 +114,7 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { Role: "model", Parts: []GeminiPart{ { - Text: "ok", + Text: "Okay", }, }, }) @@ -130,6 +130,16 @@ type GeminiChatResponse struct { PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` } +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + type GeminiChatCandidate struct { Content GeminiChatContent `json:"content"` FinishReason string `json:"finishReason"` @@ -158,10 +168,13 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content.Parts[0].Text, + Content: "", }, FinishReason: stopFinishReason, } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } return &fullTextResponse @@ -169,9 +182,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { var choice ChatCompletionsStreamResponseChoice - if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { - choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text - } + choice.Delta.Content = geminiResponse.GetResponseText() choice.FinishReason = &stopFinishReason var response ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -276,7 +287,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + completionTokens := countTokenText(geminiResponse.GetResponseText(), model) usage := Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens,