diff --git a/middleware/auth.go b/middleware/auth.go index 4964a4dd..1eb0c2b4 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -108,16 +108,20 @@ func tokenAuth(c *gin.Context, key string) { c.Set("chat_cache", token.ChatCache) if len(parts) > 1 { if model.IsAdmin(token.UserId) { - channelId := common.String2Int(parts[1]) - if channelId == 0 { - abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") - return + if strings.HasPrefix(parts[1], "!") { + channelId := common.String2Int(parts[1][1:]) + c.Set("skip_channel_id", channelId) + } else { + channelId := common.String2Int(parts[1]) + if channelId == 0 { + abortWithMessage(c, http.StatusForbidden, "无效的渠道 Id") + return + } + c.Set("specific_channel_id", channelId) + if len(parts) == 3 && parts[2] == "ignore" { + c.Set("specific_channel_id_ignore", true) + } } - c.Set("specific_channel_id", channelId) - if len(parts) == 3 && parts[2] == "ignore" { - c.Set("specific_channel_id_ignore", true) - } - } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return diff --git a/model/balancer.go b/model/balancer.go index ecf2721d..1b65550f 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -35,12 +35,15 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool { return true } -func (cc *ChannelsChooser) balancer(channelIds []int) *Channel { +func (cc *ChannelsChooser) balancer(channelIds []int, skipChannelId int) *Channel { nowTime := time.Now().Unix() totalWeight := 0 validChannels := make([]*ChannelChoice, 0, len(channelIds)) for _, channelId := range channelIds { + if skipChannelId > 0 && channelId == skipChannelId { + continue + } if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime { weight := int(*choice.Channel.Weight) totalWeight += weight @@ -68,7 +71,7 @@ func (cc *ChannelsChooser) balancer(channelIds []int) *Channel { return nil } -func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) { +func (cc *ChannelsChooser) Next(group, modelName string, skipChannelId int) (*Channel, error) { cc.RLock() defer cc.RUnlock() if _, ok := cc.Rule[group]; !ok { @@ -89,7 +92,7 @@ func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) { } for _, priority := range channelsPriority { - channel := cc.balancer(priority) + channel := cc.balancer(priority, skipChannelId) if channel != nil { return channel, nil } diff --git a/relay/common.go b/relay/common.go index b6872b51..56f46181 100644 --- a/relay/common.go +++ b/relay/common.go @@ -100,7 +100,8 @@ func fetchChannelById(channelId int) (*model.Channel, error) { func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) { group := c.GetString("group") - channel, err := model.ChannelGroup.Next(group, modelName) + skip_channel_id := c.GetInt("skip_channel_id") + channel, err := model.ChannelGroup.Next(group, modelName, skip_channel_id) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) if channel != nil {