diff --git a/README.md b/README.md index f0007143..bb35c6d0 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 > **Warning** > 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。 -> **此分叉最新版Docker镜像** -> calciumion/one-api-midjourney:latest +> **Note** +> 此分叉最新版Docker镜像 calciumion/one-api-midjourney:latest ## 此分叉版本的主要变更 1. 添加[Midjourney-Proxy](https://github.com/novicezk/midjourney-proxy)接口的支持: @@ -78,6 +78,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销 4. 渠道显示已使用额度,支持指定组织访问 5. 分页支持选择每页显示数量 + + ## 功能 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) @@ -88,6 +90,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [360 智脑](https://ai.360.cn) + + [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) 2. 支持配置镜像以及众多第三方代理服务: + [x] [OpenAI-SB](https://openai-sb.com) + [x] [CloseAI](https://console.closeai-asia.com/r/2412) @@ -110,23 +113,30 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 15. 支持模型映射,重定向用户的请求模型。 16. 支持失败自动重试。 17. 支持绘图接口。 -18. 支持丰富的**自定义**设置, +18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 +19. 支持丰富的**自定义**设置, 1. 支持自定义系统名称,logo 以及页脚。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 -19. 支持通过系统访问令牌访问管理 API。 -20. 支持 Cloudflare Turnstile 用户校验。 -21. 支持用户管理,支持**多种用户登录注册方式**: +20. 支持通过系统访问令牌访问管理 API。 +21. 支持 Cloudflare Turnstile 用户校验。 +22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 ## 部署 ### 基于 Docker 进行部署 -部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api` +```shell +# 使用 SQLite 的部署命令: +docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api +# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。 +# 例如: +docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api +``` 其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。 -数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 +数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。 如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。 @@ -270,6 +280,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope +
+部署到 Render +
+ +> Render 提供免费额度,绑卡后可以进一步提升额度 + +Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com + +
+
+ ## 配置 系统本身开箱即用。 @@ -297,10 +318,11 @@ OPENAI_API_BASE="https://:/v1" ```mermaid graph LR A(用户) - A --->|请求| B(One API) + A --->|使用 One API 分发的 key 进行请求| B(One API) B -->|中继请求| C(OpenAI) B -->|中继请求| D(Azure) - B -->|中继请求| E(其他下游渠道) + B -->|中继请求| E(其他 OpenAI API 格式下游渠道) + B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道) ``` 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 @@ -328,22 +350,24 @@ graph LR + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步。 +5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`MEMORY_CACHE_ENABLED=true` +6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + 例子:`SYNC_FREQUENCY=60` -6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 +7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` -7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 +8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 +9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + 例子:`CHANNEL_TEST_FREQUENCY=1440` -9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 - + 例子:`POLLING_INTERVAL=5` -10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + + 例子:`POLLING_INTERVAL=5` +11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -12. 请求频率限制: +13. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 @@ -385,6 +409,12 @@ https://openai.justsong.cn + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。 6. 报错:`当前分组负载已饱和,请稍后再试` + 上游通道 429 了。 +7. 升级之后我的数据会丢失吗? + + 如果使用 MySQL,不会。 + + 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。 +8. 升级之前数据库需要做变更吗? + + 一般情况下不需要,系统将在初始化的时候自动调整。 + + 如果需要的话,我会在更新日志中说明,并给出脚本。 ## 相关项目 * [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 diff --git a/common/constants.go b/common/constants.go index 725aa772..0c00018d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -60,6 +60,7 @@ var EmailDomainWhitelist = []string{ } var DebugEnabled = os.Getenv("DEBUG") == "true" +var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" var LogConsumeEnabled = true @@ -96,7 +97,7 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) @@ -159,9 +160,10 @@ const ( ) const ( - ChannelStatusUnknown = 0 - ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! - ChannelStatusDisabled = 2 // also don't use 0 + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 ) const ( @@ -188,30 +190,32 @@ const ( ChannelTypeOpenRouter = 20 ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 + ChannelTypeTencent = 23 ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", //23 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 4b3dd763..f1ce99a5 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -59,7 +59,7 @@ var ModelRatio = map[string]float64{ "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens + "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 } func ModelRatio2JSONString() string { diff --git a/controller/channel-test.go b/controller/channel-test.go index 45cf604b..e91ccad3 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -141,7 +141,7 @@ func disableChannel(channelId int, channelName string, reason string) { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } - model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) err := common.SendEmail(subject, common.RootUserEmail, content) diff --git a/controller/channel.go b/controller/channel.go index 5c733670..98f42421 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -131,6 +131,23 @@ func DeleteChannel(c *gin.Context) { return } +func DeleteDisabledChannel(c *gin.Context) { + rows, err := model.DeleteDisabledChannel() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": rows, + }) + return +} + func UpdateChannel(c *gin.Context) { channel := model.Channel{} err := c.ShouldBindJSON(&channel) diff --git a/controller/model.go b/controller/model.go index ae2061b3..e9b64514 100644 --- a/controller/model.go +++ b/controller/model.go @@ -424,12 +424,12 @@ func init() { Parent: nil, }, { - Id: "360GPT_S2_V9.4", + Id: "hunyuan", Object: "model", Created: 1677649963, - OwnedBy: "360", + OwnedBy: "tencent", Permission: permission, - Root: "360GPT_S2_V9.4", + Root: "hunyuan", Parent: nil, }, } diff --git a/controller/option.go b/controller/option.go index 9cf4ff1b..bbf83578 100644 --- a/controller/option.go +++ b/controller/option.go @@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) { if option.Value == "true" && common.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!", + "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", }) return } diff --git a/controller/relay-audio.go b/controller/relay-audio.go index f5903dae..644a9d2d 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -31,6 +32,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) diff --git a/controller/relay-image.go b/controller/relay-image.go index 0c3ec12c..db3d9242 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -99,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*sizeRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { - return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go new file mode 100644 index 00000000..024468bc --- /dev/null +++ b/controller/relay-tencent.go @@ -0,0 +1,287 @@ +package controller + +import ( + "bufio" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "sort" + "strconv" + "strings" +) + +// https://cloud.tencent.com/document/product/1729/97732 + +type TencentMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type TencentChatRequest struct { + AppId int64 `json:"app_id"` // 腾讯云账号的 APPID + SecretId string `json:"secret_id"` // 官网 SecretId + // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 + // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 + Timestamp int64 `json:"timestamp"` + // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, + // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 + Expired int64 `json:"expired"` + QueryID string `json:"query_id"` //请求 Id,用于问题排查 + // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 + // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 + // 建议该参数和 top_p 只设置1个,不要同时更改 top_p + Temperature float64 `json:"temperature"` + // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 + // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 + // 建议该参数和 temperature 只设置1个,不要同时更改 + TopP float64 `json:"top_p"` + // Stream 0:同步,1:流式 (默认,协议:SSE) + // 同步请求超时:60s,如果内容较长建议使用流式 + Stream int `json:"stream"` + // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 + // 输入 content 总数最大支持 3000 token。 + Messages []TencentMessage `json:"messages"` +} + +type TencentError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TencentUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TencentResponseChoices struct { + FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 + Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 + Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 +} + +type TencentChatResponse struct { + Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 + Created string `json:"created,omitempty"` // unix 时间戳的字符串 + Id string `json:"id,omitempty"` // 会话 id + Usage Usage `json:"usage,omitempty"` // token 数量 + Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 + Note string `json:"note,omitempty"` // 注释 + ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 +} + +func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { + messages := make([]TencentMessage, 0, len(request.Messages)) + for i := 0; i < len(request.Messages); i++ { + message := request.Messages[i] + if message.Role == "system" { + messages = append(messages, TencentMessage{ + Role: "user", + Content: message.Content, + }) + messages = append(messages, TencentMessage{ + Role: "assistant", + Content: "Okay", + }) + continue + } + messages = append(messages, TencentMessage{ + Content: message.Content, + Role: message.Role, + }) + } + stream := 0 + if request.Stream { + stream = 1 + } + return &TencentChatRequest{ + Timestamp: common.GetTimestamp(), + Expired: common.GetTimestamp() + 24*60*60, + QueryID: common.GetUUID(), + Temperature: request.Temperature, + TopP: request.TopP, + Stream: stream, + Messages: messages, + } +} + +func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Object: "chat.completion", + Created: common.GetTimestamp(), + Usage: response.Usage, + } + if len(response.Choices) > 0 { + choice := OpenAITextResponseChoice{ + Index: 0, + Message: Message{ + Role: "assistant", + Content: response.Choices[0].Messages.Content, + }, + FinishReason: response.Choices[0].FinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { + response := ChatCompletionsStreamResponse{ + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "tencent-hunyuan", + } + if len(TencentResponse.Choices) > 0 { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = TencentResponse.Choices[0].Delta.Content + if TencentResponse.Choices[0].FinishReason == "stop" { + choice.FinishReason = &stopFinishReason + } + response.Choices = append(response.Choices, choice) + } + return &response +} + +func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + var responseText string + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + if len(data) < 5 { // ignore blank line or wrong format + continue + } + if data[:5] != "data:" { + continue + } + data = data[5:] + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var TencentResponse TencentChatResponse + err := json.Unmarshal([]byte(data), &TencentResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + response := streamResponseTencent2OpenAI(&TencentResponse) + if len(response.Choices) != 0 { + responseText += response.Choices[0].Delta.Content + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { + var TencentResponse TencentChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &TencentResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if TencentResponse.Error.Code != 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: TencentResponse.Error.Message, + Code: TencentResponse.Error.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseTencent2OpenAI(&TencentResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} + +func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { + parts := strings.Split(config, "|") + if len(parts) != 3 { + err = errors.New("invalid tencent config") + return + } + appId, err = strconv.ParseInt(parts[0], 10, 64) + secretId = parts[1] + secretKey = parts[2] + return +} + +func getTencentSign(req TencentChatRequest, secretKey string) string { + params := make([]string, 0) + params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) + params = append(params, "secret_id="+req.SecretId) + params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) + params = append(params, "query_id="+req.QueryID) + params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) + params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) + params = append(params, "stream="+strconv.Itoa(req.Stream)) + params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + + var messageStr string + for _, msg := range req.Messages { + messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) + } + messageStr = strings.TrimSuffix(messageStr, ",") + params = append(params, "messages=["+messageStr+"]") + + sort.Sort(sort.StringSlice(params)) + url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") + mac := hmac.New(sha1.New, []byte(secretKey)) + signURL := url + mac.Write([]byte(signURL)) + sign := mac.Sum([]byte(nil)) + return base64.StdEncoding.EncodeToString(sign) +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 9243feea..9f131b7b 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -24,6 +24,7 @@ const ( APITypeAli APITypeXunfei APITypeAIProxyLibrary + APITypeTencent ) var httpClient *http.Client @@ -109,6 +110,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeXunfei case common.ChannelTypeAIProxyLibrary: apiType = APITypeAIProxyLibrary + case common.ChannelTypeTencent: + apiType = APITypeTencent } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -116,6 +119,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { baseURL = c.GetString("base_url") } fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + if channelType == common.ChannelTypeOpenAI { + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + } + } switch apiType { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { @@ -179,6 +187,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if relayMode == RelayModeEmbeddings { fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } + case APITypeTencent: + fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" case APITypeAIProxyLibrary: fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) } @@ -204,6 +214,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) @@ -282,6 +295,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeTencent: + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + appId, secretId, secretKey, err := parseTencentConfig(apiKey) + if err != nil { + return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) + } + tencentRequest := requestOpenAI2Tencent(textRequest) + tencentRequest.AppId = appId + tencentRequest.SecretId = secretId + jsonStr, err := json.Marshal(tencentRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + sign := getTencentSign(*tencentRequest, secretKey) + c.Request.Header.Set("Authorization", sign) + requestBody = bytes.NewBuffer(jsonStr) case APITypeAIProxyLibrary: aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) aiProxyLibraryRequest.LibraryId = c.GetString("library_id") @@ -336,6 +366,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.Stream { req.Header.Set("X-DashScope-SSE", "enable") } + case APITypeTencent: + req.Header.Set("Authorization", apiKey) default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -588,6 +620,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeTencent: + if isStream { + err, responseText := tencentStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := tencentHandler(c, resp) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } default: return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 947e956b..7109cd2e 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -10,44 +10,53 @@ import ( "one-api/common" "regexp" "strconv" + "strings" ) var stopFinishReason = "stop" +// tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +var defaultTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { common.SysLog("initializing token encoders") - fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { - common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) + common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) + } + defaultTokenEncoder = gpt35TokenEncoder + gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } for model, _ := range common.ModelRatio { - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) - tokenEncoderMap[model] = fallbackTokenEncoder - continue + if strings.HasPrefix(model, "gpt-3.5") { + tokenEncoderMap[model] = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4") { + tokenEncoderMap[model] = gpt4TokenEncoder + } else { + tokenEncoderMap[model] = nil } - tokenEncoderMap[model] = tokenEncoder } common.SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { - if tokenEncoder, ok := tokenEncoderMap[model]; ok { + tokenEncoder, ok := tokenEncoderMap[model] + if ok && tokenEncoder != nil { return tokenEncoder } - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") + if ok { + tokenEncoder, err := tiktoken.EncodingForModel(model) if err != nil { - common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) + common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) + tokenEncoder = defaultTokenEncoder } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder } - tokenEncoderMap[model] = tokenEncoder - return tokenEncoder + return defaultTokenEncoder } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index ff6bf065..cbaf38fe 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin for !stop { select { case xunfeiResponse = <-dataChan: + if len(xunfeiResponse.Payload.Choices.Text) == 0 { + continue + } content += xunfeiResponse.Payload.Choices.Text[0].Content usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens diff --git a/docker-compose.yml b/docker-compose.yml index 003122bb..9b814a03 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,7 +23,7 @@ services: depends_on: - redis healthcheck: - test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ] + test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] interval: 30s timeout: 10s retries: 3 diff --git a/go.mod b/go.mod index 5c5abe76..a82121b7 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/pkoukk/tiktoken-go v0.1.1 github.com/samber/lo v1.38.1 github.com/star-horizon/go-epay v0.0.0-20230204124159-fa2e2293fdc2 - golang.org/x/crypto v0.9.0 + golang.org/x/crypto v0.14.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/postgres v1.5.2 gorm.io/driver/sqlite v1.4.3 @@ -57,9 +57,9 @@ require ( github.com/ugorji/go/codec v1.2.11 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3912c6e2..2d64620e 100644 --- a/go.sum +++ b/go.sum @@ -156,13 +156,13 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 h1:3MTrJm4PyNL9NBqvYDSj3DHl46qQakyfqfWo4jgfaEM= golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -170,14 +170,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= diff --git a/main.go b/main.go index a883d936..de0b2a07 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "embed" + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" @@ -50,18 +51,17 @@ func main() { // Initialize options model.InitOptionMap() if common.RedisEnabled { + // for compatibility with old versions + common.MemoryCacheEnabled = true + } + if common.MemoryCacheEnabled { + common.SysLog("memory cache enabled") + common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) model.InitChannelCache() } - if os.Getenv("SYNC_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) - if err != nil { - common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error()) - } - common.SyncFrequency = frequency - go model.SyncOptions(frequency) - if common.RedisEnabled { - go model.SyncChannelCache(frequency) - } + if common.MemoryCacheEnabled { + go model.SyncOptions(common.SyncFrequency) + go model.SyncChannelCache(common.SyncFrequency) } if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) diff --git a/middleware/auth.go b/middleware/auth.go index 3f39752c..5b8670a9 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -103,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } - userEnabled, err := model.IsUserEnabled(token.UserId) + userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { abortWithMessage(c, http.StatusInternalServerError, err.Error()) return diff --git a/middleware/distributor.go b/middleware/distributor.go index 668241f2..c49a40d2 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID") + abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } if channel.Status != common.ChannelStatusEnabled { diff --git a/model/cache.go b/model/cache.go index b9d6b612..a7f5c06f 100644 --- a/model/cache.go +++ b/model/cache.go @@ -186,7 +186,7 @@ func SyncChannelCache(frequency int) { } func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - if !common.RedisEnabled { + if !common.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } channelSyncLock.RLock() diff --git a/model/channel.go b/model/channel.go index cd9dc9f3..96351310 100644 --- a/model/channel.go +++ b/model/channel.go @@ -12,7 +12,7 @@ type Channel struct { OpenAIOrganization *string `json:"openai_organization"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Weight int `json:"weight"` + Weight *uint `json:"weight" gorm:"default:0"` CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds @@ -178,3 +178,13 @@ func updateChannelUsedQuota(id int, quota int) { common.SysError("failed to update channel used quota: " + err.Error()) } } + +func DeleteChannelByStatus(status int64) (int64, error) { + result := DB.Where("status = ?", status).Delete(&Channel{}) + return result.RowsAffected, result.Error +} + +func DeleteDisabledChannel() (int64, error) { + result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + return result.RowsAffected, result.Error +} diff --git a/model/log.go b/model/log.go index c7dec563..5ea9372b 100644 --- a/model/log.go +++ b/model/log.go @@ -9,19 +9,19 @@ import ( ) type Log struct { - Id int `json:"id"` - UserId int `json:"user_id"` - CreatedAt int64 `json:"created_at" gorm:"bigint;index"` - Type int `json:"type" gorm:"index"` + Id int `json:"id;index:idx_created_at_id,priority:1"` + UserId int `json:"user_id" gorm:"index"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` + Type int `json:"type" gorm:"index:idx_created_at_type"` Content string `json:"content"` - Username string `json:"username" gorm:"index;default:''"` + Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"` - ModelName string `json:"model_name" gorm:"index;default:''"` + ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` Quota int `json:"quota" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"` + ChannelId int `json:"channel" gorm:"index"` TokenId int `json:"token_id" gorm:"default:0;index"` - Channel int `json:"channel" gorm:"default:0"` } const ( @@ -70,7 +70,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke TokenName: tokenName, ModelName: modelName, Quota: quota, - Channel: channelId, + ChannelId: channelId, TokenId: tokenId, } err := DB.Create(log).Error diff --git a/model/main.go b/model/main.go index 6b775cdf..4feceb48 100644 --- a/model/main.go +++ b/model/main.go @@ -81,6 +81,7 @@ func InitDB() (err error) { if !common.IsMasterNode { return nil } + common.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { return err diff --git a/model/user.go b/model/user.go index 0e12f077..5baff0b0 100644 --- a/model/user.go +++ b/model/user.go @@ -312,7 +312,8 @@ func GetRootUserEmail() (email string) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { if common.BatchUpdateEnabled { - addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + addNewRecord(BatchUpdateTypeUsedQuota, id, quota) + addNewRecord(BatchUpdateTypeRequestCount, id, 1) return } updateUserUsedQuotaAndRequestCount(id, quota, 1) @@ -330,6 +331,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } +func updateUserUsedQuota(id int, quota int) { + err := DB.Model(&User{}).Where("id = ?", id).Updates( + map[string]interface{}{ + "used_quota": gorm.Expr("used_quota + ?", quota), + }, + ).Error + if err != nil { + common.SysError("failed to update user used quota: " + err.Error()) + } +} + +func updateUserRequestCount(id int, count int) { + err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error + if err != nil { + common.SysError("failed to update user request count: " + err.Error()) + } +} + func GetUsernameById(id int) (username string) { DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username) return username diff --git a/model/utils.go b/model/utils.go index 61734332..1c28340b 100644 --- a/model/utils.go +++ b/model/utils.go @@ -6,13 +6,13 @@ import ( "time" ) -const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock - const ( BatchUpdateTypeUserQuota = iota BatchUpdateTypeTokenQuota - BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeUsedQuota BatchUpdateTypeChannelUsedQuota + BatchUpdateTypeRequestCount + BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) var batchUpdateStores []map[int]int @@ -51,7 +51,7 @@ func batchUpdate() { store := batchUpdateStores[i] batchUpdateStores[i] = make(map[int]int) batchUpdateLocks[i].Unlock() - + // TODO: maybe we can combine updates with same key? for key, value := range store { switch i { case BatchUpdateTypeUserQuota: @@ -64,8 +64,10 @@ func batchUpdate() { if err != nil { common.SysError("failed to batch update token quota: " + err.Error()) } - case BatchUpdateTypeUsedQuotaAndRequestCount: - updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeUsedQuota: + updateUserUsedQuota(key, value) + case BatchUpdateTypeRequestCount: + updateUserRequestCount(key, value) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/router/api-router.go b/router/api-router.go index e9b0c5ea..e3a25676 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -79,6 +79,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) + channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel) channelRoute.DELETE("/:id", controller.DeleteChannel) } tokenRoute := apiRouter.Group("/token") diff --git a/web/src/App.js b/web/src/App.js index 422b1522..13c884dc 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -1,20 +1,20 @@ -import React, {lazy, Suspense, useContext, useEffect} from 'react'; -import {Route, Routes} from 'react-router-dom'; +import React, { lazy, Suspense, useContext, useEffect } from 'react'; +import { Route, Routes } from 'react-router-dom'; import Loading from './components/Loading'; import User from './pages/User'; -import {PrivateRoute} from './components/PrivateRoute'; +import { PrivateRoute } from './components/PrivateRoute'; import RegisterForm from './components/RegisterForm'; import LoginForm from './components/LoginForm'; import NotFound from './pages/NotFound'; import Setting from './pages/Setting'; import EditUser from './pages/User/EditUser'; import AddUser from './pages/User/AddUser'; -import {API, getLogo, getSystemName, showError, showNotice} from './helpers'; +import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; import PasswordResetForm from './components/PasswordResetForm'; import GitHubOAuth from './components/GitHubOAuth'; import PasswordResetConfirm from './components/PasswordResetConfirm'; -import {UserContext} from './context/User'; -import {StatusContext} from './context/Status'; +import { UserContext } from './context/User'; +import { StatusContext } from './context/Status'; import Channel from './pages/Channel'; import Token from './pages/Token'; import EditToken from './pages/Token/EditToken'; @@ -24,295 +24,270 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; -import Midjourney from './pages/Midjourney'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); function App() { - const [userState, userDispatch] = useContext(UserContext); - const [statusState, statusDispatch] = useContext(StatusContext); + const [userState, userDispatch] = useContext(UserContext); + const [statusState, statusDispatch] = useContext(StatusContext); - const loadUser = () => { - let user = localStorage.getItem('user'); - if (user) { - let data = JSON.parse(user); - userDispatch({type: 'login', payload: data}); - } - }; - const loadStatus = async () => { - const res = await API.get('/api/status'); - const {success, data} = res.data; - if (success) { - localStorage.setItem('status', JSON.stringify(data)); - statusDispatch({type: 'set', payload: data}); - localStorage.setItem('system_name', data.system_name); - localStorage.setItem('logo', data.logo); - localStorage.setItem('footer_html', data.footer_html); - localStorage.setItem('quota_per_unit', data.quota_per_unit); - localStorage.setItem('display_in_currency', data.display_in_currency); - if (data.chat_link) { - localStorage.setItem('chat_link', data.chat_link); - } else { - localStorage.removeItem('chat_link'); - } - if ( - data.version !== process.env.REACT_APP_VERSION && - data.version !== 'v0.0.0' && - process.env.REACT_APP_VERSION !== '' - ) { - showNotice( - `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` - ); - } - } else { - showError('无法正常连接至服务器!'); - } - }; + const loadUser = () => { + let user = localStorage.getItem('user'); + if (user) { + let data = JSON.parse(user); + userDispatch({ type: 'login', payload: data }); + } + }; + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if ( + data.version !== process.env.REACT_APP_VERSION && + data.version !== 'v0.0.0' && + process.env.REACT_APP_VERSION !== '' + ) { + showNotice( + `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` + ); + } + } else { + showError('无法正常连接至服务器!'); + } + }; - // const getOptions = async () => { - // const res = await API.get('/api/option/'); - // const {success, message, data} = res.data; - // if (success) { - // let newInputs = {}; - // data.forEach((item) => { - // if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { - // item.value = JSON.stringify(JSON.parse(item.value), null, 2); - // } - // newInputs[item.key] = item.value; - // }); - // setInputs(newInputs); - // setOriginInputs(newInputs); - // } else { - // showError(message); - // } - // }; + useEffect(() => { + loadUser(); + loadStatus().then(); + let systemName = getSystemName(); + if (systemName) { + document.title = systemName; + } + let logo = getLogo(); + if (logo) { + let linkElement = document.querySelector("link[rel~='icon']"); + if (linkElement) { + linkElement.href = logo; + } + } + }, []); - useEffect(() => { - loadUser(); - loadStatus().then(); - let systemName = getSystemName(); - if (systemName) { - document.title = systemName; + return ( + + }> + + } - let logo = getLogo(); - if (logo) { - let linkElement = document.querySelector("link[rel~='icon']"); - if (linkElement) { - linkElement.href = logo; - } + /> + + + } - }, []); - - return ( - - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - }> - - - } - /> - - }> - - - - } - /> - - }> - - - - } - /> - - - - } - /> - - - - } - /> - }> - - - } - /> - }> - - - } - /> - - - ); + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + }> + + + } + /> + + }> + + + + } + /> + + }> + + + + } + /> + + + + } + /> + }> + + + } + /> + }> + + + } + /> + + } /> + + ); } export default App; diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 895de27e..7f9b448a 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; -import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react'; +import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Link } from 'react-router-dom'; -import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers'; +import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import {renderGroup, renderNumber, renderQuota} from '../helpers/render'; @@ -56,6 +56,7 @@ const ChannelsTable = () => { const [searching, setSearching] = useState(false); const [updatingBalance, setUpdatingBalance] = useState(false); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); + const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test")); const loadChannels = async (startIdx) => { const res = await API.get(`/api/channel/?p=${startIdx}&page_size=${pageSize}`); @@ -104,7 +105,7 @@ const ChannelsTable = () => { }); }, []); - const manageChannel = async (id, action, idx, priority) => { + const manageChannel = async (id, action, idx, value) => { let data = { id }; let res; switch (action) { @@ -120,10 +121,20 @@ const ChannelsTable = () => { res = await API.put('/api/channel/', data); break; case 'priority': - if (priority === '') { + if (value === '') { return; } - data.priority = parseInt(priority); + data.priority = parseInt(value); + res = await API.put('/api/channel/', data); + break; + case 'weight': + if (value === '') { + return; + } + data.weight = parseInt(value); + if (data.weight < 0) { + data.weight = 0; + } res = await API.put('/api/channel/', data); break; } @@ -150,9 +161,23 @@ const ChannelsTable = () => { return ; case 2: return ( - + + 已禁用 + } + content='本渠道被手动禁用' + basic + /> + ); + case 3: + return ( + + 已禁用 + } + content='本渠道被程序自动禁用' + basic + /> ); default: return ( @@ -210,7 +235,6 @@ const ChannelsTable = () => { showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message); - showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。") } }; @@ -224,6 +248,17 @@ const ChannelsTable = () => { } }; + const deleteAllDisabledChannels = async () => { + const res = await API.delete(`/api/channel/disabled`); + const { success, message, data } = res.data; + if (success) { + showSuccess(`已删除所有禁用渠道,共计 ${data} 个`); + await refresh(); + } else { + showError(message); + } + }; + const updateChannelBalance = async (id, name, idx) => { const res = await API.get(`/api/channel/update_balance/${id}/`); const { success, message, balance } = res.data; @@ -290,7 +325,19 @@ const ChannelsTable = () => { onChange={handleKeywordChange} /> + { + showPrompt && ( + { + setShowPrompt(false); + setPromptShown("channel-test"); + }}> + 当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo + 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。 + 另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 + + ) + } @@ -363,10 +410,10 @@ const ChannelsTable = () => { 余额 { - sortChannel('priority'); - }} + style={{ cursor: 'pointer' }} + onClick={() => { + sortChannel('priority'); + }} > 优先级 @@ -411,18 +458,18 @@ const ChannelsTable = () => { { - manageChannel( - channel.id, - 'priority', - idx, - event.target.value, - ); - }}> - - } - content='渠道选择优先级,越高越优先' - basic + trigger={ { + manageChannel( + channel.id, + 'priority', + idx, + event.target.value + ); + }}> + + } + content='渠道选择优先级,越高越优先' + basic /> @@ -519,6 +566,31 @@ const ChannelsTable = () => { } /> + + 删除禁用渠道 + + } + on='click' + flowing + hoverable + > + + + diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index b5c4e6f9..a3913220 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react'; import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react'; import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; -import { API, getLogo, showError, showSuccess } from '../helpers'; -import { getOAuthState, onGitHubOAuthClicked } from './utils'; +import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; +import { onGitHubOAuthClicked } from './utils'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -68,8 +68,14 @@ const LoginForm = () => { if (success) { userDispatch({ type: 'login', payload: data }); localStorage.setItem('user', JSON.stringify(data)); - navigate('/'); - showSuccess('登录成功!'); + if (username === 'root' && password === '123456') { + navigate('/user/edit'); + showSuccess('登录成功!'); + showWarning('请立刻修改默认密码!'); + } else { + navigate('/token'); + showSuccess('登录成功!'); + } } else { showError(message); } @@ -126,7 +132,7 @@ const LoginForm = () => { circular color='black' icon='github' - onClick={()=>onGitHubOAuthClicked(status.github_client_id)} + onClick={() => onGitHubOAuthClicked(status.github_client_id)} /> ) : ( <> diff --git a/web/src/components/TokensTable.js b/web/src/components/TokensTable.js index 50d442de..6772c235 100644 --- a/web/src/components/TokensTable.js +++ b/web/src/components/TokensTable.js @@ -138,7 +138,7 @@ const TokensTable = () => { let defaultUrl; if (chatLink) { - defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`; + defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } else { defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`; } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index e42afc6e..76407745 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [ { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index 399b7e89..01ce68c0 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -186,4 +186,14 @@ export const verifyJSON = (str) => { return false; } return true; -}; \ No newline at end of file +}; + +export function shouldShowPrompt(id) { + let prompt = localStorage.getItem(`prompt-${id}`); + return !prompt; + +} + +export function setPromptShown(id) { + localStorage.setItem(`prompt-${id}`, 'true'); +} \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 3a8022a8..50148bd5 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -19,6 +19,8 @@ function type2secretPrompt(type) { return '按照如下格式输入:APPID|APISecret|APIKey'; case 22: return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041'; + case 23: + return '按照如下格式输入:AppId|SecretId|SecretKey'; default: return '请输入渠道对应的鉴权密钥'; } @@ -80,7 +82,10 @@ const EditChannel = () => { localModels = ['SparkDesk']; break; case 19: - localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4']; + localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1']; + break; + case 23: + localModels = ['hunyuan']; break; } setInputs((inputs) => ({ ...inputs, models: localModels })); diff --git a/web/src/pages/NotFound/index.js b/web/src/pages/NotFound/index.js index 08a95f9d..f92dbc90 100644 --- a/web/src/pages/NotFound/index.js +++ b/web/src/pages/NotFound/index.js @@ -1,19 +1,12 @@ import React from 'react'; -import { Segment, Header } from 'semantic-ui-react'; +import { Message } from 'semantic-ui-react'; const NotFound = () => ( <> -
- - 未找到所请求的页面 - + + 页面不存在 +

请检查你的浏览器地址是否正确

+
); diff --git a/web/src/pages/User/EditUser.js b/web/src/pages/User/EditUser.js index e8f96027..8ae0e556 100644 --- a/web/src/pages/User/EditUser.js +++ b/web/src/pages/User/EditUser.js @@ -102,7 +102,7 @@ const EditUser = () => { label='密码' name='password' type={'password'} - placeholder={'请输入新的密码'} + placeholder={'请输入新的密码,最短 8 位'} onChange={handleInputChange} value={password} autoComplete='new-password'