feat: able to test all enabled channels (#59)
This commit is contained in:
parent
570b3bc71c
commit
d267211ee7
@ -11,6 +11,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ func GetAllChannels(c *gin.Context) {
|
|||||||
if p < 0 {
|
if p < 0 {
|
||||||
p = 0
|
p = 0
|
||||||
}
|
}
|
||||||
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage)
|
channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -206,6 +207,19 @@ func testChannel(channel *model.Channel, request *ChatRequest) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildTestRequest(c *gin.Context) *ChatRequest {
|
||||||
|
model_ := c.Query("model")
|
||||||
|
testRequest := &ChatRequest{
|
||||||
|
Model: model_,
|
||||||
|
}
|
||||||
|
testMessage := Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: "echo hi",
|
||||||
|
}
|
||||||
|
testRequest.Messages = append(testRequest.Messages, testMessage)
|
||||||
|
return testRequest
|
||||||
|
}
|
||||||
|
|
||||||
func TestChannel(c *gin.Context) {
|
func TestChannel(c *gin.Context) {
|
||||||
id, err := strconv.Atoi(c.Param("id"))
|
id, err := strconv.Atoi(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -223,17 +237,9 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
model_ := c.Query("model")
|
testRequest := buildTestRequest(c)
|
||||||
chatRequest := &ChatRequest{
|
|
||||||
Model: model_,
|
|
||||||
}
|
|
||||||
testMessage := Message{
|
|
||||||
Role: "user",
|
|
||||||
Content: "echo hi",
|
|
||||||
}
|
|
||||||
chatRequest.Messages = append(chatRequest.Messages, testMessage)
|
|
||||||
tik := time.Now()
|
tik := time.Now()
|
||||||
err = testChannel(channel, chatRequest)
|
err = testChannel(channel, testRequest)
|
||||||
tok := time.Now()
|
tok := time.Now()
|
||||||
milliseconds := tok.Sub(tik).Milliseconds()
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
go channel.UpdateResponseTime(milliseconds)
|
go channel.UpdateResponseTime(milliseconds)
|
||||||
@ -253,3 +259,70 @@ func TestChannel(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var testAllChannelsLock sync.Mutex
|
||||||
|
|
||||||
|
func testAllChannels(c *gin.Context) error {
|
||||||
|
ok := testAllChannelsLock.TryLock()
|
||||||
|
if !ok {
|
||||||
|
return errors.New("测试已在运行")
|
||||||
|
}
|
||||||
|
defer testAllChannelsLock.Unlock()
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
testRequest := buildTestRequest(c)
|
||||||
|
var disableThreshold int64 = 5000 // TODO: make it configurable
|
||||||
|
email := model.GetRootUserEmail()
|
||||||
|
go func() {
|
||||||
|
for _, channel := range channels {
|
||||||
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tik := time.Now()
|
||||||
|
err := testChannel(channel, testRequest)
|
||||||
|
tok := time.Now()
|
||||||
|
milliseconds := tok.Sub(tik).Milliseconds()
|
||||||
|
if err != nil || milliseconds > disableThreshold {
|
||||||
|
if milliseconds > disableThreshold {
|
||||||
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||||
|
}
|
||||||
|
// disable & notify
|
||||||
|
channel.UpdateStatus(common.ChannelStatusDisabled)
|
||||||
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id)
|
||||||
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error())
|
||||||
|
err = common.SendEmail(subject, email, content)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
channel.UpdateResponseTime(milliseconds)
|
||||||
|
}
|
||||||
|
err := common.SendEmail("通道测试完成", email, "通道测试完成")
|
||||||
|
if err != nil {
|
||||||
|
common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllChannels(c *gin.Context) {
|
||||||
|
err := testAllChannels(c)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -19,10 +19,14 @@ type Channel struct {
|
|||||||
Other string `json:"other"`
|
Other string `json:"other"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
var channels []*Channel
|
var channels []*Channel
|
||||||
var err error
|
var err error
|
||||||
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
if selectAll {
|
||||||
|
err = DB.Order("id desc").Find(&channels).Error
|
||||||
|
} else {
|
||||||
|
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||||
|
}
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,6 +86,13 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) UpdateStatus(status int) {
|
||||||
|
err := DB.Model(channel).Update("status", status).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update response time: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Delete() error {
|
func (channel *Channel) Delete() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Delete(channel).Error
|
err = DB.Delete(channel).Error
|
||||||
|
@ -234,3 +234,8 @@ func DecreaseUserQuota(id int, quota int) (err error) {
|
|||||||
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRootUserEmail() (email string) {
|
||||||
|
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute.GET("/", controller.GetAllChannels)
|
channelRoute.GET("/", controller.GetAllChannels)
|
||||||
channelRoute.GET("/search", controller.SearchChannels)
|
channelRoute.GET("/search", controller.SearchChannels)
|
||||||
channelRoute.GET("/:id", controller.GetChannel)
|
channelRoute.GET("/:id", controller.GetChannel)
|
||||||
|
channelRoute.GET("/test", controller.TestAllChannels)
|
||||||
channelRoute.GET("/test/:id", controller.TestChannel)
|
channelRoute.GET("/test/:id", controller.TestChannel)
|
||||||
channelRoute.POST("/", controller.AddChannel)
|
channelRoute.POST("/", controller.AddChannel)
|
||||||
channelRoute.PUT("/", controller.UpdateChannel)
|
channelRoute.PUT("/", controller.UpdateChannel)
|
||||||
|
@ -170,6 +170,16 @@ const ChannelsTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const testAllChannels = async () => {
|
||||||
|
const res = await API.get(`/api/channel/test`);
|
||||||
|
const { success, message } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showSuccess("已成功开始测试所有已启用通道,请刷新页面查看结果。");
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleKeywordChange = async (e, { value }) => {
|
const handleKeywordChange = async (e, { value }) => {
|
||||||
setSearchKeyword(value.trim());
|
setSearchKeyword(value.trim());
|
||||||
};
|
};
|
||||||
@ -335,6 +345,9 @@ const ChannelsTable = () => {
|
|||||||
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
<Button size='small' as={Link} to='/channel/add' loading={loading}>
|
||||||
添加新的渠道
|
添加新的渠道
|
||||||
</Button>
|
</Button>
|
||||||
|
<Button size='small' loading={loading} onClick={testAllChannels}>
|
||||||
|
测试所有已启用通道
|
||||||
|
</Button>
|
||||||
<Pagination
|
<Pagination
|
||||||
floated='right'
|
floated='right'
|
||||||
activePage={activePage}
|
activePage={activePage}
|
||||||
|
Loading…
Reference in New Issue
Block a user