diff --git a/model/ability.go b/model/ability.go index eb68fa0d..8724e69f 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,7 +10,7 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { diff --git a/model/cache.go b/model/cache.go index 1b547842..b9d6b612 100644 --- a/model/cache.go +++ b/model/cache.go @@ -165,7 +165,7 @@ func InitChannelCache() { for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { sort.Slice(channels, func(i, j int) bool { - return channels[i].Priority > channels[j].Priority + return channels[i].GetPriority() > channels[j].GetPriority() }) newGroup2model2channels[group][model] = channels } @@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } + endIdx := len(channels) // choose by priority firstChannel := channels[0] - if firstChannel.Priority > 0 { - return firstChannel, nil + if firstChannel.GetPriority() > 0 { + for i := range channels { + if channels[i].GetPriority() != firstChannel.GetPriority() { + endIdx = i + break + } + } } - idx := rand.Intn(len(channels)) + idx := rand.Intn(endIdx) return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index d146193b..1a478b91 100644 --- a/model/channel.go +++ b/model/channel.go @@ -23,7 +23,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -79,6 +79,13 @@ func BatchInsertChannels(channels []Channel) error { return nil } +func (channel *Channel) GetPriority() int64 { + if channel == nil { + return 0 + } + return *channel.Priority +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error