feat: able to disable channel by success rate

This commit is contained in:
JustSong 2024-03-10 17:57:47 +08:00
parent 6ebc99460e
commit 12440874b0
9 changed files with 168 additions and 32 deletions

View File

@ -375,6 +375,9 @@ graph LR
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`
17. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true``false`
20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`
21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`

View File

@ -126,3 +126,9 @@ var (
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)

View File

@ -195,6 +195,13 @@ func Max(a int, b int) int {
}
}
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
@ -207,6 +214,18 @@ func GetOrDefaultEnvInt(env string, defaultValue int) int {
return num
}
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue

View File

@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@ -313,7 +314,7 @@ func updateAllChannelsBalance() error {
} else {
// err is nil & balance <= 0 means quota is used up
if balance <= 0 {
disableChannel(channel.Id, channel.Name, "余额不足")
monitor.DisableChannel(channel.Id, channel.Name, "余额不足")
}
}
time.Sleep(config.RequestInterval)

View File

@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
@ -148,32 +149,6 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
@ -202,13 +177,13 @@ func testAllChannels(notify bool) error {
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error())
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(config.RequestInterval)

View File

@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
@ -45,11 +46,12 @@ func Relay(c *gin.Context) {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
disableChannel(channelId, channelName, err.Message)
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
}
}

View File

@ -83,6 +83,9 @@ func main() {
logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
if config.EnableMetric {
logger.SysLog("metric enabled, will disable channel if too much request failed")
}
openai.InitTokenEncoders()
// Initialize HTTP server

46
monitor/channel.go Normal file
View File

@ -0,0 +1,46 @@
package monitor
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func notifyRootUser(subject string, content string) {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("通道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

79
monitor/metric.go Normal file
View File

@ -0,0 +1,79 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
)
var store = make(map[int][]bool)
var metricSuccessChan = make(chan int, config.MetricSuccessChanSize)
var metricFailChan = make(chan int, config.MetricFailChanSize)
func consumeSuccess(channelId int) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], true)
}
func consumeFail(channelId int) (bool, float64) {
if len(store[channelId]) > config.MetricQueueSize {
store[channelId] = store[channelId][1:]
}
store[channelId] = append(store[channelId], false)
successCount := 0
for _, success := range store[channelId] {
if success {
successCount++
}
}
successRate := float64(successCount) / float64(len(store[channelId]))
if len(store[channelId]) < config.MetricQueueSize {
return false, successRate
}
if successRate < config.MetricSuccessRateThreshold {
store[channelId] = make([]bool, 0)
return true, successRate
}
return false, successRate
}
func metricSuccessConsumer() {
for {
select {
case channelId := <-metricSuccessChan:
consumeSuccess(channelId)
}
}
}
func metricFailConsumer() {
for {
select {
case channelId := <-metricFailChan:
disable, successRate := consumeFail(channelId)
if disable {
go MetricDisableChannel(channelId, successRate)
}
}
}
}
func init() {
if config.EnableMetric {
go metricSuccessConsumer()
go metricFailConsumer()
}
}
func Emit(channelId int, success bool) {
if !config.EnableMetric {
return
}
go func() {
if success {
metricSuccessChan <- channelId
} else {
metricFailChan <- channelId
}
}()
}