diff --git a/common/constants.go b/common/constants.go index c7d3f222..f6860f67 100644 --- a/common/constants.go +++ b/common/constants.go @@ -78,6 +78,7 @@ var QuotaForInviter = 0 var QuotaForInvitee = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false diff --git a/controller/channel-test.go b/controller/channel-test.go index 1b0b745a..8599fe6b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -80,7 +80,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai if err != nil { return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil } - if response.Usage.CompletionTokens == 0 { + if response.Usage.CompletionTokens == 0 && (response.Error.Type != "" && response.Error.Code != "" && response.Error.Message != "") { return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error } return nil, nil @@ -156,7 +156,73 @@ func disableChannel(channelId int, channelName string, reason string) { } } +// enable & notify +func enableChannel(channelId int, channelName string) { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)现在已被启用", channelName, channelId) + err := common.SendEmail(subject, common.RootUserEmail, content) + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + func testAllChannels(notify bool) error { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + testAllChannelsLock.Lock() + if testAllChannelsRunning { + testAllChannelsLock.Unlock() + return errors.New("测试已在运行中") + } + testAllChannelsRunning = true + testAllChannelsLock.Unlock() + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + return err + } + testRequest := buildTestRequest() + var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + if disableThreshold == 0 { + disableThreshold = 10000000 // a impossible value + } + go func() { + for _, channel := range channels { + tik := time.Now() + err, openaiErr := testChannel(channel, *testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + disableChannel(channel.Id, channel.Name, err.Error()) + } + if shouldDisableChannel(openaiErr, -1) { + disableChannel(channel.Id, channel.Name, err.Error()) + } + if shouldEnableChannel(channel.Status, err, openaiErr, -1) { + enableChannel(channel.Id, channel.Name) + } + channel.UpdateResponseTime(milliseconds) + time.Sleep(common.RequestInterval) + } + testAllChannelsLock.Lock() + testAllChannelsRunning = false + testAllChannelsLock.Unlock() + if notify { + err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到启用/禁用通知,说明所有通道都正常") + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } + } + }() + return nil +} + +func testAllEnableChannels(notify bool) error { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } @@ -224,6 +290,22 @@ func TestAllChannels(c *gin.Context) { return } +func TestAllEnableChannels(c *gin.Context) { + err := testAllEnableChannels(true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} + func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 391f28b4..1abac6a2 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -145,6 +145,25 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { return false } +func shouldEnableChannel(status int, err error, openAIErr *OpenAIError, statusCode int) bool { + if status == common.ChannelStatusEnabled { + return false + } + if !common.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if statusCode == http.StatusUnauthorized { + return false + } + if openAIErr != nil { + return false + } + return true +} + func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") diff --git a/i18n/en.json b/i18n/en.json index 9b2ca4c8..b0deb83a 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -119,6 +119,7 @@ " 年 ": " y ", "未测试": "Not tested", "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", + "已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.", "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", @@ -139,6 +140,7 @@ "启用": "Enable", "编辑": "Edit", "添加新的渠道": "Add a new channel", + "测试所有通道": "Test all channels", "测试所有已启用通道": "Test all enabled channels", "更新所有已启用通道余额": "Update the balance of all enabled channels", "刷新": "Refresh", diff --git a/model/option.go b/model/option.go index 4ef4d260..bb8b709c 100644 --- a/model/option.go +++ b/model/option.go @@ -34,6 +34,7 @@ func InitOptionMap() { common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) @@ -147,6 +148,8 @@ func updateOptionMap(key string, value string) (err error) { common.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": common.AutomaticDisableChannelEnabled = boolValue + case "AutomaticEnableChannelEnabled": + common.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": common.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": diff --git a/router/api-router.go b/router/api-router.go index da3f9e61..d5450329 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -68,7 +68,8 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/:id", controller.GetChannel) - channelRoute.GET("/test", controller.TestAllChannels) + channelRoute.GET("/testAll", controller.TestAllChannels) + channelRoute.GET("/test", controller.TestAllEnableChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index d44ea2d7..64c40f7f 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -231,6 +231,16 @@ const ChannelsTable = () => { }; const testAllChannels = async () => { + const res = await API.get(`/api/channel/testAll`); + const { success, message } = res.data; + if (success) { + showInfo('已成功开始测试所有通道,请刷新页面查看结果。'); + } else { + showError(message); + } + }; + + const testAllEnableChannels = async () => { const res = await API.get(`/api/channel/test`); const { success, message } = res.data; if (success) { @@ -523,6 +533,9 @@ const ChannelsTable = () => { 添加新的渠道 +