diff --git a/README.md b/README.md index d06df9ed..23d18786 100644 --- a/README.md +++ b/README.md @@ -250,6 +250,12 @@ graph LR + 例子:`SYNC_FREQUENCY=60` 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` +7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` +8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + + 例子:`CHANNEL_TEST_FREQUENCY=1440` +9. `REQUEST_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + + 例子:`POLLING_INTERVAL=5` ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 373ad88a..4439e2a5 100644 --- a/common/constants.go +++ b/common/constants.go @@ -2,6 +2,7 @@ package common import ( "os" + "strconv" "sync" "time" @@ -70,6 +71,9 @@ var RootUserEmail = "" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" +var requestInterval, _ = strconv.Atoi(os.Getenv("REQUEST_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 4f89c6ed..fbf57508 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -257,6 +257,7 @@ func updateAllChannelsBalance() error { disableChannel(channel.Id, channel.Name, "余额不足") } } + time.Sleep(common.RequestInterval) } return nil } @@ -277,3 +278,12 @@ func UpdateAllChannelsBalance(c *gin.Context) { }) return } + +func AutomaticallyUpdateChannels(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("updating all channels") + _ = updateAllChannelsBalance() + common.SysLog("channels update done") + } +} diff --git a/controller/channel-test.go b/controller/channel-test.go index b74ab095..f2ffad01 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -62,10 +62,9 @@ func testChannel(channel *model.Channel, request ChatRequest) error { return nil } -func buildTestRequest(c *gin.Context) *ChatRequest { - model_ := c.Query("model") +func buildTestRequest() *ChatRequest { testRequest := &ChatRequest{ - Model: model_, + Model: "", // this will be set later MaxTokens: 1, } testMessage := Message{ @@ -93,7 +92,7 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest(c) + testRequest := buildTestRequest() tik := time.Now() err = testChannel(channel, *testRequest) tok := time.Now() @@ -133,7 +132,7 @@ func disableChannel(channelId int, channelName string, reason string) { } } -func testAllChannels(c *gin.Context) error { +func testAllChannels(notify bool) error { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } @@ -146,13 +145,9 @@ func testAllChannels(c *gin.Context) error { testAllChannelsLock.Unlock() channels, err := model.GetAllChannels(0, 0, true) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) return err } - testRequest := buildTestRequest(c) + testRequest := buildTestRequest() var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value @@ -173,20 +168,23 @@ func testAllChannels(c *gin.Context) error { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) - } - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") - if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + time.Sleep(common.RequestInterval) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() + if notify { + err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + if err != nil { + common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } + } }() return nil } func TestAllChannels(c *gin.Context) { - err := testAllChannels(c) + err := testAllChannels(true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -200,3 +198,12 @@ func TestAllChannels(c *gin.Context) { }) return } + +func AutomaticallyTestChannels(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("testing all channels") + _ = testAllChannels(false) + common.SysLog("channel test finished") + } +} diff --git a/main.go b/main.go index a6430a12..b9fa95d7 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/gin-contrib/sessions/redis" "github.com/gin-gonic/gin" "one-api/common" + "one-api/controller" "one-api/middleware" "one-api/model" "one-api/router" @@ -59,6 +60,20 @@ func main() { go model.SyncChannelCache(frequency) } } + if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) + if err != nil { + common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + } + go controller.AutomaticallyUpdateChannels(frequency) + } + if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) + if err != nil { + common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + } + go controller.AutomaticallyTestChannels(frequency) + } // Initialize HTTP server server := gin.Default()