Merge branch 'main' into feat/zhipu_support_text_embedding

This commit is contained in:
igophper 2023-11-20 21:56:37 +08:00 committed by GitHub
commit 3148a525b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 683 additions and 246 deletions

3
.gitignore vendored
View File

@ -5,4 +5,5 @@ upload
*.db *.db
build build
*.db-journal *.db-journal
logs logs
data

View File

@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. First, fork the code. 1. First, fork the code.
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).

View File

@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。 > Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. まず、コードをフォークする。 1. まず、コードをフォークする。
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。 2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。 3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。

View File

@ -75,7 +75,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
2. 支持配置镜像以及众多第三方代理服务: 2. 支持配置镜像以及众多第三方代理服务:
+ [x] [OpenAI-SB](https://openai-sb.com) + [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412) + [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
+ [x] [API2D](https://api2d.com/r/197971) + [x] [API2D](https://api2d.com/r/197971)
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI` + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
@ -92,15 +92,16 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
12. 支持**用户邀请奖励**。 12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。 13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。
15. 支持模型映射,重定向用户的请求模型。 15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功
16. 支持失败自动重试。 16. 支持失败自动重试。
17. 支持绘图接口。 17. 支持绘图接口。
18. 支持丰富的**自定义**设置, 18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
19. 支持通过系统访问令牌访问管理 API 20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
20. 支持 Cloudflare Turnstile 用户校验。 21. 支持 Cloudflare Turnstile 用户校验。
21. 支持用户管理,支持**多种用户登录注册方式** 22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
@ -159,6 +160,19 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456` 初始账号用户名为 `root`,密码为 `123456`
### 基于 Docker Compose 进行部署
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
```shell
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
docker-compose up -d
# 查看部署状态
docker-compose ps
```
### 手动部署 ### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
@ -248,6 +262,8 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. 首先 fork 一份代码。 1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口 3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
@ -351,6 +367,10 @@ graph LR
13. 请求频率限制: 13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180` + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60` + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`
14. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`

View File

@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true var DisplayTokenStatEnabled = true
var UsingSQLite = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions // Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String() var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex var OptionMapRWMutex sync.RWMutex
@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
const ( const (
RequestIdKey = "X-Oneapi-Request-Id" RequestIdKey = "X-Oneapi-Request-Id"
) )

6
common/database.go Normal file
View File

@ -0,0 +1,6 @@
package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
func UnmarshalBodyReusable(c *gin.Context, v any) error { func UnmarshalBodyReusable(c *gin.Context, v any) error {
@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil { if err != nil {
return err return err
} }
err = json.Unmarshal(requestBody, &v) contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil { if err != nil {
return err return err
} }

View File

@ -3,8 +3,32 @@ package common
import ( import (
"encoding/json" "encoding/json"
"strings" "strings"
"time"
) )
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
// ModelRatio // ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30, "gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10, "text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,
@ -41,13 +72,16 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e": 8, "dall-e-2": 8, // $0.016 - $0.020 / image
"dall-e-3": 20, // $0.040 - $0.120 / image
"claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-instant-1": 0.815, // $1.63 / 1M tokens
"claude-2": 5.51, // $11.02 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1, "PaLM-2": 1,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
@ -87,9 +121,24 @@ func GetModelRatio(name string) float64 {
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-3.5") { if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333 return 1.333333
} }
if strings.HasPrefix(name, "gpt-4") { if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2 return 2
} }
if strings.HasPrefix(name, "claude-instant-1") { if strings.HasPrefix(name, "claude-instant-1") {

View File

@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
func MessageWithRequestId(message string, id string) string { func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id) return fmt.Sprintf("%s (request id: %s)", message, id)
} }
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@ -5,13 +5,15 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
requestURL := common.ChannelBaseURLs[channel.Type] requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else { } else {
if channel.GetBaseURL() != "" { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = channel.GetBaseURL() requestURL = baseURL
} }
requestURL += "/v1/chat/completions"
}
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
}
jsonData, err := json.Marshal(request) jsonData, err := json.Marshal(request)
if err != nil { if err != nil {
return err, nil return err, nil
@ -70,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
defer resp.Body.Close() defer resp.Body.Close()
var response TextResponse var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return err, nil return err, nil
} }
err = json.Unmarshal(body, &response)
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
if response.Usage.CompletionTokens == 0 { if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
} }

