diff --git a/README.en.md b/README.en.md index fb54e5f3..1108e615 100644 --- a/README.en.md +++ b/README.en.md @@ -63,7 +63,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: - [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) - [x] [Anthropic Claude Series Models](https://anthropic.com) - - [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + - [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) - [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) - [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) - [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index b0f7e660..a0bb5e78 100644 --- a/README.ja.md +++ b/README.ja.md @@ -63,7 +63,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: - [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) - [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - - [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + - [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) - [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) - [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) - [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index fee730c1..9f44adfe 100644 --- a/README.md +++ b/README.md @@ -52,15 +52,15 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 赞赏支持

-> **Note** +> [!NOTE] > 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 > > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> **Warning** +> [!WARNING] > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 -> **Warning** +> [!WARNING] > 使用 root 用户初次登录系统后,务必修改默认密码 `123456`! ## 功能 @@ -68,20 +68,14 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: - [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) - [x] [Anthropic Claude 系列模型](https://anthropic.com) - - [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + - [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) - [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) - [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) - [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) - [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) - [x] [360 智脑](https://ai.360.cn) - [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) -2. 支持配置镜像以及众多第三方代理服务: - - [x] [OpenAI-SB](https://openai-sb.com) - - [x] [CloseAI](https://referer.shadowai.xyz/r/2412) - - [x] [API2D](https://api2d.com/r/197971) - - [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) - - [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) - - [x] 自定义渠道:例如各种未收录的第三方代理服务 +2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 5. 支持**多机部署**,[详见此处](#多机部署)。 @@ -389,6 +383,7 @@ graph LR - `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 - `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 ### 命令行参数 diff --git a/common/client.go b/common/client.go index d482eeec..91903ac0 100644 --- a/common/client.go +++ b/common/client.go @@ -130,12 +130,53 @@ func SendRequest(req *http.Request, response any, outputResp bool) (*http.Respon return nil, nil } +type GeneralErrorResponse struct { + Error types.OpenAIError `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + // 处理错误响应 func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { openAIErrorWithStatusCode = &types.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, OpenAIError: types.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Message: "", Type: "upstream_error", Code: "bad_response_status_code", Param: strconv.Itoa(resp.StatusCode), @@ -149,16 +190,23 @@ func HandleErrorResp(resp *http.Response) (openAIErrorWithStatusCode *types.Open if err != nil { return } - var errorResponse types.OpenAIErrorResponse + // var errorResponse types.OpenAIErrorResponse + var errorResponse GeneralErrorResponse err = json.Unmarshal(responseBody, &errorResponse) if err != nil { return } - if errorResponse.Error.Type != "" { + + if errorResponse.Error.Message != "" { + // OpenAI format error, so we override the default one openAIErrorWithStatusCode.OpenAIError = errorResponse.Error } else { - openAIErrorWithStatusCode.OpenAIError.Message = string(responseBody) + openAIErrorWithStatusCode.OpenAIError.Message = errorResponse.ToMessage() } + if openAIErrorWithStatusCode.OpenAIError.Message == "" { + openAIErrorWithStatusCode.OpenAIError.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return } diff --git a/common/constants.go b/common/constants.go index 527d3a28..a9265b86 100644 --- a/common/constants.go +++ b/common/constants.go @@ -78,6 +78,7 @@ var QuotaForInviter = 0 var QuotaForInvitee = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false @@ -187,6 +188,7 @@ const ( ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 ChannelTypeAzureSpeech = 24 + ChannelTypeGemini = 25 ) var ChannelBaseURLs = []string{ @@ -213,8 +215,9 @@ var ChannelBaseURLs = []string{ "https://openrouter.ai/api", // 20 "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", // 23 - "", // 24 + "https://hunyuan.cloud.tencent.com", //23 + "", //24 + "", //25 } const ( diff --git a/common/database.go b/common/database.go index c7e9fd52..76f2cd55 100644 --- a/common/database.go +++ b/common/database.go @@ -4,3 +4,4 @@ var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" +var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/email.go b/common/email.go index 74f4cccd..b915f0f9 100644 --- a/common/email.go +++ b/common/email.go @@ -1,11 +1,13 @@ package common import ( + "crypto/rand" "crypto/tls" "encoding/base64" "fmt" "net/smtp" "strings" + "time" ) func SendEmail(subject string, receiver string, content string) error { @@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error { SMTPFrom = SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) + + // Extract domain from SMTPFrom + parts := strings.Split(SMTPFrom, "@") + var domain string + if len(parts) > 1 { + domain = parts[1] + } + // Generate a unique Message-ID + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + return err + } + messageId := fmt.Sprintf("<%x@%s>", buf, domain) + mail := []byte(fmt.Sprintf("To: %s\r\n"+ "From: %s<%s>\r\n"+ "Subject: %s\r\n"+ + "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 + "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, content)) + receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) to := strings.Split(receiver, ";") - var err error + if SMTPPort == 465 { tlsConfig := &tls.Config{ InsecureSkipVerify: true, diff --git a/common/image/image.go b/common/image/image.go new file mode 100644 index 00000000..cbb656ad --- /dev/null +++ b/common/image/image.go @@ -0,0 +1,47 @@ +package image + +import ( + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "net/http" + "regexp" + "strings" + + _ "golang.org/x/image/webp" +) + +func GetImageSizeFromUrl(url string) (width int, height int, err error) { + resp, err := http.Get(url) + if err != nil { + return + } + defer resp.Body.Close() + img, _, err := image.DecodeConfig(resp.Body) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +var ( + reg = regexp.MustCompile(`data:image/([^;]+);base64,`) +) + +func GetImageSizeFromBase64(encoded string) (width int, height int, err error) { + encoded = strings.TrimPrefix(encoded, "data:image/png;base64,") + base64 := strings.NewReader(reg.ReplaceAllString(encoded, "")) + img, _, err := image.DecodeConfig(base64) + if err != nil { + return + } + return img.Width, img.Height, nil +} + +func GetImageSize(image string) (width int, height int, err error) { + if strings.HasPrefix(image, "data:image/") { + return GetImageSizeFromBase64(image) + } + return GetImageSizeFromUrl(image) +} diff --git a/common/image/image_test.go b/common/image/image_test.go new file mode 100644 index 00000000..366eda6e --- /dev/null +++ b/common/image/image_test.go @@ -0,0 +1,154 @@ +package image_test + +import ( + "encoding/base64" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" + "io" + "net/http" + "strconv" + "strings" + "testing" + + img "one-api/common/image" + + "github.com/stretchr/testify/assert" + _ "golang.org/x/image/webp" +) + +type CountingReader struct { + reader io.Reader + BytesRead int +} + +func (r *CountingReader) Read(p []byte) (n int, err error) { + n, err = r.reader.Read(p) + r.BytesRead += n + return n, err +} + +var ( + cases = []struct { + url string + format string + width int + height int + }{ + {"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669}, + {"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592}, + {"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985}, + {"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533}, + {"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230}, + } +) + +func TestDecode(t *testing.T) { + // Bytes read: varies sometimes + // jpeg: 1063892 + // png: 294462 + // webp: 99529 + // gif: 956153 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 4096 + // png: 4096 + // webp: 4096 + // gif: 4096 + // jpeg#01: 4096 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + reader := &CountingReader{reader: resp.Body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestBase64(t *testing.T) { + // Bytes read: + // jpeg: 1063892 + // png: 294462 + // webp: 99072 + // gif: 953856 + // jpeg#01: 32805 + for _, c := range cases { + t.Run("Decode:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + img, format, err := image.Decode(reader) + assert.NoError(t, err) + size := img.Bounds().Size() + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, size.X) + assert.Equal(t, c.height, size.Y) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } + + // Bytes read: + // jpeg: 1536 + // png: 768 + // webp: 768 + // gif: 1536 + // jpeg#01: 3840 + for _, c := range cases { + t.Run("DecodeConfig:"+c.format, func(t *testing.T) { + resp, err := http.Get(c.url) + assert.NoError(t, err) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + encoded := base64.StdEncoding.EncodeToString(data) + body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + reader := &CountingReader{reader: body} + config, format, err := image.DecodeConfig(reader) + assert.NoError(t, err) + assert.Equal(t, c.format, format) + assert.Equal(t, c.width, config.Width) + assert.Equal(t, c.height, config.Height) + t.Logf("Bytes read: %d", reader.BytesRead) + }) + } +} + +func TestGetImageSize(t *testing.T) { + for i, c := range cases { + t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) { + width, height, err := img.GetImageSize(c.url) + assert.NoError(t, err) + assert.Equal(t, c.width, width) + assert.Equal(t, c.height, height) + }) + } +} diff --git a/common/init.go b/common/init.go index 1e8b4dcd..7e3fb16f 100644 --- a/common/init.go +++ b/common/init.go @@ -43,7 +43,11 @@ func init() { } if os.Getenv("SESSION_SECRET") != "" { - SessionSecret = os.Getenv("SESSION_SECRET") + if os.Getenv("SESSION_SECRET") == "random_string" { + SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + } else { + SessionSecret = os.Getenv("SESSION_SECRET") + } } if os.Getenv("SQLITE_PATH") != "" { SQLitePath = os.Getenv("SQLITE_PATH") diff --git a/common/model-ratio.go b/common/model-ratio.go index 74c74a90..d1c96d96 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -76,17 +76,22 @@ var ModelRatio = map[string]float64{ "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens + "claude-2.0": 5.51, // $11.02 / 1M tokens + "claude-2.1": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens - "qwen-plus": 10, // ¥0.14 / 1k tokens + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": 1.4286, // ¥0.02 / 1k tokens + "qwen-max": 1.4286, // ¥0.02 / 1k tokens + "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens diff --git a/common/token.go b/common/token.go index 59d58343..887989ad 100644 --- a/common/token.go +++ b/common/token.go @@ -3,8 +3,10 @@ package common import ( "errors" "fmt" + "math" "strings" + "one-api/common/image" "one-api/types" "github.com/pkoukk/tiktoken-go" @@ -79,6 +81,33 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage + switch v := message.Content.(type) { + case string: + tokenNum += getTokenNum(tokenEncoder, v) + case []any: + for _, it := range v { + m := it.(map[string]any) + switch m["type"] { + case "text": + tokenNum += getTokenNum(tokenEncoder, m["text"].(string)) + case "image_url": + imageUrl, ok := m["image_url"].(map[string]any) + if ok { + url := imageUrl["url"].(string) + detail := "" + if imageUrl["detail"] != nil { + detail = imageUrl["detail"].(string) + } + imageTokens, err := countImageTokens(url, detail) + if err != nil { + SysError("error counting image tokens: " + err.Error()) + } else { + tokenNum += imageTokens + } + } + } + } + } tokenNum += getTokenNum(tokenEncoder, message.StringContent()) tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { @@ -90,16 +119,84 @@ func CountTokenMessages(messages []types.ChatCompletionMessage, model string) in return tokenNum } +const ( + lowDetailCost = 85 + highDetailCostPerTile = 170 + additionalCost = 85 +) + +// https://platform.openai.com/docs/guides/vision/calculating-costs +// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb +func countImageTokens(url string, detail string) (_ int, err error) { + var fetchSize = true + var width, height int + // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding + // detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting. + // According to the official guide, "low" disable the high-res model, + // and only receive low-res 512px x 512px version of the image, indicating + // that image is treated as low-res when size is smaller than 512px x 512px, + // then we can assume that image size larger than 512px x 512px is treated + // as high-res. Then we have the following logic: + // if detail == "" || detail == "auto" { + // width, height, err = image.GetImageSize(url) + // if err != nil { + // return 0, err + // } + // fetchSize = false + // // not sure if this is correct + // if width > 512 || height > 512 { + // detail = "high" + // } else { + // detail = "low" + // } + // } + + // However, in my test, it seems to be always the same as "high". + // The following image, which is 125x50, is still treated as high-res, taken + // 255 tokens in the response of non-stream chat completion api. + // https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg + if detail == "" || detail == "auto" { + // assume by test, not sure if this is correct + detail = "high" + } + switch detail { + case "low": + return lowDetailCost, nil + case "high": + if fetchSize { + width, height, err = image.GetImageSize(url) + if err != nil { + return 0, err + } + } + if width > 2048 || height > 2048 { // max(width, height) > 2048 + ratio := float64(2048) / math.Max(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + if width > 768 && height > 768 { // min(width, height) > 768 + ratio := float64(768) / math.Min(float64(width), float64(height)) + width = int(float64(width) * ratio) + height = int(float64(height) * ratio) + } + numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) + result := numSquares*highDetailCostPerTile + additionalCost + return result, nil + default: + return 0, errors.New("invalid detail option") + } +} + func CountTokenInput(input any, model string) int { - switch input.(type) { + switch v := input.(type) { case string: - return CountTokenText(input.(string), model) + return CountTokenInput(v, model) case []string: text := "" - for _, s := range input.([]string) { + for _, s := range v { text += s } - return CountTokenText(text, model) + return CountTokenInput(text, model) } return 0 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 1f3d5b71..e8024778 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -68,11 +68,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e } promptTokens := common.CountTokenMessages(request.Messages, request.Model) - _, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) + Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) if openAIErrorWithStatusCode != nil { return nil, &openAIErrorWithStatusCode.OpenAIError } + if Usage.CompletionTokens == 0 { + return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil + } + return nil, nil } @@ -134,20 +138,32 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { +func notifyRootUser(subject string, content string) { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) err := common.SendEmail(subject, common.RootUserEmail, content) if err != nil { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } +// disable & notify +func disableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +// enable & notify +func enableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} + func testAllChannels(notify bool) error { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() @@ -170,9 +186,7 @@ func testAllChannels(notify bool) error { } go func() { for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { - continue - } + isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() @@ -181,9 +195,12 @@ func testAllChannels(notify bool) error { err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } + if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { + enableChannel(channel.Id, channel.Name) + } channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } diff --git a/controller/model.go b/controller/model.go index 1bae0f74..0a78fd17 100644 --- a/controller/model.go +++ b/controller/model.go @@ -361,6 +361,24 @@ func init() { Root: "claude-2", Parent: nil, }, + { + Id: "claude-2.1", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.1", + Parent: nil, + }, + { + Id: "claude-2.0", + Object: "model", + Created: 1677649963, + OwnedBy: "anthropic", + Permission: permission, + Root: "claude-2.0", + Parent: nil, + }, { Id: "ERNIE-Bot", Object: "model", @@ -406,6 +424,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", @@ -460,6 +487,24 @@ func init() { Root: "qwen-plus", Parent: nil, }, + { + Id: "qwen-max", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max", + Parent: nil, + }, + { + Id: "qwen-max-longcontext", + Object: "model", + Created: 1677649963, + OwnedBy: "ali", + Permission: permission, + Root: "qwen-max-longcontext", + Parent: nil, + }, { Id: "text-embedding-v1", Object: "model", diff --git a/controller/relay-chat.go b/controller/relay-chat.go index 17dc8039..a1e93e25 100644 --- a/controller/relay-chat.go +++ b/controller/relay-chat.go @@ -2,6 +2,7 @@ package controller import ( "context" + "math" "net/http" "one-api/common" "one-api/model" @@ -24,6 +25,11 @@ func RelayChat(c *gin.Context) { return } + if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 { + common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") + return + } + // 解析模型映射 var isModelMapped bool modelMap, err := parseModelMapping(channel.GetModelMapping()) diff --git a/controller/relay-completions.go b/controller/relay-completions.go index c6f7ab86..da60a773 100644 --- a/controller/relay-completions.go +++ b/controller/relay-completions.go @@ -2,6 +2,7 @@ package controller import ( "context" + "math" "net/http" "one-api/common" "one-api/model" @@ -24,6 +25,11 @@ func RelayCompletions(c *gin.Context) { return } + if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 { + common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid") + return + } + // 解析模型映射 var isModelMapped bool modelMap, err := parseModelMapping(channel.GetModelMapping()) diff --git a/controller/relay-image-generations.go b/controller/relay-image-generations.go index f339b79c..20092e0e 100644 --- a/controller/relay-image-generations.go +++ b/controller/relay-image-generations.go @@ -24,6 +24,10 @@ func RelayImageGenerations(c *gin.Context) { imageRequest.Model = "dall-e-2" } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { imageRequest.Size = "1024x1024" } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 6e930df3..3307b0d1 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -110,9 +110,14 @@ func setChannelToContext(c *gin.Context, channel *model.Channel) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeAli: + c.Set("plugin", channel.Other) } + } func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { @@ -131,8 +136,22 @@ func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool { return false } -func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - err := model.PostConsumeTokenQuota(tokenId, quota) +func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool { + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} + +func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -140,11 +159,15 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c if err != nil { common.SysError("error update user quota cache: " + err.Error()) } - if quota != 0 { + // totalQuota is total quota consumed + if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } diff --git a/go.mod b/go.mod index 632e2a6a..19b5b72d 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 + github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.14.0 + golang.org/x/image v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -26,6 +28,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.10.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect @@ -51,12 +54,13 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect - golang.org/x/text v0.13.0 // indirect + golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9d3407fe..a37fd320 100644 --- a/go.sum +++ b/go.sum @@ -154,6 +154,8 @@ golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4= +golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= @@ -170,8 +172,8 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/i18n/en.json b/i18n/en.json index 9b2ca4c8..b0deb83a 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -119,6 +119,7 @@ " 年 ": " y ", "未测试": "Not tested", "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", @@ -139,6 +140,7 @@ "启用": "Enable", "编辑": "Edit", "添加新的渠道": "Add a new channel", + "测试所有通道": "Test all channels", "测试所有已启用通道": "Test all enabled channels", "更新所有已启用通道余额": "Update the balance of all enabled channels", "刷新": "Refresh", diff --git a/middleware/auth.go b/middleware/auth.go index b0803612..ad7e64b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 72a1b362..b40ed496 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -11,7 +11,6 @@ func Distribute() func(c *gin.Context) { userId := c.GetInt("id") userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) - c.Next() } } diff --git a/middleware/recover.go b/middleware/recover.go new file mode 100644 index 00000000..c3a3d748 --- /dev/null +++ b/middleware/recover.go @@ -0,0 +1,26 @@ +package middleware + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func RelayPanicRecover() gin.HandlerFunc { + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + common.SysError(fmt.Sprintf("panic detected: %v", err)) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": gin.H{ + "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), + "type": "one_api_panic", + }, + }) + c.Abort() + } + }() + c.Next() + } +} diff --git a/model/main.go b/model/main.go index 08182634..bfd6888b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -59,7 +60,8 @@ func chooseDB() (*gorm.DB, error) { // Use SQLite common.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true - return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{ + config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) + return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } diff --git a/model/option.go b/model/option.go index 4ef4d260..bb8b709c 100644 --- a/model/option.go +++ b/model/option.go @@ -34,6 +34,7 @@ func InitOptionMap() { common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) @@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) { common.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": common.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": diff --git a/providers/ali/base.go b/providers/ali/base.go index f49067c0..72ee3317 100644 --- a/providers/ali/base.go +++ b/providers/ali/base.go @@ -33,6 +33,9 @@ func (p *AliProvider) GetRequestHeaders() (headers map[string]string) { headers = make(map[string]string) p.CommonRequestHeaders(headers) headers["Authorization"] = fmt.Sprintf("Bearer %s", p.Context.GetString("api_key")) + if p.Context.GetString("plugin") != "" { + headers["X-DashScope-Plugin"] = p.Context.GetString("plugin") + } return headers } diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 30606298..68d8376f 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -53,32 +53,17 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI // 获取聊天请求体 func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest { messages := make([]AliMessage, 0, len(request.Messages)) - prompt := "" for i := 0; i < len(request.Messages); i++ { message := request.Messages[i] - if message.Role == "system" { - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: "Okay", - }) - continue - } else { - if i == len(request.Messages)-1 { - prompt = message.StringContent() - break - } - messages = append(messages, AliMessage{ - User: message.StringContent(), - Bot: request.Messages[i+1].StringContent(), - }) - i++ - } + messages = append(messages, AliMessage{ + Content: message.StringContent(), + Role: strings.ToLower(message.Role), + }) } return &AliChatRequest{ Model: request.Model, Input: AliInput{ - Prompt: prompt, - History: messages, + Messages: messages, }, } } diff --git a/providers/ali/type.go b/providers/ali/type.go index e4c5d3d2..da24dcb3 100644 --- a/providers/ali/type.go +++ b/providers/ali/type.go @@ -13,13 +13,13 @@ type AliUsage struct { } type AliMessage struct { - User string `json:"user"` - Bot string `json:"bot"` + Content string `json:"content"` + Role string `json:"role"` } type AliInput struct { - Prompt string `json:"prompt"` - History []AliMessage `json:"history"` + // Prompt string `json:"prompt"` + Messages []AliMessage `json:"messages"` } type AliParameters struct { diff --git a/providers/claude/chat.go b/providers/claude/chat.go index cd5ba74d..cfff6bc0 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -69,7 +69,9 @@ func (p *ClaudeProvider) getChatRequestBody(request *types.ChatCompletionRequest } else if message.Role == "assistant" { prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) } else if message.Role == "system" { - prompt += fmt.Sprintf("\n\nSystem: %s", message.Content) + if prompt == "" { + prompt = message.StringContent() + } } } prompt += "\n\nAssistant:" diff --git a/providers/gemini/base.go b/providers/gemini/base.go new file mode 100644 index 00000000..0e448cd6 --- /dev/null +++ b/providers/gemini/base.go @@ -0,0 +1,45 @@ +package gemini + +import ( + "fmt" + "one-api/providers/base" + "strings" + + "github.com/gin-gonic/gin" +) + +type GeminiProviderFactory struct{} + +// 创建 ClaudeProvider +func (f GeminiProviderFactory) Create(c *gin.Context) base.ProviderInterface { + return &GeminiProvider{ + BaseProvider: base.BaseProvider{ + BaseURL: "https://generativelanguage.googleapis.com", + ChatCompletions: "/", + Context: c, + }, + } +} + +type GeminiProvider struct { + base.BaseProvider +} + +func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string) string { + baseURL := strings.TrimSuffix(p.GetBaseURL(), "/") + version := "v1" + if p.Context.GetString("api_version") != "" { + version = p.Context.GetString("api_version") + } + + return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key")) + +} + +// 获取请求头 +func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) { + headers = make(map[string]string) + p.CommonRequestHeaders(headers) + + return headers +} diff --git a/providers/gemini/chat.go b/providers/gemini/chat.go new file mode 100644 index 00000000..9e2efa15 --- /dev/null +++ b/providers/gemini/chat.go @@ -0,0 +1,261 @@ +package gemini + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "one-api/providers/base" + "one-api/types" + "strings" +) + +func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + if len(response.Candidates) == 0 { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + } + } + + fullTextResponse := &types.ChatCompletionResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := types.ChatCompletionChoice{ + Index: i, + Message: types.ChatCompletionMessage{ + Role: "assistant", + Content: "", + }, + FinishReason: base.StopFinishReason, + } + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = candidate.Content.Parts[0].Text + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + + completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro") + response.Usage.CompletionTokens = completionTokens + response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens + + return fullTextResponse, nil +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest) (requestBody *GeminiChatRequest) { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(request.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: request.Temperature, + TopP: request.TopP, + MaxOutputTokens: request.MaxTokens, + }, + } + if request.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: request.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range request.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + requestBody := p.getChatRequestBody(request) + fullRequestURL := p.GetFullRequestURL("generateContent", request.Model) + headers := p.GetRequestHeaders() + if request.Stream { + headers["Accept"] = "text/event-stream" + } + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Stream { + var responseText string + errWithCode, responseText = p.sendStreamRequest(req) + if errWithCode != nil { + return + } + + usage.PromptTokens = promptTokens + usage.CompletionTokens = common.CountTokenText(responseText, request.Model) + usage.TotalTokens = promptTokens + usage.CompletionTokens + + } else { + var geminiResponse = &GeminiChatResponse{ + Usage: &types.Usage{ + PromptTokens: promptTokens, + }, + } + errWithCode = p.SendRequest(req, geminiResponse, false) + if errWithCode != nil { + return + } + + usage = geminiResponse.Usage + } + return + +} + +func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse { + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = geminiResponse.GetResponseText() + choice.FinishReason = &base.StopFinishReason + var response types.ChatCompletionStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []types.ChatCompletionStreamChoice{choice} + return &response +} + +func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + defer req.Body.Close() + + // 发送请求 + resp, err := common.HttpClient.Do(req) + if err != nil { + return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), "" + } + + if common.IsFailureStatusCode(resp) { + return common.HandleErrorResp(resp), "" + } + + defer resp.Body.Close() + + responseText := "" + dataChan := make(chan string) + stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + go func() { + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(p.Context) + p.Context.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice types.ChatCompletionStreamChoice + choice.Delta.Content = dummy.Content + response := types.ChatCompletionStreamResponse{ + ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []types.ChatCompletionStreamChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + p.Context.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + p.Context.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + + return nil, responseText +} diff --git a/providers/gemini/type.go b/providers/gemini/type.go new file mode 100644 index 00000000..333dfcc7 --- /dev/null +++ b/providers/gemini/type.go @@ -0,0 +1,75 @@ +package gemini + +import "one-api/types" + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` + Usage *types.Usage `json:"usage,omitempty"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} diff --git a/providers/openai/image_generations.go b/providers/openai/image_generations.go index d1f3ae09..ba5a450b 100644 --- a/providers/openai/image_generations.go +++ b/providers/openai/image_generations.go @@ -19,6 +19,10 @@ func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Respons func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + if isWithinRange(request.Model, request.N) == false { + return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest) + } + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) @@ -47,3 +51,13 @@ func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isM return } + +func isWithinRange(element string, value int) bool { + if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + return false + } + min := common.DalleGenerationImageAmounts[element][0] + max := common.DalleGenerationImageAmounts[element][1] + + return value >= min && value <= max +} diff --git a/providers/xunfei/chat.go b/providers/xunfei/chat.go index 31294e27..9a8c1317 100644 --- a/providers/xunfei/chat.go +++ b/providers/xunfei/chat.go @@ -46,6 +46,14 @@ func (p *XunfeiProvider) sendRequest(request *types.ChatCompletionRequest, authU } } + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{ + { + Content: "", + }, + } + } + xunfeiResponse.Payload.Choices.Text[0].Content = content response := p.responseXunfei2OpenAI(&xunfeiResponse) diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 00000000..bbcd969c --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1,3 @@ +close #issue_number + +我已确认该 PR 已自测通过,相关截图如下: \ No newline at end of file diff --git a/router/relay-router.go b/router/relay-router.go index 4f7d5c15..171864c6 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.RelayCompletions) relayV1Router.POST("/chat/completions", controller.RelayChat) @@ -36,11 +36,37 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.GET("/files/:id", controller.RelayNotImplemented) relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented) + relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) + relayV1Router.POST("/assistants", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented) + relayV1Router.GET("/assistants", controller.RelayNotImplemented) + relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented) + relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id", controller.RelayNotImplemented) + relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented) + relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented) + relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) } } diff --git a/web/src/constants/ChannelConstants.js b/web/src/constants/ChannelConstants.js index f81fb994..c91c5de8 100644 --- a/web/src/constants/ChannelConstants.js +++ b/web/src/constants/ChannelConstants.js @@ -23,6 +23,12 @@ export const CHANNEL_OPTIONS = { value: 11, color: 'orange' }, + 25: { + key: 25, + text: 'Google Gemini', + value: 25, + color: 'orange' + }, 15: { key: 15, text: '百度文心千帆', diff --git a/web/src/views/Channel/index.js b/web/src/views/Channel/index.js index 59dd699b..2c47614f 100644 --- a/web/src/views/Channel/index.js +++ b/web/src/views/Channel/index.js @@ -137,7 +137,7 @@ export default function ChannelPage() { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { - showInfo('已成功开始测试所有已启用通道,请刷新页面查看结果。'); + showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); } else { showError(message); } diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index c790313f..576fbdd4 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -50,7 +50,7 @@ const typeConfig = { }, 14: { input: { - models: ['claude-instant-1', 'claude-2'] + models: ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'] } }, 15: { @@ -67,8 +67,14 @@ const typeConfig = { } }, 17: { + inputLabel: { + other: '插件参数' + }, input: { - models: ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'] + models: ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'] + }, + prompt: { + other: '请输入插件参数,即 X-DashScope-Plugin 请求头的取值' } }, 18: { @@ -96,6 +102,11 @@ const typeConfig = { prompt: { key: '按照如下格式输入:AppId|SecretId|SecretKey' } + }, + 25: { + input: { + models: ['gemini-pro'] + } } }; diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js index 820ca2e6..a52b57d8 100644 --- a/web/src/views/Setting/component/OperationSetting.js +++ b/web/src/views/Setting/component/OperationSetting.js @@ -23,6 +23,7 @@ const OperationSetting = () => { ChatLink: '', QuotaPerUnit: 0, AutomaticDisableChannelEnabled: '', + AutomaticEnableChannelEnabled: '', ChannelDisableThreshold: 0, LogConsumeEnabled: '', DisplayInCurrencyEnabled: '', @@ -327,6 +328,16 @@ const OperationSetting = () => { /> } /> + + } + />