From 2bc2dfca3991cdeb57999895af8c683694aa03b4 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Wed, 15 May 2024 17:51:49 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Support=20channel=20setting?= =?UTF-8?q?=20chat=20only?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/balancer.go | 41 ++++++++++++++++---- model/channel.go | 1 + relay/chat.go | 4 ++ relay/common.go | 11 +++++- web/src/views/Channel/component/EditModal.js | 16 ++++++++ web/src/views/Channel/type/Config.js | 9 +++-- 6 files changed, 70 insertions(+), 12 deletions(-) diff --git a/model/balancer.go b/model/balancer.go index 1b65550f..c71df3ce 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -21,6 +21,20 @@ type ChannelsChooser struct { Match []string } +type ChannelsFilterFunc func(channelId int, choice *ChannelChoice) bool + +func FilterChannelId(skipChannelId int) ChannelsFilterFunc { + return func(channelId int, choice *ChannelChoice) bool { + return skipChannelId > 0 && channelId == skipChannelId + } +} + +func FilterOnlyChat() ChannelsFilterFunc { + return func(channelId int, choice *ChannelChoice) bool { + return choice.Channel.OnlyChat + } +} + func (cc *ChannelsChooser) Cooldowns(channelId int) bool { if common.RetryCooldownSeconds == 0 { return false @@ -35,20 +49,31 @@ func (cc *ChannelsChooser) Cooldowns(channelId int) bool { return true } -func (cc *ChannelsChooser) balancer(channelIds []int, skipChannelId int) *Channel { +func (cc *ChannelsChooser) balancer(channelIds []int, filters []ChannelsFilterFunc) *Channel { nowTime := time.Now().Unix() totalWeight := 0 validChannels := make([]*ChannelChoice, 0, len(channelIds)) for _, channelId := range channelIds { - if skipChannelId > 0 && channelId == skipChannelId { + choice, ok := cc.Channels[channelId] + if !ok || choice.CooldownsTime >= nowTime { continue } - if choice, ok := cc.Channels[channelId]; ok && choice.CooldownsTime < nowTime { - weight := int(*choice.Channel.Weight) - totalWeight += weight - validChannels = append(validChannels, choice) + + isSkip := false + for _, filter := range filters { + if filter(channelId, choice) { + isSkip = true + break + } } + if isSkip { + continue + } + + weight := int(*choice.Channel.Weight) + totalWeight += weight + validChannels = append(validChannels, choice) } if len(validChannels) == 0 { @@ -71,7 +96,7 @@ func (cc *ChannelsChooser) balancer(channelIds []int, skipChannelId int) *Channe return nil } -func (cc *ChannelsChooser) Next(group, modelName string, skipChannelId int) (*Channel, error) { +func (cc *ChannelsChooser) Next(group, modelName string, filters ...ChannelsFilterFunc) (*Channel, error) { cc.RLock() defer cc.RUnlock() if _, ok := cc.Rule[group]; !ok { @@ -92,7 +117,7 @@ func (cc *ChannelsChooser) Next(group, modelName string, skipChannelId int) (*Ch } for _, priority := range channelsPriority { - channel := cc.balancer(priority, skipChannelId) + channel := cc.balancer(priority, filters) if channel != nil { return channel, nil } diff --git a/model/channel.go b/model/channel.go index c7cd9a75..7f6eb615 100644 --- a/model/channel.go +++ b/model/channel.go @@ -29,6 +29,7 @@ type Channel struct { Priority *int64 `json:"priority" gorm:"bigint;default:0"` Proxy *string `json:"proxy" gorm:"type:varchar(255);default:''"` TestModel string `json:"test_model" form:"test_model" gorm:"type:varchar(50);default:''"` + OnlyChat bool `json:"only_chat" form:"only_chat" gorm:"default:false"` Plugin *datatypes.JSONType[PluginType] `json:"plugin" form:"plugin" gorm:"type:json"` } diff --git a/relay/chat.go b/relay/chat.go index 0e28be26..3e3de6b4 100644 --- a/relay/chat.go +++ b/relay/chat.go @@ -32,6 +32,10 @@ func (r *relayChat) setRequest() error { return errors.New("max_tokens is invalid") } + if r.chatRequest.Tools != nil { + r.c.Set("skip_only_chat", true) + } + r.originalModel = r.chatRequest.Model return nil diff --git a/relay/common.go b/relay/common.go index 56f46181..a230e098 100644 --- a/relay/common.go +++ b/relay/common.go @@ -101,7 +101,16 @@ func fetchChannelById(channelId int) (*model.Channel, error) { func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) { group := c.GetString("group") skip_channel_id := c.GetInt("skip_channel_id") - channel, err := model.ChannelGroup.Next(group, modelName, skip_channel_id) + skip_only_chat := c.GetBool("skip_only_chat") + var filters []model.ChannelsFilterFunc + if skip_only_chat { + filters = append(filters, model.FilterOnlyChat()) + } + if skip_channel_id > 0 { + filters = append(filters, model.FilterChannelId(skip_channel_id)) + } + + channel, err := model.ChannelGroup.Next(group, modelName, filters...) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName) if channel != nil { diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index 069673b6..99730822 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -629,6 +629,22 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => { )} )} + {inputPrompt.only_chat && ( + + { + setFieldValue('only_chat', !values.only_chat); + }} + /> + } + label={inputLabel.only_chat} + /> + {inputPrompt.only_chat} + + )} {pluginList[values.type] && Object.keys(pluginList[values.type]).map((pluginId) => { const plugin = pluginList[values.type][pluginId]; diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 74ca393d..09d5537f 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -10,7 +10,8 @@ const defaultConfig = { model_mapping: '', models: [], groups: ['default'], - plugin: {} + plugin: {}, + only_chat: false }, inputLabel: { name: '渠道名称', @@ -22,7 +23,8 @@ const defaultConfig = { test_model: '测速模型', models: '模型', model_mapping: '模型映射关系', - groups: '用户组' + groups: '用户组', + only_chat: '仅支持聊天' }, prompt: { type: '请选择渠道类型', @@ -36,7 +38,8 @@ const defaultConfig = { '请选择该渠道所支持的模型,你也可以输入通配符*来匹配模型,例如:gpt-3.5*,表示支持所有gpt-3.5开头的模型,*号只能在最后一位使用,前面必须有字符,例如:gpt-3.5*是正确的,*gpt-3.5是错误的', model_mapping: '请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}', - groups: '请选择该渠道所支持的用户组' + groups: '请选择该渠道所支持的用户组', + only_chat: '如果选择了仅支持聊天,那么遇到有函数调用的请求会跳过该渠道' }, modelGroup: 'OpenAI' };