View File

@ -127,8 +127,8 @@ func DeleteChannel(c *gin.Context) {
return return
} }
func DeleteManuallyDisabledChannel(c *gin.Context) { func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteChannelByStatus(common.ChannelStatusManuallyDisabled) rows, err := model.DeleteDisabledChannel()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,

View File

@ -55,12 +55,21 @@ func init() {
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
openAIModels = []OpenAIModels{ openAIModels = []OpenAIModels{
{ {
Id: "dall-e", Id: "dall-e-2",
Object: "model", Object: "model",
Created: 1677649963, Created: 1677649963,
OwnedBy: "openai", OwnedBy: "openai",
Permission: permission, Permission: permission,
Root: "dall-e", Root: "dall-e-2",
Parent: nil,
},
{
Id: "dall-e-3",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "dall-e-3",
Parent: nil, Parent: nil,
}, },
{ {
@ -72,6 +81,42 @@ func init() {
Root: "whisper-1", Root: "whisper-1",
Parent: nil, Parent: nil,
}, },
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{ {
Id: "gpt-3.5-turbo", Id: "gpt-3.5-turbo",
Object: "model", Object: "model",
@ -117,6 +162,15 @@ func init() {
Root: "gpt-3.5-turbo-16k-0613", Root: "gpt-3.5-turbo-16k-0613",
Parent: nil, Parent: nil,
}, },
{
Id: "gpt-3.5-turbo-1106",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-1106",
Parent: nil,
},
{ {
Id: "gpt-3.5-turbo-instruct", Id: "gpt-3.5-turbo-instruct",
Object: "model", Object: "model",
@ -180,6 +234,24 @@ func init() {
Root: "gpt-4-32k-0613", Root: "gpt-4-32k-0613",
Parent: nil, Parent: nil,
}, },
{
Id: "gpt-4-1106-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-1106-preview",
Parent: nil,
},
{
Id: "gpt-4-vision-preview",
Object: "model",
Created: 1699593571,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-vision-preview",
Parent: nil,
},
{ {
Id: "text-embedding-ada-002", Id: "text-embedding-ada-002",
Object: "model", Object: "model",
@ -274,7 +346,7 @@ func init() {
Id: "claude-instant-1", Id: "claude-instant-1",
Object: "model", Object: "model",
Created: 1677649963, Created: 1677649963,
OwnedBy: "anturopic", OwnedBy: "anthropic",
Permission: permission, Permission: permission,
Root: "claude-instant-1", Root: "claude-instant-1",
Parent: nil, Parent: nil,
@ -283,7 +355,7 @@ func init() {
Id: "claude-2", Id: "claude-2",
Object: "model", Object: "model",
Created: 1677649963, Created: 1677649963,
OwnedBy: "anturopic", OwnedBy: "anthropic",
Permission: permission, Permission: permission,
Root: "claude-2", Root: "claude-2",
Parent: nil, Parent: nil,
@ -306,6 +378,15 @@ func init() {
Root: "ERNIE-Bot-turbo", Root: "ERNIE-Bot-turbo",
Parent: nil, Parent: nil,
}, },
{
Id: "ERNIE-Bot-4",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-4",
Parent: nil,
},
{ {
Id: "Embedding-V1", Id: "Embedding-V1",
Object: "model", Object: "model",
@ -324,6 +405,15 @@ func init() {
Root: "PaLM-2", Root: "PaLM-2",
Parent: nil, Parent: nil,
}, },
{
Id: "chatglm_turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_turbo",
Parent: nil,
},
{ {
Id: "chatglm_pro", Id: "chatglm_pro",
Object: "model", Object: "model",

View File

@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
query := "" query := ""
if len(request.Messages) != 0 { if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].Content query = request.Messages[len(request.Messages)-1].StringContent()
} }
return &AIProxyLibraryRequest{ return &AIProxyLibraryRequest{
Model: request.Model, Model: request.Model,

View File

@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
message := request.Messages[i] message := request.Messages[i]
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.StringContent(),
Bot: "Okay", Bot: "Okay",
}) })
continue continue
} else { } else {
if i == len(request.Messages)-1 { if i == len(request.Messages)-1 {
prompt = message.Content prompt = message.StringContent()
break break
} }
messages = append(messages, AliMessage{ messages = append(messages, AliMessage{
User: message.Content, User: message.StringContent(),
Bot: request.Messages[i+1].Content, Bot: request.Messages[i+1].StringContent(),
}) })
i++ i++
} }

View File

@ -5,13 +5,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin"
) )
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -22,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
tokenName := c.GetString("token_name")
var ttsRequest TextToSpeechRequest
if relayMode == RelayModeAudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
if err != nil {
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
}
audioModel = ttsRequest.Model
// Check if text is too long 4096
if len(ttsRequest.Input) > 4096 {
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel) modelRatio := common.GetModelRatio(audioModel)
@ -32,22 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) quota := 0
} // Check if user quota is enough
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if relayMode == RelayModeAudioSpeech {
if err != nil { quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) if quota > userQuota {
} return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
if userQuota > 100*preConsumedQuota { }
// in this case, we do not pre-consume quota } else {
// because the user has enough quota if userQuota-preConsumedQuota < 0 {
preConsumedQuota = 0 return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
if preConsumedQuota > 0 { err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
} }
} }
@ -66,12 +90,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
@ -95,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
var audioResponse AudioResponse
defer func(ctx context.Context) { if relayMode == RelayModeAudioSpeech {
go func() { defer func(ctx context.Context) {
quota := countTokenText(audioResponse.Text, audioModel) go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
} else {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var whisperResponse WhisperResponse
err = json.Unmarshal(responseBody, &whisperResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
defer func(ctx context.Context) {
quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta) go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
if err != nil { }(c.Request.Context())
common.SysError("error consuming token remain quota: " + err.Error()) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }

View File

@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: "assistant", Role: "assistant",
@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
} else { } else {
messages = append(messages, BaiduMessage{ messages = append(messages, BaiduMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }

View File

@ -14,8 +14,20 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func isWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}
min := common.DalleGenerationImageAmounts[element][0]
max := common.DalleGenerationImageAmounts[element][1]
return value >= min && value <= max
}
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e" imageModel := "dall-e-2"
imageSize := "1024x1024"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
@ -32,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
} }
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation
if imageRequest.Model != "" {
imageModel = imageRequest.Model
}
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
// Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
// Prompt validation // Prompt validation
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
} }
// Not "256x256", "512x512", or "1024x1024" // Check prompt length
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
} }
// N should between 1 and 10 // Number of generated images validation
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
} }
// map model name // map model name
@ -61,16 +98,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true isModelMapped = true
} }
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader var requestBody io.Reader
if isModelMapped { if isModelMapped {
jsonStr, err := json.Marshal(imageRequest) jsonStr, err := json.Marshal(imageRequest)
@ -87,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0 quota := int(ratio*imageCostRatio*1000) * imageRequest.N
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)

View File

@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
if textResponse.Usage.TotalTokens == 0 { if textResponse.Usage.TotalTokens == 0 {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += countTokenText(choice.Message.Content, model) completionTokens += countTokenText(choice.Message.StringContent(), model)
} }
textResponse.Usage = Usage{ textResponse.Usage = Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,

View File

@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
} }
for _, message := range textRequest.Messages { for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{ palmMessage := PaLMChatMessage{
Content: message.Content, Content: message.StringContent(),
} }
if message.Role == "user" { if message.Role == "user" {
palmMessage.Author = "0" palmMessage.Author = "0"

View File

@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Role: "assistant", Role: "assistant",
@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
continue continue
} }
messages = append(messages, TencentMessage{ messages = append(messages, TencentMessage{
Content: message.Content, Content: message.StringContent(),
Role: message.Role, Role: message.Role,
}) })
} }

