feat: support automatic channel testing & balance updates (close #11, close #59)

This commit is contained in:
JustSong 2023-06-22 22:01:03 +08:00
parent ad1049b0cf
commit 4463224f04
5 changed files with 57 additions and 15 deletions

View File

@ -250,6 +250,12 @@ graph LR
+ 例子:`SYNC_FREQUENCY=60` + 例子:`SYNC_FREQUENCY=60`
6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master``slave`,未设置则默认为 `master` 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master``slave`,未设置则默认为 `master`
+ 例子:`NODE_TYPE=slave` + 例子:`NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440`
9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5`
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000` 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`

View File

@ -2,6 +2,7 @@ package common
import ( import (
"os" "os"
"strconv"
"sync" "sync"
"time" "time"
@ -70,6 +71,9 @@ var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
RoleCommonUser = 1 RoleCommonUser = 1

View File

@ -257,6 +257,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足") disableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(common.RequestInterval)
} }
return nil return nil
} }
@ -277,3 +278,12 @@ func UpdateAllChannelsBalance(c *gin.Context) {
}) })
return return
} }
func AutomaticallyUpdateChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("updating all channels")
_ = updateAllChannelsBalance()
common.SysLog("channels update done")
}
}

View File

@ -62,10 +62,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
return nil return nil
} }
func buildTestRequest(c *gin.Context) *ChatRequest { func buildTestRequest() *ChatRequest {
model_ := c.Query("model")
testRequest := &ChatRequest{ testRequest := &ChatRequest{
Model: model_, Model: "", // this will be set later
MaxTokens: 1, MaxTokens: 1,
} }
testMessage := Message{ testMessage := Message{
@ -93,7 +92,7 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest(c) testRequest := buildTestRequest()
tik := time.Now() tik := time.Now()
err = testChannel(channel, *testRequest) err = testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
@ -133,7 +132,7 @@ func disableChannel(channelId int, channelName string, reason string) {
} }
} }
func testAllChannels(c *gin.Context) error { func testAllChannels(notify bool) error {
if common.RootUserEmail == "" { if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail() common.RootUserEmail = model.GetRootUserEmail()
} }
@ -146,13 +145,9 @@ func testAllChannels(c *gin.Context) error {
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, true) channels, err := model.GetAllChannels(0, 0, true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return err return err
} }
testRequest := buildTestRequest(c) testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
@ -173,20 +168,23 @@ func testAllChannels(c *gin.Context) error {
disableChannel(channel.Id, channel.Name, err.Error()) disableChannel(channel.Id, channel.Name, err.Error())
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
} time.Sleep(common.RequestInterval)
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify {
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}() }()
return nil return nil
} }
func TestAllChannels(c *gin.Context) { func TestAllChannels(c *gin.Context) {
err := testAllChannels(c) err := testAllChannels(true)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@ -200,3 +198,12 @@ func TestAllChannels(c *gin.Context) {
}) })
return return
} }
func AutomaticallyTestChannels(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Minute)
common.SysLog("testing all channels")
_ = testAllChannels(false)
common.SysLog("channel test finished")
}
}

15
main.go
View File

@ -7,6 +7,7 @@ import (
"github.com/gin-contrib/sessions/redis" "github.com/gin-contrib/sessions/redis"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/middleware" "one-api/middleware"
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
@ -59,6 +60,20 @@ func main() {
go model.SyncChannelCache(frequency) go model.SyncChannelCache(frequency)
} }
} }
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyUpdateChannels(frequency)
}
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil {
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
}
go controller.AutomaticallyTestChannels(frequency)
}
// Initialize HTTP server // Initialize HTTP server
server := gin.Default() server := gin.Default()