diff --git a/.gitignore b/.gitignore index 1b2cf071..60abb13e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ upload *.db build *.db-journal -logs \ No newline at end of file +logs +data \ No newline at end of file diff --git a/README.en.md b/README.en.md index 783c140c..9345a219 100644 --- a/README.en.md +++ b/README.en.md @@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. First, fork the code. 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). diff --git a/README.ja.md b/README.ja.md index fa3339c2..6faf9bee 100644 --- a/README.ja.md +++ b/README.ja.md @@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. まず、コードをフォークする。 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 diff --git a/README.md b/README.md index cb641947..20c81361 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://console.closeai-asia.com/r/2412) + + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) @@ -92,15 +92,16 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 12. 支持**用户邀请奖励**。 13. 支持以美元为单位显示额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。 -15. 支持模型映射,重定向用户的请求模型。 +15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。 16. 支持失败自动重试。 17. 支持绘图接口。 -18. 支持丰富的**自定义**设置, +18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 +19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -19. 支持通过系统访问令牌访问管理 API。 -20. 支持 Cloudflare Turnstile 用户校验。 -21. 支持用户管理,支持**多种用户登录注册方式**: +20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。 +21. 支持 Cloudflare Turnstile 用户校验。 +22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 @@ -159,6 +160,19 @@ sudo service nginx restart 初始账号用户名为 `root`,密码为 `123456`。 + +### 基于 Docker Compose 进行部署 + +> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 + +```shell +# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 +docker-compose up -d + +# 查看部署状态 +docker-compose ps +``` + ### 手动部署 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell @@ -248,6 +262,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. 首先 fork 一份代码。 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 @@ -351,6 +367,10 @@ graph LR 13. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 +14. 编码器缓存设置: + + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 +15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index a0361c35..c7d3f222 100644 --- a/common/constants.go +++ b/common/constants.go @@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true -var UsingSQLite = false - // Any options with "Secret", "Token" in its key won't be return by GetOptions var SessionSecret = uuid.New().String() -var SQLitePath = "one-api.db" var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex @@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second + const ( RequestIdKey = "X-Oneapi-Request-Id" ) diff --git a/common/database.go b/common/database.go new file mode 100644 index 00000000..c7e9fd52 --- /dev/null +++ b/common/database.go @@ -0,0 +1,6 @@ +package common + +var UsingSQLite = false +var UsingPostgreSQL = false + +var SQLitePath = "one-api.db" diff --git a/common/gin.go b/common/gin.go index ffa1e218..f5012688 100644 --- a/common/gin.go +++ b/common/gin.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/gin-gonic/gin" "io" + "strings" ) func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } - err = json.Unmarshal(requestBody, &v) + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + err = json.Unmarshal(requestBody, &v) + } else { + // skip for now + // TODO: someday non json request have variant model, we will need to implementation this + } if err != nil { return err } diff --git a/common/model-ratio.go b/common/model-ratio.go index b121ba4f..9060f3be 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -3,8 +3,32 @@ package common import ( "encoding/json" "strings" + "time" ) +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} + // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf @@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{ "gpt-4-32k": 30, "gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "text-ada-001": 0.2, "text-babbage-001": 0.25, "text-curie-001": 1, @@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, "davinci": 10, "curie": 10, "babbage": 10, @@ -41,13 +72,16 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e": 8, + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens @@ -87,9 +121,24 @@ func GetModelRatio(name string) float64 { func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-3.5") { + if strings.HasSuffix(name, "1106") { + return 2 + } + if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { + // TODO: clear this after 2023-12-11 + now := time.Now() + // https://platform.openai.com/docs/models/continuous-model-upgrades + // if after 2023-12-11, use 2 + if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { + return 2 + } + } return 1.333333 } if strings.HasPrefix(name, "gpt-4") { + if strings.HasSuffix(name, "preview") { + return 3 + } return 2 } if strings.HasPrefix(name, "claude-instant-1") { diff --git a/common/utils.go b/common/utils.go index ab901b77..21bec8f5 100644 --- a/common/utils.go +++ b/common/utils.go @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 1974ef6e..1b0b745a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,13 +5,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" + "io" "net/http" "one-api/common" "one-api/model" "strconv" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { @@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) + requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { - if channel.GetBaseURL() != "" { - requestURL = channel.GetBaseURL() + if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { + requestURL = baseURL } - requestURL += "/v1/chat/completions" - } + requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) + } jsonData, err := json.Marshal(request) if err != nil { return err, nil @@ -70,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } defer resp.Body.Close() var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) + body, err := io.ReadAll(resp.Body) if err != nil { return err, nil } + err = json.Unmarshal(body, &response) + if err != nil { + return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + } if response.Usage.CompletionTokens == 0 { return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error } diff --git a/controller/channel.go b/controller/channel.go index 41a55550..904abc23 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -127,8 +127,8 @@ func DeleteChannel(c *gin.Context) { return } -func DeleteManuallyDisabledChannel(c *gin.Context) { - rows, err := model.DeleteChannelByStatus(common.ChannelStatusManuallyDisabled) +func DeleteDisabledChannel(c *gin.Context) { + rows, err := model.DeleteDisabledChannel() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/model.go b/controller/model.go index a8e71db7..feb1a0bc 100644 --- a/controller/model.go +++ b/controller/model.go @@ -55,12 +55,21 @@ func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility openAIModels = []OpenAIModels{ { - Id: "dall-e", + Id: "dall-e-2", Object: "model", Created: 1677649963, OwnedBy: "openai", Permission: permission, - Root: "dall-e", + Root: "dall-e-2", + Parent: nil, + }, + { + Id: "dall-e-3", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e-3", Parent: nil, }, { @@ -72,6 +81,42 @@ func init() { Root: "whisper-1", Parent: nil, }, + { + Id: "tts-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1", + Parent: nil, + }, + { + Id: "tts-1-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-1106", + Parent: nil, + }, + { + Id: "tts-1-hd", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd", + Parent: nil, + }, + { + Id: "tts-1-hd-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", @@ -117,6 +162,15 @@ func init() { Root: "gpt-3.5-turbo-16k-0613", Parent: nil, }, + { + Id: "gpt-3.5-turbo-1106", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo-instruct", Object: "model", @@ -180,6 +234,24 @@ func init() { Root: "gpt-4-32k-0613", Parent: nil, }, + { + Id: "gpt-4-1106-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-1106-preview", + Parent: nil, + }, + { + Id: "gpt-4-vision-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-vision-preview", + Parent: nil, + }, { Id: "text-embedding-ada-002", Object: "model", @@ -274,7 +346,7 @@ func init() { Id: "claude-instant-1", Object: "model", Created: 1677649963, - OwnedBy: "anturopic", + OwnedBy: "anthropic", Permission: permission, Root: "claude-instant-1", Parent: nil, @@ -283,7 +355,7 @@ func init() { Id: "claude-2", Object: "model", Created: 1677649963, - OwnedBy: "anturopic", + OwnedBy: "anthropic", Permission: permission, Root: "claude-2", Parent: nil, @@ -306,6 +378,15 @@ func init() { Root: "ERNIE-Bot-turbo", Parent: nil, }, + { + Id: "ERNIE-Bot-4", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot-4", + Parent: nil, + }, { Id: "Embedding-V1", Object: "model", @@ -324,6 +405,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "chatglm_turbo", + Object: "model", + Created: 1677649963, + OwnedBy: "zhipu", + Permission: permission, + Root: "chatglm_turbo", + Parent: nil, + }, { Id: "chatglm_pro", Object: "model", diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce8..543954f7 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = request.Messages[len(request.Messages)-1].StringContent() } return &AIProxyLibraryRequest{ Model: request.Model, diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743c..b41ca327 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: message.StringContent(), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = message.StringContent() break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: message.StringContent(), + Bot: request.Messages[i+1].StringContent(), }) i++ } diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 381c6feb..01267fbf 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,13 +5,11 @@ import ( "context" "encoding/json" "errors" - "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - - "github.com/gin-gonic/gin" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -22,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") + tokenName := c.GetString("token_name") + + var ttsRequest TextToSpeechRequest + if relayMode == RelayModeAudioSpeech { + // Read JSON + err := common.UnmarshalBodyReusable(c, &ttsRequest) + // Check if JSON is valid + if err != nil { + return errorWrapper(err, "invalid_json", http.StatusBadRequest) + } + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + } preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) @@ -32,22 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + + quota := 0 + // Check if user quota is enough + if relayMode == RelayModeAudioSpeech { + quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) + if quota > userQuota { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + } else { + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } } } @@ -66,12 +90,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) @@ -95,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var audioResponse AudioResponse - defer func(ctx context.Context) { - go func() { - quota := countTokenText(audioResponse.Text, audioModel) + if relayMode == RelayModeAudioSpeech { + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + } else { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var whisperResponse WhisperResponse + err = json.Unmarshal(responseBody, &whisperResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + defer func(ctx context.Context) { + quota := countTokenText(whisperResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac04..c75ec09a 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } diff --git a/controller/relay-image.go b/controller/relay-image.go index 998a7851..1d1b71ba 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -14,8 +14,20 @@ import ( "github.com/gin-gonic/gin" ) +func isWithinRange(element string, value int) bool { + if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + return false + } + + min := common.DalleGenerationImageAmounts[element][0] + max := common.DalleGenerationImageAmounts[element][1] + + return value >= min && value <= max +} + func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e" + imageModel := "dall-e-2" + imageSize := "1024x1024" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -32,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + // Size validation + if imageRequest.Size != "" { + imageSize = imageRequest.Size + } + + // Model validation + if imageRequest.Model != "" { + imageModel = imageRequest.Model + } + + imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] + + // Check if model is supported + if hasValidSize { + if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { + if imageSize == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + // Check prompt length + if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { + return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + // Number of generated images validation + if isWithinRange(imageModel, imageRequest.N) == false { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } // map model name @@ -61,16 +98,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode isModelMapped = true } } - baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) var requestBody io.Reader if isModelMapped { jsonStr, err := json.Marshal(imageRequest) @@ -87,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - sizeRatio := 1.0 - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 1 - } else if imageRequest.Size == "512x512" { - sizeRatio = 1.125 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1.25 - } - quota := int(ratio*sizeRatio*1000) * imageRequest.N + quota := int(ratio*imageCostRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc08..dcd20115 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + completionTokens += countTokenText(choice.Message.StringContent(), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b318..2bd0bcd8 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: message.StringContent(), } if message.Role == "user" { palmMessage.Author = "0" diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468bc..f66bf38f 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: message.StringContent(), Role: message.Role, }) } diff --git a/controller/relay-text.go b/controller/relay-text.go index 24d09e5a..fce54b89 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -6,13 +6,15 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" + "math" "net/http" "one-api/common" "one-api/model" "strings" "time" + + "github.com/gin-gonic/gin" ) const ( @@ -31,7 +33,14 @@ var httpClient *http.Client var impatientHTTPClient *http.Client func init() { - httpClient = &http.Client{} + if common.RelayTimeout == 0 { + httpClient = &http.Client{} + } else { + httpClient = &http.Client{ + Timeout: time.Duration(common.RelayTimeout) * time.Second, + } + } + impatientHTTPClient = &http.Client{ Timeout: 5 * time.Second, } @@ -118,7 +127,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) switch apiType { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { @@ -138,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0613") - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) } case APITypeClaude: fullRequestURL = "https://api.anthropic.com/v1/complete" @@ -151,6 +162,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" case "ERNIE-Bot-turbo": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" case "BLOOMZ-7B": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" case "Embedding-V1": @@ -367,11 +380,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeTencent: req.Header.Set("Authorization", apiKey) + case APITypePaLM: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if isStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } //req.Header.Set("Connection", c.Request.Header.Get("Connection")) resp, err = httpClient.Do(req) if err != nil { @@ -412,9 +430,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens completionTokens = textResponse.Usage.CompletionTokens - - quota = promptTokens + int(float64(completionTokens)*completionRatio) - quota = int(float64(quota) * ratio) + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 4775ec88..c7cd4766 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,15 +1,18 @@ package controller import ( + "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" + "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -84,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) + tokenNum += getTokenNum(tokenEncoder, message.StringContent()) tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName @@ -176,3 +179,35 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr openAIErrorWithStatusCode.OpenAIError = textResponse.Error return } + +func getFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + + return fullRequestURL +} + +func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } +} diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index ff6bf065..00ec8981 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } @@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin for !stop { select { case xunfeiResponse = <-dataChan: + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + continue + } content += xunfeiResponse.Payload.Choices.Text[0].Content usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens @@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, common.SysLog("api_version not found, use default: " + apiVersion) } domain := "general" - if apiVersion == "v2.1" { - domain = "generalv2" + if apiVersion != "v1.1" { + domain += strings.Split(apiVersion, ".")[0] } authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) return domain, authUrl diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index b7376bd0..05022937 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -154,7 +154,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: message.StringContent(), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -163,7 +163,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: message.StringContent(), }) } } diff --git a/controller/relay.go b/controller/relay.go index 1926110e..f91ba6da 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -12,10 +12,49 @@ import ( type Message struct { Role string `json:"role"` - Content string `json:"content"` + Content any `json:"content"` Name *string `json:"name,omitempty"` } +type ImageURL struct { + Url string `json:"url,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type TextContent struct { + Type string `json:"type,omitempty"` + Text string `json:"text,omitempty"` +} + +type ImageContent struct { + Type string `json:"type,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` +} + +func (m Message) StringContent() string { + content, ok := m.Content.(string) + if ok { + return content + } + contentList, ok := m.Content.([]any) + if ok { + var contentStr string + for _, contentItem := range contentList { + contentMap, ok := contentItem.(map[string]any) + if !ok { + continue + } + if contentMap["type"] == "text" { + if subStr, ok := contentMap["text"].(string); ok { + contentStr += subStr + } + } + } + return contentStr + } + return "" +} + const ( RelayModeUnknown = iota RelayModeChatCompletions @@ -24,24 +63,37 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits - RelayModeAudio + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation ) // https://platform.openai.com/docs/api-reference/chat +type ResponseFormat struct { + Type string `json:"type,omitempty"` +} + type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` + Size string `json:"size,omitempty"` + Functions any `json:"functions,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat *ResponseFormat `json:"response_format,omitempty"` + Seed float64 `json:"seed,omitempty"` + Tools any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + User string `json:"user,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { @@ -77,16 +129,30 @@ type TextRequest struct { //Stream bool `json:"stream"` } +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create type ImageRequest struct { - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + ResponseFormat string `json:"response_format"` + Style string `json:"style"` + User string `json:"user"` } -type AudioResponse struct { +type WhisperResponse struct { Text string `json:"text,omitempty"` } +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -183,14 +249,22 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - relayMode = RelayModeAudio + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) - case RelayModeAudio: + case RelayModeAudioSpeech: + fallthrough + case RelayModeAudioTranslation: + fallthrough + case RelayModeAudioTranscription: err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) diff --git a/docker-compose.yml b/docker-compose.yml index 003122bb..30edb281 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,21 +9,21 @@ services: ports: - "3000:3000" volumes: - - ./data:/data + - ./data/oneapi:/data - ./logs:/app/logs environment: - - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 + - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - REDIS_CONN_STRING=redis://redis - SESSION_SECRET=random_string # 修改为随机字符串 - TZ=Asia/Shanghai # - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 - depends_on: - redis + - db healthcheck: - test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] + test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] interval: 30s timeout: 10s retries: 3 @@ -32,3 +32,18 @@ services: image: redis:latest container_name: redis restart: always + + db: + image: mysql:8.2.0 + restart: always + container_name: mysql + volumes: + - ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储 + ports: + - '3306:3306' + environment: + TZ: Asia/Shanghai # 设置时区 + MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 + MYSQL_USER: oneapi # 创建专用用户 + MYSQL_PASSWORD: '123456' # 设置专用用户密码 + MYSQL_DATABASE: one-api # 自动创建数据库 \ No newline at end of file diff --git a/go.mod b/go.mod index 79b01f93..10b78d68 100644 --- a/go.mod +++ b/go.mod @@ -15,8 +15,9 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 github.com/pkoukk/tiktoken-go v0.1.5 - golang.org/x/crypto v0.9.0 + golang.org/x/crypto v0.14.0 gorm.io/driver/mysql v1.4.3 + gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 gorm.io/gorm v1.25.0 ) @@ -52,10 +53,9 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - gorm.io/driver/postgres v1.5.2 // indirect ) diff --git a/go.sum b/go.sum index 810e7819..4865bcaa 100644 --- a/go.sum +++ b/go.sum @@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= diff --git a/middleware/distributor.go b/middleware/distributor.go index d80945fc..c4ddc3a0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - var err error - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } + err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { - modelRequest.Model = "dall-e" + modelRequest.Model = "dall-e-2" } } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } diff --git a/model/ability.go b/model/ability.go index 50972a26..3da83be8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,10 +15,17 @@ type Ability struct { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) - channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) - if common.UsingSQLite { + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { err = channelQuery.Order("RAND()").First(&ability).Error diff --git a/model/cache.go b/model/cache.go index a7f5c06f..c6d0c70a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,14 +21,18 @@ var ( ) func CacheGetTokenByKey(key string) (*Token, error) { + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } var token Token if !common.RedisEnabled { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error return &token, err } tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error if err != nil { return nil, err } diff --git a/model/channel.go b/model/channel.go index 36bb78a5..7e7b42e6 100644 --- a/model/channel.go +++ b/model/channel.go @@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { } func SearchChannels(keyword string) (channels []*Channel, err error) { - err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { return &channel, err } -func GetRandomChannel() (*Channel, error) { - channel := Channel{} - var err error = nil - if common.UsingSQLite { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error - } else { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error - } - return &channel, err -} - func BatchInsertChannels(channels []Channel) error { var err error err = DB.Create(&channels).Error @@ -181,3 +174,8 @@ func DeleteChannelByStatus(status int64) (int64, error) { result := DB.Where("status = ?", status).Delete(&Channel{}) return result.RowsAffected, result.Error } + +func DeleteDisabledChannel() (int64, error) { + result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + return result.RowsAffected, result.Error +} diff --git a/model/log.go b/model/log.go index d26da9a2..3d3ffae3 100644 --- a/model/log.go +++ b/model/log.go @@ -94,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = tx.Where("created_at <= ?", endTimestamp) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err @@ -151,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa tx = tx.Where("model_name = ?", modelName) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota diff --git a/model/main.go b/model/main.go index 0e962049..08182634 100644 --- a/model/main.go +++ b/model/main.go @@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) { if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage diff --git a/model/redemption.go b/model/redemption.go index fafb2145..f16412b5 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) { } redemption := &Redemption{} + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } + err = DB.Transaction(func(tx *gorm.DB) error { - err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error if err != nil { return errors.New("无效的兑换码") } diff --git a/model/user.go b/model/user.go index cee4b023..7844eb6a 100644 --- a/model/user.go +++ b/model/user.go @@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) { } func GetUserGroup(id int) (group string, err error) { - err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + + err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error return group, err } @@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + addNewRecord(BatchUpdateTypeUsedQuota, id, quota) + addNewRecord(BatchUpdateTypeRequestCount, id, 1) return } updateUserUsedQuotaAndRequestCount(id, quota, 1) @@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } +func updateUserUsedQuota(id int, quota int) { + err := DB.Model(&User{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", quota), + }, + ).Error + if err != nil { + common.SysError("failed to update user used quota: " + err.Error()) + } +} + +func updateUserRequestCount(id int, count int) { + err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error + if err != nil { + common.SysError("failed to update user request count: " + err.Error()) + } +} + func GetUsernameById(id int) (username string) { DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) return username diff --git a/model/utils.go b/model/utils.go index 61734332..1c28340b 100644 --- a/model/utils.go +++ b/model/utils.go @@ -6,13 +6,13 @@ import ( "time" ) -const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock - const ( BatchUpdateTypeUserQuota = iota BatchUpdateTypeTokenQuota - BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeUsedQuota BatchUpdateTypeChannelUsedQuota + BatchUpdateTypeRequestCount + BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) var batchUpdateStores []map[int]int @@ -51,7 +51,7 @@ func batchUpdate() { store := batchUpdateStores[i] batchUpdateStores[i] = make(map[int]int) batchUpdateLocks[i].Unlock() - + // TODO: maybe we can combine updates with same key? for key, value := range store { switch i { case BatchUpdateTypeUserQuota: @@ -64,8 +64,10 @@ func batchUpdate() { if err != nil { common.SysError("failed to batch update token quota: " + err.Error()) } - case BatchUpdateTypeUsedQuotaAndRequestCount: - updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeUsedQuota: + updateUserUsedQuota(key, value) + case BatchUpdateTypeRequestCount: + updateUserRequestCount(key, value) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/router/api-router.go b/router/api-router.go index 5ec385dc..da3f9e61 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -74,7 +74,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) - channelRoute.DELETE("/manually_disabled", controller.DeleteManuallyDisabledChannel) + channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) channelRoute.DELETE("/:id", controller.DeleteChannel) } tokenRoute := apiRouter.Group("/token") diff --git a/router/relay-router.go b/router/relay-router.go index e84f02db..912f4989 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) diff --git a/web/src/App.js b/web/src/App.js index c967ce2c..13c884dc 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -283,7 +283,9 @@ function App() { } /> - + + } /> ); } diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 57d45c55..d44ea2d7 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; -import { Button, Form, Input, Label, Pagination, Popup, Table } from 'semantic-ui-react'; +import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; +import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; @@ -55,6 +55,7 @@ const ChannelsTable = () => { const [searchKeyword, setSearchKeyword] = useState(''); const [searching, setSearching] = useState(false); const [updatingBalance, setUpdatingBalance] = useState(false); + const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); const loadChannels = async (startIdx) => { const res = await API.get(`/api/channel/?p=${startIdx}`); @@ -226,7 +227,6 @@ const ChannelsTable = () => { showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); - showNotice('当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。'); } }; @@ -240,11 +240,11 @@ const ChannelsTable = () => { } }; - const deleteAllManuallyDisabledChannels = async () => { - const res = await API.delete(`/api/channel/manually_disabled`); + const deleteAllDisabledChannels = async () => { + const res = await API.delete(`/api/channel/disabled`); const { success, message, data } = res.data; if (success) { - showSuccess(`已删除所有手动禁用渠道,共计 ${data} 个`); + showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); await refresh(); } else { showError(message); @@ -286,17 +286,15 @@ const ChannelsTable = () => { if (channels.length === 0) return; setLoading(true); let sortedChannels = [...channels]; - if (typeof sortedChannels[0][key] === 'string') { - sortedChannels.sort((a, b) => { + sortedChannels.sort((a, b) => { + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings return ('' + a[key]).localeCompare(b[key]); - }); - } else { - sortedChannels.sort((a, b) => { - if (a[key] === b[key]) return 0; - if (a[key] > b[key]) return -1; - if (a[key] < b[key]) return 1; - }); - } + } + }); if (sortedChannels[0].id === channels[0].id) { sortedChannels.reverse(); } @@ -304,6 +302,7 @@ const ChannelsTable = () => { setLoading(false); }; + return ( <>
@@ -317,7 +316,19 @@ const ChannelsTable = () => { onChange={handleKeywordChange} />
+ { + showPrompt && ( + { + setShowPrompt(false); + setPromptShown("channel-test"); + }}> + 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo + 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 + 另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 + + ) + } @@ -519,14 +530,14 @@ const ChannelsTable = () => { - 删除所有手动禁用渠道 + 删除禁用渠道 } on='click' flowing hoverable > - diff --git a/web/src/components/RedemptionsTable.js b/web/src/components/RedemptionsTable.js index ae8b5b03..dfd59685 100644 --- a/web/src/components/RedemptionsTable.js +++ b/web/src/components/RedemptionsTable.js @@ -130,7 +130,13 @@ const RedemptionsTable = () => { setLoading(true); let sortedRedemptions = [...redemptions]; sortedRedemptions.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedRedemptions[0].id === redemptions[0].id) { sortedRedemptions.reverse(); diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index c7ec9b48..db4745e4 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -138,7 +138,7 @@ const TokensTable = () => { let defaultUrl; if (chatLink) { - defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } else { defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } @@ -228,7 +228,13 @@ const TokensTable = () => { setLoading(true); let sortedTokens = [...tokens]; sortedTokens.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedTokens[0].id === tokens[0].id) { sortedTokens.reverse(); diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index f8fb0a75..ad4e9b49 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -133,7 +133,13 @@ const UsersTable = () => { setLoading(true); let sortedUsers = [...users]; sortedUsers.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedUsers[0].id === users[0].id) { sortedUsers.reverse(); diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index 3871a43e..28ae4992 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { return false; } return true; -}; \ No newline at end of file +}; + +export function shouldShowPrompt(id) { + let prompt = localStorage.getItem(`prompt-${id}`); + return !prompt; + +} + +export function setPromptShown(id) { + localStorage.setItem(`prompt-${id}`, 'true'); +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index caee3c61..299fc19b 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -66,13 +66,13 @@ const EditChannel = () => { localModels = ['PaLM-2']; break; case 15: - localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; + localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; break; case 16: - localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding']; + localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding']; break; case 18: localModels = ['SparkDesk']; diff --git a/web/src/pages/NotFound/index.js b/web/src/pages/NotFound/index.js index 08a95f9d..f92dbc90 100644 --- a/web/src/pages/NotFound/index.js +++ b/web/src/pages/NotFound/index.js @@ -1,19 +1,12 @@ import React from 'react'; -import { Segment, Header } from 'semantic-ui-react'; +import { Message } from 'semantic-ui-react'; const NotFound = () => ( <> -
- - 未找到所请求的页面 - + + 页面不存在 +

请检查你的浏览器地址是否正确

+
);