diff --git a/common/random.go b/common/random.go new file mode 100644 index 00000000..44bd2856 --- /dev/null +++ b/common/random.go @@ -0,0 +1,8 @@ +package common + +import "math/rand" + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/relay.go b/controller/relay.go index 278c0b32..33a8243d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -62,7 +62,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, true) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break diff --git a/middleware/distributor.go b/middleware/distributor.go index aeb2796a..e845c2f8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { } } requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { diff --git a/model/cache.go b/model/cache.go index 04a60348..3c3575b8 100644 --- a/model/cache.go +++ b/model/cache.go @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = common.RandRange(endIdx, len(channels)) + } + } return channels[idx], nil }