From 222ba4e2c6caeec04d9f3b076350c8710a4fe6a5 Mon Sep 17 00:00:00 2001
From: liyujie <29959257@qq.com>
Date: Sat, 2 Dec 2023 12:50:47 +0800
Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=8A=9F=E8=83=BD:=20?=
=?UTF-8?q?=E6=B8=A0=E9=81=93=20-=20=E6=B5=8B=E8=AF=95=E6=89=80=E6=9C=89?=
=?UTF-8?q?=E9=80=9A=E9=81=93;=20=E8=AE=BE=E7=BD=AE=20-=20=E8=BF=90?=
=?UTF-8?q?=E8=90=A5=E8=AE=BE=E7=BD=AE=20-=20=E7=9B=91=E6=8E=A7=E8=AE=BE?=
=?UTF-8?q?=E7=BD=AE=20-=20=E6=88=90=E5=8A=9F=E6=97=B6=E8=87=AA=E5=8A=A8?=
=?UTF-8?q?=E5=90=AF=E7=94=A8=E9=80=9A=E9=81=93?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
common/constants.go | 1 +
controller/channel-test.go | 84 +++++++++++++++++++++++++-
controller/relay-utils.go | 19 ++++++
i18n/en.json | 2 +
model/option.go | 3 +
router/api-router.go | 3 +-
web/src/components/ChannelsTable.js | 13 ++++
web/src/components/OperationSetting.js | 7 +++
8 files changed, 130 insertions(+), 2 deletions(-)
diff --git a/common/constants.go b/common/constants.go
index c7d3f222..f6860f67 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -78,6 +78,7 @@ var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
+var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 1b0b745a..8599fe6b 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -80,7 +80,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
if err != nil {
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
}
- if response.Usage.CompletionTokens == 0 {
+ if response.Usage.CompletionTokens == 0 && (response.Error.Type != "" && response.Error.Code != "" && response.Error.Message != "") {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
return nil, nil
@@ -156,7 +156,73 @@ func disableChannel(channelId int, channelName string, reason string) {
}
}
+// enable & notify
+func enableChannel(channelId int, channelName string) {
+ if common.RootUserEmail == "" {
+ common.RootUserEmail = model.GetRootUserEmail()
+ }
+ model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
+ subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
+ content := fmt.Sprintf("通道「%s」(#%d)现在已被启用", channelName, channelId)
+ err := common.SendEmail(subject, common.RootUserEmail, content)
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ }
+}
+
func testAllChannels(notify bool) error {
+ if common.RootUserEmail == "" {
+ common.RootUserEmail = model.GetRootUserEmail()
+ }
+ testAllChannelsLock.Lock()
+ if testAllChannelsRunning {
+ testAllChannelsLock.Unlock()
+ return errors.New("测试已在运行中")
+ }
+ testAllChannelsRunning = true
+ testAllChannelsLock.Unlock()
+ channels, err := model.GetAllChannels(0, 0, true)
+ if err != nil {
+ return err
+ }
+ testRequest := buildTestRequest()
+ var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
+ if disableThreshold == 0 {
+ disableThreshold = 10000000 // a impossible value
+ }
+ go func() {
+ for _, channel := range channels {
+ tik := time.Now()
+ err, openaiErr := testChannel(channel, *testRequest)
+ tok := time.Now()
+ milliseconds := tok.Sub(tik).Milliseconds()
+ if milliseconds > disableThreshold {
+ err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+ disableChannel(channel.Id, channel.Name, err.Error())
+ }
+ if shouldDisableChannel(openaiErr, -1) {
+ disableChannel(channel.Id, channel.Name, err.Error())
+ }
+ if shouldEnableChannel(channel.Status, err, openaiErr, -1) {
+ enableChannel(channel.Id, channel.Name)
+ }
+ channel.UpdateResponseTime(milliseconds)
+ time.Sleep(common.RequestInterval)
+ }
+ testAllChannelsLock.Lock()
+ testAllChannelsRunning = false
+ testAllChannelsLock.Unlock()
+ if notify {
+ err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到启用/禁用通知,说明所有通道都正常")
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
+ }
+ }
+ }()
+ return nil
+}
+
+func testAllEnableChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
@@ -224,6 +290,22 @@ func TestAllChannels(c *gin.Context) {
return
}
+func TestAllEnableChannels(c *gin.Context) {
+ err := testAllEnableChannels(true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
+
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
diff --git a/controller/relay-utils.go b/controller/relay-utils.go
index 391f28b4..1abac6a2 100644
--- a/controller/relay-utils.go
+++ b/controller/relay-utils.go
@@ -145,6 +145,25 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
return false
}
+func shouldEnableChannel(status int, err error, openAIErr *OpenAIError, statusCode int) bool {
+ if status == common.ChannelStatusEnabled {
+ return false
+ }
+ if !common.AutomaticEnableChannelEnabled {
+ return false
+ }
+ if err != nil {
+ return false
+ }
+ if statusCode == http.StatusUnauthorized {
+ return false
+ }
+ if openAIErr != nil {
+ return false
+ }
+ return true
+}
+
func setEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
diff --git a/i18n/en.json b/i18n/en.json
index 9b2ca4c8..b0deb83a 100644
--- a/i18n/en.json
+++ b/i18n/en.json
@@ -119,6 +119,7 @@
" 年 ": " y ",
"未测试": "Not tested",
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
+ "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
@@ -139,6 +140,7 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
+ "测试所有通道": "Test all channels",
"测试所有已启用通道": "Test all enabled channels",
"更新所有已启用通道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",
diff --git a/model/option.go b/model/option.go
index 4ef4d260..bb8b709c 100644
--- a/model/option.go
+++ b/model/option.go
@@ -34,6 +34,7 @@ func InitOptionMap() {
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
+ common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
@@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled":
common.AutomaticDisableChannelEnabled = boolValue
+ case "AutomaticEnableChannelEnabled":
+ common.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled":
common.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled":
diff --git a/router/api-router.go b/router/api-router.go
index da3f9e61..d5450329 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -68,7 +68,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/:id", controller.GetChannel)
- channelRoute.GET("/test", controller.TestAllChannels)
+ channelRoute.GET("/testAll", controller.TestAllChannels)
+ channelRoute.GET("/test", controller.TestAllEnableChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance)
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index d44ea2d7..64c40f7f 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -231,6 +231,16 @@ const ChannelsTable = () => {
};
const testAllChannels = async () => {
+ const res = await API.get(`/api/channel/testAll`);
+ const { success, message } = res.data;
+ if (success) {
+ showInfo('已成功开始测试所有通道,请刷新页面查看结果。');
+ } else {
+ showError(message);
+ }
+ };
+
+ const testAllEnableChannels = async () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
@@ -523,6 +533,9 @@ const ChannelsTable = () => {
添加新的渠道
+