feat: able to test all enabled channels (#59)

This commit is contained in:
JustSong 2023-05-15 12:36:55 +08:00
parent 570b3bc71c
commit d267211ee7
5 changed files with 116 additions and 13 deletions

View File

@ -11,6 +11,7 @@ import (
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage) channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
return nil return nil
} }
func buildTestRequest(c *gin.Context) *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{
Model: model_,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model_ := c.Query("model") testRequest := buildTestRequest(c)
chatRequest := &ChatRequest{
Model: model_,
}
testMessage := Message{
Role: "user",
Content: "echo hi",
}
chatRequest.Messages = append(chatRequest.Messages, testMessage)
tik := time.Now() tik := time.Now()
err = testChannel(channel, chatRequest) err = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
var testAllChannelsLock sync.Mutex
func testAllChannels(c *gin.Context) error {
ok := testAllChannelsLock.TryLock()
if !ok {
return errors.New("测试已在运行")
}
defer testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err
}
testRequest := buildTestRequest(c)
var disableThreshold int64 = 5000 // TODO: make it configurable
email := model.GetRootUserEmail()
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now()
err := testChannel(channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold {
if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
}
// disable & notify
channel.UpdateStatus(common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channel.Name, channel.Id)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channel.Name, channel.Id, err.Error())
err = common.SendEmail(subject, email, content)
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}
channel.UpdateResponseTime(milliseconds)
}
err := common.SendEmail("通道测试完成", email, "通道测试完成")
if err != nil {
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
}
}()
return nil
}
func TestAllChannels(c *gin.Context) {
err := testAllChannels(c)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

View File

@ -19,10 +19,14 @@ type Channel struct {
Other string `json:"other"` Other string `json:"other"`
} }
func GetAllChannels(startIdx int, num int) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
var channels []*Channel var channels []*Channel
var err error var err error
if selectAll {
err = DB.Order("id desc").Find(&channels).Error
} else {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
return channels, err return channels, err
} }
@ -82,6 +86,13 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
} }
} }
func (channel *Channel) UpdateStatus(status int) {
err := DB.Model(channel).Update("status", status).Error
if err != nil {
common.SysError("failed to update response time: " + err.Error())
}
}
func (channel *Channel) Delete() error { func (channel *Channel) Delete() error {
var err error var err error
err = DB.Delete(channel).Error err = DB.Delete(channel).Error

View File

@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err return err
} }
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
}

View File

@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/test/:id", controller.TestChannel)
channelRoute.POST("/", controller.AddChannel) channelRoute.POST("/", controller.AddChannel)
channelRoute.PUT("/", controller.UpdateChannel) channelRoute.PUT("/", controller.UpdateChannel)

View File

@ -170,6 +170,16 @@ const ChannelsTable = () => {
} }
}; };
const testAllChannels = async () => {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showSuccess("已成功开始测试所有已启用通道,请刷新页面查看结果。");
} else {
showError(message);
}
}
const handleKeywordChange = async (e, { value }) => { const handleKeywordChange = async (e, { value }) => {
setSearchKeyword(value.trim()); setSearchKeyword(value.trim());
}; };
@ -335,6 +345,9 @@ const ChannelsTable = () => {
<Button size='small' as={Link} to='/channel/add' loading={loading}> <Button size='small' as={Link} to='/channel/add' loading={loading}>
添加新的渠道 添加新的渠道
</Button> </Button>
<Button size='small' loading={loading} onClick={testAllChannels}>
测试所有已启用通道
</Button>
<Pagination <Pagination
floated='right' floated='right'
activePage={activePage} activePage={activePage}