diff --git a/README.md b/README.md index c1a651ca..897dc118 100644 --- a/README.md +++ b/README.md @@ -375,6 +375,9 @@ graph LR 16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/config/config.go b/common/config/config.go index 47053e98..4d391aac 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -126,3 +126,9 @@ var ( ) var RateLimitKeyExpirationDuration = 20 * time.Minute + +var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) +var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/helper/helper.go b/common/helper/helper.go index babe422b..23578842 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -195,6 +195,13 @@ func Max(a int, b int) int { } } +func GetOrDefaultEnvBool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + func GetOrDefaultEnvInt(env string, defaultValue int) int { if env == "" || os.Getenv(env) == "" { return defaultValue @@ -207,6 +214,18 @@ func GetOrDefaultEnvInt(env string, defaultValue int) int { return num } +func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + func GetOrDefaultEnvString(env string, defaultValue string) string { if env == "" || os.Getenv(env) == "" { return defaultValue diff --git a/controller/channel-billing.go b/controller/channel-billing.go index abeab26a..55eec54a 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/util" "io" "net/http" @@ -313,7 +314,7 @@ func updateAllChannelsBalance() error { } else { // err is nil & balance <= 0 means quota is used up if balance <= 0 { - disableChannel(channel.Id, channel.Name, "余额不足") + monitor.DisableChannel(channel.Id, channel.Name, "余额不足") } } time.Sleep(config.RequestInterval) diff --git a/controller/channel-test.go b/controller/channel-test.go index 7007e205..6fe18d6a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -10,6 +10,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" relaymodel "github.com/songquanpeng/one-api/relay/model" @@ -148,32 +149,6 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func notifyRootUser(subject string, content string) { - if config.RootUserEmail == "" { - config.RootUserEmail = model.GetRootUserEmail() - } - err := common.SendEmail(subject, config.RootUserEmail, content) - if err != nil { - logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) - } -} - -// disable & notify -func disableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) - notifyRootUser(subject, content) -} - -// enable & notify -func enableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) - notifyRootUser(subject, content) -} - func testAllChannels(notify bool) error { if config.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail() @@ -202,13 +177,13 @@ func testAllChannels(notify bool) error { milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - disableChannel(channel.Id, channel.Name, err.Error()) + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { - disableChannel(channel.Id, channel.Name, err.Error()) + monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { - enableChannel(channel.Id, channel.Name) + monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) time.Sleep(config.RequestInterval) diff --git a/controller/relay.go b/controller/relay.go index 9b2d462c..b34768df 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" @@ -45,11 +46,12 @@ func Relay(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } + channelId := c.GetInt("channel_id") bizErr := relay(c, relayMode) if bizErr == nil { + monitor.Emit(channelId, true) return } - channelId := c.GetInt("channel_id") lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") @@ -117,7 +119,9 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if util.ShouldDisableChannel(&err.Error, err.StatusCode) { - disableChannel(channelId, channelName, err.Message) + monitor.DisableChannel(channelId, channelName, err.Message) + } else { + monitor.Emit(channelId, false) } } diff --git a/main.go b/main.go index 1f43a45f..96603066 100644 --- a/main.go +++ b/main.go @@ -83,6 +83,9 @@ func main() { logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } + if config.EnableMetric { + logger.SysLog("metric enabled, will disable channel if too much request failed") + } openai.InitTokenEncoders() // Initialize HTTP server diff --git a/monitor/channel.go b/monitor/channel.go new file mode 100644 index 00000000..12394913 --- /dev/null +++ b/monitor/channel.go @@ -0,0 +1,46 @@ +package monitor + +import ( + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func notifyRootUser(subject string, content string) { + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() + } + err := common.SendEmail(subject, config.RootUserEmail, content) + if err != nil { + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + } +} + +// DisableChannel disable & notify +func DisableChannel(channelId int, channelName string, reason string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + notifyRootUser(subject, content) +} + +func MetricDisableChannel(channelId int, successRate float64) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) + subject := fmt.Sprintf("通道 #%d 已被禁用", channelId) + content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", + config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100) + notifyRootUser(subject, content) +} + +// EnableChannel enable & notify +func EnableChannel(channelId int, channelName string) { + model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) + subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId) + notifyRootUser(subject, content) +} diff --git a/monitor/metric.go b/monitor/metric.go new file mode 100644 index 00000000..98bc546e --- /dev/null +++ b/monitor/metric.go @@ -0,0 +1,79 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" +) + +var store = make(map[int][]bool) +var metricSuccessChan = make(chan int, config.MetricSuccessChanSize) +var metricFailChan = make(chan int, config.MetricFailChanSize) + +func consumeSuccess(channelId int) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], true) +} + +func consumeFail(channelId int) (bool, float64) { + if len(store[channelId]) > config.MetricQueueSize { + store[channelId] = store[channelId][1:] + } + store[channelId] = append(store[channelId], false) + successCount := 0 + for _, success := range store[channelId] { + if success { + successCount++ + } + } + successRate := float64(successCount) / float64(len(store[channelId])) + if len(store[channelId]) < config.MetricQueueSize { + return false, successRate + } + if successRate < config.MetricSuccessRateThreshold { + store[channelId] = make([]bool, 0) + return true, successRate + } + return false, successRate +} + +func metricSuccessConsumer() { + for { + select { + case channelId := <-metricSuccessChan: + consumeSuccess(channelId) + } + } +} + +func metricFailConsumer() { + for { + select { + case channelId := <-metricFailChan: + disable, successRate := consumeFail(channelId) + if disable { + go MetricDisableChannel(channelId, successRate) + } + } + } +} + +func init() { + if config.EnableMetric { + go metricSuccessConsumer() + go metricFailConsumer() + } +} + +func Emit(channelId int, success bool) { + if !config.EnableMetric { + return + } + go func() { + if success { + metricSuccessChan <- channelId + } else { + metricFailChan <- channelId + } + }() +}