View File

@ -6,13 +6,15 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin"
) )
const ( const (
@ -31,7 +33,14 @@ var httpClient *http.Client
var impatientHTTPClient *http.Client var impatientHTTPClient *http.Client
func init() { func init() {
httpClient = &http.Client{} if common.RelayTimeout == 0 {
httpClient = &http.Client{}
} else {
httpClient = &http.Client{
Timeout: time.Duration(common.RelayTimeout) * time.Second,
}
}
impatientHTTPClient = &http.Client{ impatientHTTPClient = &http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@ -118,7 +127,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType { switch apiType {
case APITypeOpenAI: case APITypeOpenAI:
if channelType == common.ChannelTypeAzure { if channelType == common.ChannelTypeAzure {
@ -138,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613") model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
} }
case APITypeClaude: case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete" fullRequestURL = "https://api.anthropic.com/v1/complete"
@ -151,6 +162,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo": case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "BLOOMZ-7B": case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1": case "Embedding-V1":
@ -367,11 +380,16 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
} }
case APITypeTencent: case APITypeTencent:
req.Header.Set("Authorization", apiKey) req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
default: default:
req.Header.Set("Authorization", "Bearer "+apiKey) req.Header.Set("Authorization", "Bearer "+apiKey)
} }
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
//req.Header.Set("Connection", c.Request.Header.Get("Connection")) //req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err = httpClient.Do(req) resp, err = httpClient.Do(req)
if err != nil { if err != nil {
@ -412,9 +430,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
completionRatio := common.GetCompletionRatio(textRequest.Model) completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = textResponse.Usage.PromptTokens promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens completionTokens = textResponse.Usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {
quota = 1 quota = 1
} }

View File

@ -1,15 +1,18 @@
package controller package controller
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@ -84,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
tokenNum := 0 tokenNum := 0
for _, message := range messages { for _, message := range messages {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content) tokenNum += getTokenNum(tokenEncoder, message.StringContent())
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil { if message.Name != nil {
tokenNum += tokensPerName tokenNum += tokensPerName
@ -176,3 +179,35 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return return
} }
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}

View File

@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "user", Role: "user",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: "assistant", Role: "assistant",
@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
} else { } else {
messages = append(messages, XunfeiMessage{ messages = append(messages, XunfeiMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }
@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
for !stop { for !stop {
select { select {
case xunfeiResponse = <-dataChan: case xunfeiResponse = <-dataChan:
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
continue
}
content += xunfeiResponse.Payload.Choices.Text[0].Content content += xunfeiResponse.Payload.Choices.Text[0].Content
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
common.SysLog("api_version not found, use default: " + apiVersion) common.SysLog("api_version not found, use default: " + apiVersion)
} }
domain := "general" domain := "general"
if apiVersion == "v2.1" { if apiVersion != "v1.1" {
domain = "generalv2" domain += strings.Split(apiVersion, ".")[0]
} }
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
return domain, authUrl return domain, authUrl

View File

@ -154,7 +154,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "system", Role: "system",
Content: message.Content, Content: message.StringContent(),
}) })
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: "user", Role: "user",
@ -163,7 +163,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
} else { } else {
messages = append(messages, ZhipuMessage{ messages = append(messages, ZhipuMessage{
Role: message.Role, Role: message.Role,
Content: message.Content, Content: message.StringContent(),
}) })
} }
} }

