diff --git a/controller/channel.go b/controller/channel.go index b393273c..19be1299 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -11,6 +11,7 @@ import ( "one-api/model" "strconv" "strings" + "sync" "time" ) @@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage) + channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error { return nil } +func buildTestRequest(c *gin.Context) *ChatRequest { + model_ := c.Query("model") + testRequest := &ChatRequest{ + Model: model_, + } + testMessage := Message{ + Role: "user", + Content: "echo hi", + } + testRequest.Messages = append(testRequest.Messages, testMessage) + return testRequest +} + func TestChannel(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) { }) return } - model_ := c.Query("model") - chatRequest := &ChatRequest{ - Model: model_, - } - testMessage := Message{ - Role: "user", - Content: "echo hi", - } - chatRequest.Messages = append(chatRequest.Messages, testMessage) + testRequest := buildTestRequest(c) tik := time.Now() - err = testChannel(channel, chatRequest) + err = testChannel(channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) @@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) { }) return } + +var testAllChannelsLock sync.Mutex + +func testAllChannels(c *gin.Context) error { + ok := testAllChannelsLock.TryLock() + if !ok { + return errors.New("测试已在运行") + } + defer testAllChannelsLock.Unlock() + channels, err := model.GetAllChannels(0, 0, true) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return err + } + testRequest := buildTestRequest(c) + var disableThreshold int64 = 5000 // TODO: make it configurable + email := model.GetRootUserEmail() + go func() { + for _, channel := range channels { + if channel.Status != common.ChannelStatusEnabled { + continue + } + tik := time.Now() + err := testChannel(channel, testRequest) + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + if err != nil || milliseconds > disableThreshold { + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + } + // disable & notify + channel.UpdateStatus(common.ChannelStatusDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error()) + err = common.SendEmail(subject, email, content) + if err != nil { + common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) + } + } + channel.UpdateResponseTime(milliseconds) + } + err := common.SendEmail("通道测试完成", email, "通道测试完成") + if err != nil { + common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) + } + }() + return nil +} + +func TestAllChannels(c *gin.Context) { + err := testAllChannels(c) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/model/channel.go b/model/channel.go index e443d7be..7b1c9ec2 100644 --- a/model/channel.go +++ b/model/channel.go @@ -19,10 +19,14 @@ type Channel struct { Other string `json:"other"` } -func GetAllChannels(startIdx int, num int) ([]*Channel, error) { +func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { var channels []*Channel var err error - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error + if selectAll { + err = DB.Order("id desc").Find(&channels).Error + } else { + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error + } return channels, err } @@ -82,6 +86,13 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { } } +func (channel *Channel) UpdateStatus(status int) { + err := DB.Model(channel).Update("status", status).Error + if err != nil { + common.SysError("failed to update response time: " + err.Error()) + } +} + func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error diff --git a/model/user.go b/model/user.go index 3a8f2313..b121753d 100644 --- a/model/user.go +++ b/model/user.go @@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } + +func GetRootUserEmail() (email string) { + DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + return email +} diff --git a/router/api-router.go b/router/api-router.go index 0e249dc9..9e7f580d 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/:id", controller.GetChannel) + channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.POST("/", controller.AddChannel) channelRoute.PUT("/", controller.UpdateChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 7fe24645..c3531fea 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -170,6 +170,16 @@ const ChannelsTable = () => { } }; + const testAllChannels = async () => { + const res = await API.get(`/api/channel/test`); + const { success, message } = res.data; + if (success) { + showSuccess("已成功开始测试所有已启用通道,请刷新页面查看结果。"); + } else { + showError(message); + } + } + const handleKeywordChange = async (e, { value }) => { setSearchKeyword(value.trim()); }; @@ -335,6 +345,9 @@ const ChannelsTable = () => { +