From e5311892d1075a7facfd32c6aa358e4a574138da Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 8 Nov 2023 23:17:12 +0800 Subject: [PATCH 01/10] docs: update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 79229a94..38e3e0b7 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) - + [x] [CloseAI](https://console.closeai-asia.com/r/2412) + + [x] [CloseAI](https://referer.shadowai.xyz/r/2412) + [x] [API2D](https://api2d.com/r/197971) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`) From 9d43ec57d8e67aa89acb1af78f21c1469d26a113 Mon Sep 17 00:00:00 2001 From: Mikey Date: Fri, 10 Nov 2023 05:08:23 -0800 Subject: [PATCH 02/10] feat: sync pricing for new 1106 models (#696) * feat: sync pricing for new 1106 models * chore: change ratio after 2023-12-11 --------- Co-authored-by: JustSong --- common/model-ratio.go | 19 +++++++++++++++++++ controller/model.go | 27 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/common/model-ratio.go b/common/model-ratio.go index 8f4be8c3..681f0ae7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -3,6 +3,7 @@ package common import ( "encoding/json" "strings" + "time" ) // ModelRatio @@ -19,12 +20,15 @@ var ModelRatio = map[string]float64{ "gpt-4-32k": 30, "gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30, + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "text-ada-001": 0.2, "text-babbage-001": 0.25, "text-curie-001": 1, @@ -88,9 +92,24 @@ func GetModelRatio(name string) float64 { func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-3.5") { + if strings.HasSuffix(name, "1106") { + return 2 + } + if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { + // TODO: clear this after 2023-12-11 + now := time.Now() + // https://platform.openai.com/docs/models/continuous-model-upgrades + // if after 2023-12-11, use 2 + if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { + return 2 + } + } return 1.333333 } if strings.HasPrefix(name, "gpt-4") { + if strings.HasSuffix(name, "preview") { + return 3 + } return 2 } if strings.HasPrefix(name, "claude-instant-1") { diff --git a/controller/model.go b/controller/model.go index 2a7dc538..7bd9d097 100644 --- a/controller/model.go +++ b/controller/model.go @@ -117,6 +117,15 @@ func init() { Root: "gpt-3.5-turbo-16k-0613", Parent: nil, }, + { + Id: "gpt-3.5-turbo-1106", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo-instruct", Object: "model", @@ -180,6 +189,24 @@ func init() { Root: "gpt-4-32k-0613", Parent: nil, }, + { + Id: "gpt-4-1106-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-1106-preview", + Parent: nil, + }, + { + Id: "gpt-4-vision-preview", + Object: "model", + Created: 1699593571, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-vision-preview", + Parent: nil, + }, { Id: "text-embedding-ada-002", Object: "model", From 7c4505bdfc4af036dacea6050e04e4aa92f9c8d2 Mon Sep 17 00:00:00 2001 From: Baksi Date: Fri, 10 Nov 2023 21:20:05 +0800 Subject: [PATCH 03/10] fix: numeric sorting in tables (#695) * Update sorting method for id * Update sorting method for id (token) * Update sorting method for id (redemptions) * Update sorting method for id (channel) * chore: use same logic for all tables --------- Co-authored-by: JustSong --- web/src/components/ChannelsTable.js | 19 +++++++++---------- web/src/components/RedemptionsTable.js | 8 +++++++- web/src/components/TokensTable.js | 8 +++++++- web/src/components/UsersTable.js | 8 +++++++- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 732189cb..d44ea2d7 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -286,17 +286,15 @@ const ChannelsTable = () => { if (channels.length === 0) return; setLoading(true); let sortedChannels = [...channels]; - if (typeof sortedChannels[0][key] === 'string') { - sortedChannels.sort((a, b) => { + sortedChannels.sort((a, b) => { + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings return ('' + a[key]).localeCompare(b[key]); - }); - } else { - sortedChannels.sort((a, b) => { - if (a[key] === b[key]) return 0; - if (a[key] > b[key]) return -1; - if (a[key] < b[key]) return 1; - }); - } + } + }); if (sortedChannels[0].id === channels[0].id) { sortedChannels.reverse(); } @@ -304,6 +302,7 @@ const ChannelsTable = () => { setLoading(false); }; + return ( <>
diff --git a/web/src/components/RedemptionsTable.js b/web/src/components/RedemptionsTable.js index ae8b5b03..dfd59685 100644 --- a/web/src/components/RedemptionsTable.js +++ b/web/src/components/RedemptionsTable.js @@ -130,7 +130,13 @@ const RedemptionsTable = () => { setLoading(true); let sortedRedemptions = [...redemptions]; sortedRedemptions.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedRedemptions[0].id === redemptions[0].id) { sortedRedemptions.reverse(); diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index a3bb6f91..db4745e4 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -228,7 +228,13 @@ const TokensTable = () => { setLoading(true); let sortedTokens = [...tokens]; sortedTokens.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedTokens[0].id === tokens[0].id) { sortedTokens.reverse(); diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index f8fb0a75..ad4e9b49 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -133,7 +133,13 @@ const UsersTable = () => { setLoading(true); let sortedUsers = [...users]; sortedUsers.sort((a, b) => { - return ('' + a[key]).localeCompare(b[key]); + if (!isNaN(a[key])) { + // If the value is numeric, subtract to sort + return a[key] - b[key]; + } else { + // If the value is not numeric, sort as strings + return ('' + a[key]).localeCompare(b[key]); + } }); if (sortedUsers[0].id === users[0].id) { sortedUsers.reverse(); From 6c5307d0c4c7740ad6eb7917dced6dd2b939f3b1 Mon Sep 17 00:00:00 2001 From: Yuhang <2312744987@qq.com> Date: Fri, 10 Nov 2023 21:20:59 +0800 Subject: [PATCH 04/10] docs: add deploy to zeabur button (#693) * Update README.md * Update README.en.md * Update README.ja.md --- README.en.md | 2 ++ README.ja.md | 2 ++ README.md | 2 ++ 3 files changed, 6 insertions(+) diff --git a/README.en.md b/README.en.md index 783c140c..9345a219 100644 --- a/README.en.md +++ b/README.en.md @@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. First, fork the code. 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). diff --git a/README.ja.md b/README.ja.md index fa3339c2..6faf9bee 100644 --- a/README.ja.md +++ b/README.ja.md @@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. まず、コードをフォークする。 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 diff --git a/README.md b/README.md index 38e3e0b7..39eb5fa1 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 +[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3) + 1. 首先 fork 一份代码。 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 3. 新建一个 Project,在 Service -> Add Service 选择 Marketplace,选择 MySQL,并记下连接参数(用户名、密码、地址、端口)。 From d306cb52293b4053105d0de869795bffc4593730 Mon Sep 17 00:00:00 2001 From: qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> Date: Fri, 10 Nov 2023 21:40:00 +0800 Subject: [PATCH 05/10] feat: add improve docker-compose.yml and support fast startup (#685) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 王彦朋 Penn Wang --- .gitignore | 3 ++- README.md | 13 +++++++++++++ docker-compose.yml | 21 ++++++++++++++++++--- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 1b2cf071..60abb13e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ upload *.db build *.db-journal -logs \ No newline at end of file +logs +data \ No newline at end of file diff --git a/README.md b/README.md index 39eb5fa1..4ef6505c 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,19 @@ sudo service nginx restart 初始账号用户名为 `root`,密码为 `123456`。 + +### 基于 Docker Compose 进行部署 + +> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分 + +```shell +# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内 +docker-compose up -d + +# 查看部署状态 +docker-compose ps +``` + ### 手动部署 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: ```shell diff --git a/docker-compose.yml b/docker-compose.yml index 9b814a03..30edb281 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,19 +9,19 @@ services: ports: - "3000:3000" volumes: - - ./data:/data + - ./data/oneapi:/data - ./logs:/app/logs environment: - - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 + - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - REDIS_CONN_STRING=redis://redis - SESSION_SECRET=random_string # 修改为随机字符串 - TZ=Asia/Shanghai # - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 - depends_on: - redis + - db healthcheck: test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] interval: 30s @@ -32,3 +32,18 @@ services: image: redis:latest container_name: redis restart: always + + db: + image: mysql:8.2.0 + restart: always + container_name: mysql + volumes: + - ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储 + ports: + - '3306:3306' + environment: + TZ: Asia/Shanghai # 设置时区 + MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码 + MYSQL_USER: oneapi # 创建专用用户 + MYSQL_PASSWORD: '123456' # 设置专用用户密码 + MYSQL_DATABASE: one-api # 自动创建数据库 \ No newline at end of file From 58bb3ab6f6008af8135464877259db610f89992d Mon Sep 17 00:00:00 2001 From: Dafei Zhao Date: Fri, 10 Nov 2023 08:50:52 -0500 Subject: [PATCH 06/10] fix: fix channel_id column name (#681, close #688) --- model/log.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/log.go b/model/log.go index d26da9a2..3d3ffae3 100644 --- a/model/log.go +++ b/model/log.go @@ -94,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = tx.Where("created_at <= ?", endTimestamp) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error return logs, err @@ -151,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa tx = tx.Where("model_name = ?", modelName) } if channel != 0 { - tx = tx.Where("channel = ?", channel) + tx = tx.Where("channel_id = ?", channel) } tx.Where("type = ?", LogTypeConsume).Scan("a) return quota From de7b9710a52a943e9b7dd54837916949cbda1663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AE=A1=E5=AE=9C=E5=B0=A7?= Date: Fri, 17 Nov 2023 19:40:59 +0800 Subject: [PATCH 07/10] fix: fix PaLM not working issue (#667) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix for #515 最新版本谷歌PaLM模型无法使用 * update * chore: remove unrelated file * chore: add comment --------- Co-authored-by: JustSong --- controller/relay-text.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/controller/relay-text.go b/controller/relay-text.go index a61c6f7c..b9a300b4 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -367,6 +367,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } case APITypeTencent: req.Header.Set("Authorization", apiKey) + case APITypePaLM: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } From 1d15157f7d2f3d2b2e98c965a383a06c6535e665 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:03:16 +0800 Subject: [PATCH 08/10] feat: keep sync with dall-e updates (#679) * Updated ImageRequest struct and OpenAIModels, added new Dall-E models and size ratios * Fixed suspect `or` * Refactored size ratio calculation in relayImageHelper function * Updated the format of resolution keys in DalleSizeRatios map * Added error handling for unsupported image size in relayImageHelper function * Added validation for number of generated images and defined image generation ratios * Refactored variable name from DalleGenerationImageAmountRatios to DalleGenerationImageAmounts * Added validation for prompt length in relayImageHelper function * Updated model validation and removed size not supported error in relayImageHelper function * Refactored image size and model validation in relayImageHelper function * chore: discard binary file * chore: update impl --------- Co-authored-by: cktsun1031 <65409152+cktsun1031@users.noreply.github.com> Co-authored-by: JustSong --- common/model-ratio.go | 26 ++++++++++++++- controller/model.go | 13 ++++++-- controller/relay-image.go | 67 ++++++++++++++++++++++++++++----------- controller/relay.go | 12 +++++-- 4 files changed, 93 insertions(+), 25 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 681f0ae7..b4a471dc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -6,6 +6,29 @@ import ( "time" ) +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} + // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf @@ -45,7 +68,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e": 8, + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 7bd9d097..f9904330 100644 --- a/controller/model.go +++ b/controller/model.go @@ -55,12 +55,21 @@ func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility openAIModels = []OpenAIModels{ { - Id: "dall-e", + Id: "dall-e-2", Object: "model", Created: 1677649963, OwnedBy: "openai", Permission: permission, - Root: "dall-e", + Root: "dall-e-2", + Parent: nil, + }, + { + Id: "dall-e-3", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e-3", Parent: nil, }, { diff --git a/controller/relay-image.go b/controller/relay-image.go index ccd52dce..1d1b71ba 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -6,15 +6,28 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) +func isWithinRange(element string, value int) bool { + if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + return false + } + + min := common.DalleGenerationImageAmounts[element][0] + max := common.DalleGenerationImageAmounts[element][1] + + return value >= min && value <= max +} + func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e" + imageModel := "dall-e-2" + imageSize := "1024x1024" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -31,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + // Size validation + if imageRequest.Size != "" { + imageSize = imageRequest.Size + } + + // Model validation + if imageRequest.Model != "" { + imageModel = imageRequest.Model + } + + imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] + + // Check if model is supported + if hasValidSize { + if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { + if imageSize == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + // Check prompt length + if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { + return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + // Number of generated images validation + if isWithinRange(imageModel, imageRequest.N) == false { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } // map model name @@ -82,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - sizeRatio := 1.0 - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 1 - } else if imageRequest.Size == "512x512" { - sizeRatio = 1.125 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1.25 - } - quota := int(ratio*sizeRatio*1000) * imageRequest.N + quota := int(ratio*imageCostRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) diff --git a/controller/relay.go b/controller/relay.go index 1926110e..9cff887b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -77,10 +77,16 @@ type TextRequest struct { //Stream bool `json:"stream"` } +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create type ImageRequest struct { - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + ResponseFormat string `json:"response_format"` + Style string `json:"style"` + User string `json:"user"` } type AudioResponse struct { From ddcaf95f5faddce75c095395744526d2d5713343 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 17 Nov 2023 21:18:51 +0800 Subject: [PATCH 09/10] feat: support tts model (#713) * Added support for Text-to-Speech models and endpoints * chore: update impl --------- Co-authored-by: JustSong --- common/model-ratio.go | 6 +- controller/model.go | 36 ++++++++++++ controller/relay-audio.go | 118 +++++++++++++++++++++----------------- controller/relay-utils.go | 19 ++++++ controller/relay.go | 28 +++++++-- middleware/distributor.go | 9 +-- router/relay-router.go | 1 + 7 files changed, 151 insertions(+), 66 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index b4a471dc..74c74a90 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -59,7 +59,11 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters + "tts-1-1106": 7.5, + "tts-1-hd": 15, // $0.030 / 1K characters + "tts-1-hd-1106": 15, "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index f9904330..59ea22e8 100644 --- a/controller/model.go +++ b/controller/model.go @@ -81,6 +81,42 @@ func init() { Root: "whisper-1", Parent: nil, }, + { + Id: "tts-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1", + Parent: nil, + }, + { + Id: "tts-1-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-1106", + Parent: nil, + }, + { + Id: "tts-1-hd", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd", + Parent: nil, + }, + { + Id: "tts-1-hd-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 53833108..01267fbf 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "errors" - "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -21,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") + tokenName := c.GetString("token_name") + + var ttsRequest TextToSpeechRequest + if relayMode == RelayModeAudioSpeech { + // Read JSON + err := common.UnmarshalBodyReusable(c, &ttsRequest) + // Check if JSON is valid + if err != nil { + return errorWrapper(err, "invalid_json", http.StatusBadRequest) + } + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + } preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) @@ -31,22 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + + quota := 0 + // Check if user quota is enough + if relayMode == RelayModeAudioSpeech { + quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) + if quota > userQuota { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + } else { + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } } } @@ -93,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var audioResponse AudioResponse - defer func(ctx context.Context) { - go func() { - quota := countTokenText(audioResponse.Text, audioModel) + if relayMode == RelayModeAudioSpeech { + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + } else { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + var whisperResponse WhisperResponse + err = json.Unmarshal(responseBody, &whisperResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + defer func(ctx context.Context) { + quota := countTokenText(whisperResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..888187cb 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,6 +1,7 @@ package controller import ( + "context" "encoding/json" "fmt" "github.com/gin-gonic/gin" @@ -8,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/model" "strconv" "strings" ) @@ -186,3 +188,20 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin } return fullRequestURL } + +func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } +} diff --git a/controller/relay.go b/controller/relay.go index 9cff887b..863267b4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,9 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits - RelayModeAudio + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation ) // https://platform.openai.com/docs/api-reference/chat @@ -89,10 +91,18 @@ type ImageRequest struct { User string `json:"user"` } -type AudioResponse struct { +type WhisperResponse struct { Text string `json:"text,omitempty"` } +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` @@ -189,14 +199,22 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - relayMode = RelayModeAudio + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") { + relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) - case RelayModeAudio: + case RelayModeAudioSpeech: + fallthrough + case RelayModeAudioTranslation: + fallthrough + case RelayModeAudioTranscription: err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) diff --git a/middleware/distributor.go b/middleware/distributor.go index d80945fc..c4ddc3a0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - var err error - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } + err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { - modelRequest.Model = "dall-e" + modelRequest.Model = "dall-e-2" } } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } diff --git a/router/relay-router.go b/router/relay-router.go index e84f02db..912f4989 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) From 34d517cfa233d90b470e4d6ac0208051de4b4fea Mon Sep 17 00:00:00 2001 From: Mikey Date: Fri, 17 Nov 2023 05:45:55 -0800 Subject: [PATCH 10/10] fix: cloudflare test & expose detailed info about test failures (#715) * fix: cloudflare test & expose detailed info about test failures * fix: cloudflare test & expose detailed info about test failures --------- Co-authored-by: JustSong --- controller/channel-test.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 3c6c8f43..b47a44b9 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,11 +6,11 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "io" "net/http" "one-api/common" "one-api/model" "strconv" - "strings" "sync" "time" ) @@ -45,13 +45,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai if channel.Type == common.ChannelTypeAzure { requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) } else { - if channel.GetBaseURL() != "" { - requestURL = channel.GetBaseURL() + if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { + requestURL = baseURL } - requestURL += "/v1/chat/completions" + requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) } - // for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639 - requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1) jsonData, err := json.Marshal(request) if err != nil { @@ -73,10 +71,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } defer resp.Body.Close() var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) + body, err := io.ReadAll(resp.Body) if err != nil { return err, nil } + err = json.Unmarshal(body, &response) + if err != nil { + return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil + } if response.Usage.CompletionTokens == 0 { return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error }