View File

@ -12,10 +12,49 @@ import (
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content any `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
} }
type ImageURL struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
type TextContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
}
type ImageContent struct {
Type string `json:"type,omitempty"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == "text" {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
}
}
return contentStr
}
return ""
}
const ( const (
RelayModeUnknown = iota RelayModeUnknown = iota
RelayModeChatCompletions RelayModeChatCompletions
@ -24,24 +63,37 @@ const (
RelayModeModerations RelayModeModerations
RelayModeImagesGenerations RelayModeImagesGenerations
RelayModeEdits RelayModeEdits
RelayModeAudio RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type ResponseFormat struct {
Type string `json:"type,omitempty"`
}
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"` N int `json:"n,omitempty"`
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"` Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"` Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {
@ -77,16 +129,30 @@ type TextRequest struct {
//Stream bool `json:"stream"` //Stream bool `json:"stream"`
} }
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
type ImageRequest struct { type ImageRequest struct {
Prompt string `json:"prompt"` Model string `json:"model"`
N int `json:"n"` Prompt string `json:"prompt" binding:"required"`
Size string `json:"size"` N int `json:"n"`
Size string `json:"size"`
Quality string `json:"quality"`
ResponseFormat string `json:"response_format"`
Style string `json:"style"`
User string `json:"user"`
} }
type AudioResponse struct { type WhisperResponse struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
} }
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed float64 `json:"speed"`
ResponseFormat string `json:"response_format"`
}
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
@ -183,14 +249,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudio relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
relayMode = RelayModeAudioTranslation
} }
var err *OpenAIErrorWithStatusCode var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case RelayModeImagesGenerations: case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode) err = relayImageHelper(c, relayMode)
case RelayModeAudio: case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode) err = relayAudioHelper(c, relayMode)
default: default:
err = relayTextHelper(c, relayMode) err = relayTextHelper(c, relayMode)

View File

@ -9,21 +9,21 @@ services:
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data:/data - ./data/oneapi:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - NODE_TYPE=slave # 多机部署时从节点取消注释该行
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
depends_on: depends_on:
- redis - redis
- db
healthcheck: healthcheck:
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3
@ -32,3 +32,18 @@ services:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2.0
restart: always
container_name: mysql
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
ports:
- '3306:3306'
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: oneapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: one-api # 自动创建数据库

10
go.mod
View File

@ -15,8 +15,9 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5 github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.9.0 golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3 gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0 gorm.io/gorm v1.25.0
) )
@ -52,10 +53,9 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.10.0 // indirect golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.9.0 // indirect golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.5.2 // indirect
) )

17
go.sum
View File

@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

View File

@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
} else { } else {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error err := common.UnmarshalBodyReusable(c, &modelRequest)
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return return
@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "dall-e" modelRequest.Model = "dall-e-2"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }

View File

@ -15,10 +15,17 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{} ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite { if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error err = channelQuery.Order("RANDOM()").First(&ability).Error
} else { } else {
err = channelQuery.Order("RAND()").First(&ability).Error err = channelQuery.Order("RAND()").First(&ability).Error

View File

@ -21,14 +21,18 @@ var (
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token var token Token
if !common.RedisEnabled { if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err return &token, err
} }
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil { if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }
@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err return &channel, err
} }
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error { func BatchInsertChannels(channels []Channel) error {
var err error var err error
err = DB.Create(&channels).Error err = DB.Create(&channels).Error
@ -181,3 +174,8 @@ func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{}) result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@ -94,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err return logs, err
@ -151,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel = ?", channel) tx = tx.Where("channel_id = ?", channel)
} }
tx.Where("type = ?", LogTypeConsume).Scan(&quota) tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota return quota

View File

@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
DSN: dsn, DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
} }
redemption := &Redemption{} redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }

View File

@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
} }
func GetUserGroup(id int) (group string, err error) { func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err return group, err
} }
@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if common.BatchUpdateEnabled { if common.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return return
} }
updateUserUsedQuotaAndRequestCount(id, quota, 1) updateUserUsedQuotaAndRequestCount(id, quota, 1)
@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
} }
} }
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
common.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
common.SysError("failed to update user request count: " + err.Error())
}
}
func GetUsernameById(id int) (username string) { func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username return username

View File

@ -6,13 +6,13 @@ import (
"time" "time"
) )
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
const ( const (
BatchUpdateTypeUserQuota = iota BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuotaAndRequestCount BatchUpdateTypeUsedQuota
BatchUpdateTypeChannelUsedQuota BatchUpdateTypeChannelUsedQuota
BatchUpdateTypeRequestCount
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
) )
var batchUpdateStores []map[int]int var batchUpdateStores []map[int]int
@ -51,7 +51,7 @@ func batchUpdate() {
store := batchUpdateStores[i] store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int) batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock() batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store { for key, value := range store {
switch i { switch i {
case BatchUpdateTypeUserQuota: case BatchUpdateTypeUserQuota:
@ -64,8 +64,10 @@ func batchUpdate() {
if err != nil { if err != nil {
common.SysError("failed to batch update token quota: " + err.Error()) common.SysError("failed to batch update token quota: " + err.Error())
} }
case BatchUpdateTypeUsedQuotaAndRequestCount: case BatchUpdateTypeUsedQuota:
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota: case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value) updateChannelUsedQuota(key, value)
} }

View File

@ -74,7 +74,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
channelRoute.POST("/", controller.AddChannel) channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel) channelRoute.PUT("/", controller.UpdateChannel)
channelRoute.DELETE("/manually_disabled", controller.DeleteManuallyDisabledChannel) channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
channelRoute.DELETE("/:id", controller.DeleteChannel) channelRoute.DELETE("/:id", controller.DeleteChannel)
} }
tokenRoute := apiRouter.Group("/token") tokenRoute := apiRouter.Group("/token")

View File

@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)

View File

@ -283,7 +283,9 @@ function App() {
</Suspense> </Suspense>
} }
/> />
<Route path='*' element={NotFound} /> <Route path='*' element={
<NotFound />
} />
</Routes> </Routes>
); );
} }

View File

@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render'; import { renderGroup, renderNumber } from '../helpers/render';
@ -55,6 +55,7 @@ const ChannelsTable = () => {
const [searchKeyword, setSearchKeyword] = useState(''); const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false); const [searching, setSearching] = useState(false);
const [updatingBalance, setUpdatingBalance] = useState(false); const [updatingBalance, setUpdatingBalance] = useState(false);
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
const loadChannels = async (startIdx) => { const loadChannels = async (startIdx) => {
const res = await API.get(`/api/channel/?p=${startIdx}`); const res = await API.get(`/api/channel/?p=${startIdx}`);
@ -226,7 +227,6 @@ const ChannelsTable = () => {
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else { } else {
showError(message); showError(message);
showNotice('当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。');
} }
}; };
@ -240,11 +240,11 @@ const ChannelsTable = () => {
} }
}; };
const deleteAllManuallyDisabledChannels = async () => { const deleteAllDisabledChannels = async () => {
const res = await API.delete(`/api/channel/manually_disabled`); const res = await API.delete(`/api/channel/disabled`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
showSuccess(`已删除所有手动禁用渠道,共计 ${data}`); showSuccess(`已删除所有禁用渠道,共计 ${data}`);
await refresh(); await refresh();
} else { } else {
showError(message); showError(message);
@ -286,17 +286,15 @@ const ChannelsTable = () => {
if (channels.length === 0) return; if (channels.length === 0) return;
setLoading(true); setLoading(true);
let sortedChannels = [...channels]; let sortedChannels = [...channels];
if (typeof sortedChannels[0][key] === 'string') { sortedChannels.sort((a, b) => {
sortedChannels.sort((a, b) => { if (!isNaN(a[key])) {
// If the value is numeric, subtract to sort
return a[key] - b[key];
} else {
// If the value is not numeric, sort as strings
return ('' + a[key]).localeCompare(b[key]); return ('' + a[key]).localeCompare(b[key]);
}); }
} else { });
sortedChannels.sort((a, b) => {
if (a[key] === b[key]) return 0;
if (a[key] > b[key]) return -1;
if (a[key] < b[key]) return 1;
});
}
if (sortedChannels[0].id === channels[0].id) { if (sortedChannels[0].id === channels[0].id) {
sortedChannels.reverse(); sortedChannels.reverse();
} }
@ -304,6 +302,7 @@ const ChannelsTable = () => {
setLoading(false); setLoading(false);
}; };
return ( return (
<> <>
<Form onSubmit={searchChannels}> <Form onSubmit={searchChannels}>
@ -317,7 +316,19 @@ const ChannelsTable = () => {
onChange={handleKeywordChange} onChange={handleKeywordChange}
/> />
</Form> </Form>
{
showPrompt && (
<Message onDismiss={() => {
setShowPrompt(false);
setPromptShown("channel-test");
}}>
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
模型进行非流式请求实现的因此测试报错并不一定代表通道不可用该功能后续会修复
另外OpenAI 渠道已经不再支持通过 key 获取余额因此余额显示为 0对于支持的渠道类型请点击余额进行刷新
</Message>
)
}
<Table basic compact size='small'> <Table basic compact size='small'>
<Table.Header> <Table.Header>
<Table.Row> <Table.Row>
@ -519,14 +530,14 @@ const ChannelsTable = () => {
<Popup <Popup
trigger={ trigger={
<Button size='small' loading={loading}> <Button size='small' loading={loading}>
删除所有手动禁用渠道 删除禁用渠道
</Button> </Button>
} }
on='click' on='click'
flowing flowing
hoverable hoverable
> >
<Button size='small' loading={loading} negative onClick={deleteAllManuallyDisabledChannels}> <Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
确认删除 确认删除
</Button> </Button>
</Popup> </Popup>

View File

@ -130,7 +130,13 @@ const RedemptionsTable = () => {
setLoading(true); setLoading(true);
let sortedRedemptions = [...redemptions]; let sortedRedemptions = [...redemptions];
sortedRedemptions.sort((a, b) => { sortedRedemptions.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]); if (!isNaN(a[key])) {
// If the value is numeric, subtract to sort
return a[key] - b[key];
} else {
// If the value is not numeric, sort as strings
return ('' + a[key]).localeCompare(b[key]);
}
}); });
if (sortedRedemptions[0].id === redemptions[0].id) { if (sortedRedemptions[0].id === redemptions[0].id) {
sortedRedemptions.reverse(); sortedRedemptions.reverse();

View File

@ -138,7 +138,7 @@ const TokensTable = () => {
let defaultUrl; let defaultUrl;
if (chatLink) { if (chatLink) {
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} else { } else {
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
} }
@ -228,7 +228,13 @@ const TokensTable = () => {
setLoading(true); setLoading(true);
let sortedTokens = [...tokens]; let sortedTokens = [...tokens];
sortedTokens.sort((a, b) => { sortedTokens.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]); if (!isNaN(a[key])) {
// If the value is numeric, subtract to sort
return a[key] - b[key];
} else {
// If the value is not numeric, sort as strings
return ('' + a[key]).localeCompare(b[key]);
}
}); });
if (sortedTokens[0].id === tokens[0].id) { if (sortedTokens[0].id === tokens[0].id) {
sortedTokens.reverse(); sortedTokens.reverse();

View File

@ -133,7 +133,13 @@ const UsersTable = () => {
setLoading(true); setLoading(true);
let sortedUsers = [...users]; let sortedUsers = [...users];
sortedUsers.sort((a, b) => { sortedUsers.sort((a, b) => {
return ('' + a[key]).localeCompare(b[key]); if (!isNaN(a[key])) {
// If the value is numeric, subtract to sort
return a[key] - b[key];
} else {
// If the value is not numeric, sort as strings
return ('' + a[key]).localeCompare(b[key]);
}
}); });
if (sortedUsers[0].id === users[0].id) { if (sortedUsers[0].id === users[0].id) {
sortedUsers.reverse(); sortedUsers.reverse();

View File

@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
return false; return false;
} }
return true; return true;
}; };
export function shouldShowPrompt(id) {
let prompt = localStorage.getItem(`prompt-${id}`);
return !prompt;
}
export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true');
}

View File

@ -66,13 +66,13 @@ const EditChannel = () => {
localModels = ['PaLM-2']; localModels = ['PaLM-2'];
break; break;
case 15: case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break; break;
case 17: case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1']; localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
break; break;
case 16: case 16:
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding']; localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite', 'text_embedding'];
break; break;
case 18: case 18:
localModels = ['SparkDesk']; localModels = ['SparkDesk'];

View File

@ -1,19 +1,12 @@
import React from 'react'; import React from 'react';
import { Segment, Header } from 'semantic-ui-react'; import { Message } from 'semantic-ui-react';
const NotFound = () => ( const NotFound = () => (
<> <>
<Header <Message negative>
block <Message.Header>页面不存在</Message.Header>
as="h4" <p>请检查你的浏览器地址是否正确</p>
content="404" </Message>
attached="top"
icon="info"
className="small-icon"
/>
<Segment attached="bottom">
未找到所请求的页面
</Segment>
</> </>
); );