diff --git a/README.en.md b/README.en.md index 9345a219..82dceb5b 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee..089fc2b5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b38..ff9e0bc0 100644 --- a/README.md +++ b/README.md @@ -66,20 +66,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) -2. 支持配置镜像以及众多第三方代理服务: - + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) - + [x] [API2D](https://api2d.com/r/197971) - + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) - + [x] 自定义渠道:例如各种未收录的第三方代理服务 +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 @@ -371,6 +365,7 @@ graph LR + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index f6860f67..60700ec8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "", //24 } diff --git a/common/database.go b/common/database.go index c7e9fd52..76f2cd55 100644 --- a/common/database.go +++ b/common/database.go @@ -4,3 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 00000000..cbb656ad --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,47 @@ +package image + +import ( + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + + _ "golang.org/x/image/webp" +) + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") + base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) + img, _, err := image.DecodeConfig(base64) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 00000000..366eda6e --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,154 @@ +package image_test + +import ( + "encoding/base64" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05dd..d1c96d96 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,12 +83,15 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": 1.4286, // ¥0.02 / 1k tokens + "qwen-max": 1.4286, // ¥0.02 / 1k tokens + "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657..3aaa4897 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d..9ae40f5c 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", @@ -477,6 +486,24 @@ func init() { Root: "qwen-plus", Parent: nil, }, + { + Id: "qwen-max", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max", + Parent: nil, + }, + { + Id: "qwen-max-longcontext", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max-longcontext", + Parent: nil, + }, { Id: "text-embedding-v1", Object: "model", diff --git a/controller/relay-ali.go b/controller/relay-ali.go index b41ca327..65626f6a 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -13,13 +13,13 @@ import ( // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` + Content string `json:"content"` + Role string `json:"role"` } type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` + //Prompt string `json:"prompt"` + Messages []AliMessage `json:"messages"` } type AliParameters struct { @@ -83,32 +83,17 @@ type AliChatResponse struct { func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } + messages = append(messages, AliMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) } return &AliChatRequest{ Model: request.Model, Input: AliInput{ - Prompt: prompt, - History: messages, + Messages: messages, }, //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's // TopP: request.TopP, diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..2247f4c7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + default: + return errorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse WhisperJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 00000000..2458458e --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,305 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "strings" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: "", + }, + FinishReason: stopFinishReason, + } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = geminiResponse.GetResponseText() + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + dataChan := make(chan string) + stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + go func() { + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.GetResponseText(), model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-image.go b/controller/relay-image.go index b3248fcc..7e1fed39 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -19,7 +19,6 @@ func isWithinRange(element string, value int) bool { if _, ok := common.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] max := common.DalleGenerationImageAmounts[element][1] @@ -42,6 +41,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + // Size validation if imageRequest.Size != "" { imageSize = imageRequest.Size @@ -79,7 +82,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode // Number of generated images validation if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + // channel not azure + if channelType != common.ChannelTypeAzure { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } } // map model name @@ -102,7 +108,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := GetAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview diff --git a/controller/relay-text.go b/controller/relay-text.go index a3e233d3..b53b0caa 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -177,6 +180,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + if textRequest.Stream { + action = "streamGenerateContent" + } + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +294,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -360,10 +387,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + if c.GetString("plugin") != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + } case APITypeTencent: req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGemini: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -524,6 +556,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 38408c7f..a6a1f0f6 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,10 +3,13 @@ package controller import ( "context" "encoding/json" + "errors" "fmt" "io" + "math" "net/http" "one-api/common" + "one-api/common/image" "one-api/model" "strconv" "strings" @@ -87,7 +90,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.StringContent()) + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail) + if err != nil { + common.SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -98,13 +127,81 @@ func countTokenMessages(messages []Message, model string) int { return tokenNum } +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + func countTokenInput(input any, model string) int { - switch input.(type) { + switch v := input.(type) { case string: - return countTokenText(input.(string), model) + return countTokenText(v, model) case []string: text := "" - for _, s := range input.([]string) { + for _, s := range v { text += s } return countTokenText(text, model) @@ -166,11 +263,52 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } +type GeneralErrorResponse struct { + Error OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), @@ -184,12 +322,20 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr if err != nil { return } - var textResponse TextResponse - err = json.Unmarshal(responseBody, &textResponse) + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) if err != nil { return } - openAIErrorWithStatusCode.OpenAIError = textResponse.Error + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + openAIErrorWithStatusCode.OpenAIError = errResponse.Error + } else { + openAIErrorWithStatusCode.OpenAIError.Message = errResponse.ToMessage() + } + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } return } diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 00ec8981..904e6d14 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -230,7 +230,13 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin case stop = <-stopChan: } } - + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } xunfeiResponse.Payload.Choices.Text[0].Content = content response := responseXunfei2OpenAI(&xunfeiResponse) diff --git a/controller/relay.go b/controller/relay.go index 58ee8381..15021997 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJSONResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"` @@ -215,7 +236,7 @@ type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"` } `json:"delta"` - FinishReason *string `json:"finish_reason"` + FinishReason *string `json:"finish_reason,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/go.mod b/go.mod index 10b78d68..1fe5eabc 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.14.0 + golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -26,6 +28,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -50,12 +53,13 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4865bcaa..fb252aa7 100644 --- a/go.sum +++ b/go.sum @@ -152,6 +152,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= @@ -168,8 +170,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..81338130 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,8 +87,12 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } c.Next() } diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..c3a3d748 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + common.SysError(fmt.Sprintf("panic detected: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/model/main.go b/model/main.go index 08182634..bfd6888b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } diff --git a/router/relay-router.go b/router/relay-router.go index 24edc9a9..56ab9b28 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 76407745..264dbefb 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index bc3886a0..364da69d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -69,7 +69,7 @@ const EditChannel = () => { localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: - localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; + localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1']; break; case 16: localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); } @@ -343,6 +346,20 @@ const EditChannel = () => { ) } + { + inputs.type === 17 && ( + + + + ) + }