diff --git a/controller/channel.go b/controller/channel.go
index b393273c..19be1299 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -11,6 +11,7 @@ import (
"one-api/model"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 {
p = 0
}
- channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
+ channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
return nil
}
+func buildTestRequest(c *gin.Context) *ChatRequest {
+ model_ := c.Query("model")
+ testRequest := &ChatRequest{
+ Model: model_,
+ }
+ testMessage := Message{
+ Role: "user",
+ Content: "echo hi",
+ }
+ testRequest.Messages = append(testRequest.Messages, testMessage)
+ return testRequest
+}
+
func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
@@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) {
})
return
}
- model_ := c.Query("model")
- chatRequest := &ChatRequest{
- Model: model_,
- }
- testMessage := Message{
- Role: "user",
- Content: "echo hi",
- }
- chatRequest.Messages = append(chatRequest.Messages, testMessage)
+ testRequest := buildTestRequest(c)
tik := time.Now()
- err = testChannel(channel, chatRequest)
+ err = testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds)
@@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) {
})
return
}
+
+var testAllChannelsLock sync.Mutex
+
+func testAllChannels(c *gin.Context) error {
+ ok := testAllChannelsLock.TryLock()
+ if !ok {
+ return errors.New("测试已在运行")
+ }
+ defer testAllChannelsLock.Unlock()
+ channels, err := model.GetAllChannels(0, 0, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return err
+ }
+ testRequest := buildTestRequest(c)
+ var disableThreshold int64 = 5000 // TODO: make it configurable
+ email := model.GetRootUserEmail()
+ go func() {
+ for _, channel := range channels {
+ if channel.Status != common.ChannelStatusEnabled {
+ continue
+ }
+ tik := time.Now()
+ err := testChannel(channel, testRequest)
+ tok := time.Now()
+ milliseconds := tok.Sub(tik).Milliseconds()
+ if err != nil || milliseconds > disableThreshold {
+ if milliseconds > disableThreshold {
+ err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
+ }
+ // disable & notify
+ channel.UpdateStatus(common.ChannelStatusDisabled)
+ subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
+ content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
+ err = common.SendEmail(subject, email, content)
+ if err != nil {
+ common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
+ }
+ }
+ channel.UpdateResponseTime(milliseconds)
+ }
+ err := common.SendEmail("通道测试完成", email, "通道测试完成")
+ if err != nil {
+ common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
+ }
+ }()
+ return nil
+}
+
+func TestAllChannels(c *gin.Context) {
+ err := testAllChannels(c)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
diff --git a/model/channel.go b/model/channel.go
index e443d7be..7b1c9ec2 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -19,10 +19,14 @@ type Channel struct {
Other string `json:"other"`
}
-func GetAllChannels(startIdx int, num int) ([]*Channel, error) {
+func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
var channels []*Channel
var err error
- err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ if selectAll {
+ err = DB.Order("id desc").Find(&channels).Error
+ } else {
+ err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
+ }
return channels, err
}
@@ -82,6 +86,13 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
}
}
+func (channel *Channel) UpdateStatus(status int) {
+ err := DB.Model(channel).Update("status", status).Error
+ if err != nil {
+ common.SysError("failed to update response time: " + err.Error())
+ }
+}
+
func (channel *Channel) Delete() error {
var err error
err = DB.Delete(channel).Error
diff --git a/model/user.go b/model/user.go
index 3a8f2313..b121753d 100644
--- a/model/user.go
+++ b/model/user.go
@@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
+
+func GetRootUserEmail() (email string) {
+ DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+ return email
+}
diff --git a/router/api-router.go b/router/api-router.go
index 0e249dc9..9e7f580d 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel)
+ channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel)
diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js
index 7fe24645..c3531fea 100644
--- a/web/src/components/ChannelsTable.js
+++ b/web/src/components/ChannelsTable.js
@@ -170,6 +170,16 @@ const ChannelsTable = () => {
}
};
+ const testAllChannels = async () => {
+ const res = await API.get(`/api/channel/test`);
+ const { success, message } = res.data;
+ if (success) {
+ showSuccess("已成功开始测试所有已启用通道,请刷新页面查看结果。");
+ } else {
+ showError(message);
+ }
+ }
+
const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim());
};
@@ -335,6 +345,9 @@ const ChannelsTable = () => {
+