From 13b3bfee2a489a96738ca2329a41bd7aeaecec76 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Mon, 24 Jul 2023 22:05:21 +0800 Subject: [PATCH] fix: channel issue --- middleware/distributor.go | 2 ++ model/ability.go | 20 ++++++++++++-------- model/channel.go | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 1940c69c..bf77ccc7 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "log" "net/http" "one-api/common" "one-api/model" @@ -85,6 +86,7 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } + log.Print(modelRequest.Stream) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) diff --git a/model/ability.go b/model/ability.go index 465acb88..e3c4c444 100644 --- a/model/ability.go +++ b/model/ability.go @@ -7,10 +7,12 @@ import ( ) type Ability struct { - Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` - Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` - ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` - Enabled bool `json:"enabled"` + Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` + Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` + ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` + Enabled bool `json:"enabled"` + AllowStreaming int `json:"allow_streaming" gorm:"default:1"` + AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"` } func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { @@ -46,10 +48,12 @@ func (channel *Channel) AddAbilities() error { for _, model := range models_ { for _, group := range groups_ { ability := Ability{ - Group: group, - Model: model, - ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Group: group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == common.ChannelStatusEnabled, + AllowStreaming: channel.AllowStreaming, + AllowNonStreaming: channel.AllowNonStreaming, } abilities = append(abilities, ability) } diff --git a/model/channel.go b/model/channel.go index 738e6f7c..cb03ad11 100644 --- a/model/channel.go +++ b/model/channel.go @@ -24,8 +24,8 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - AllowStreaming int `json:"allow_streaming" gorm:"default:2"` - AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:2"` + AllowStreaming int `json:"allow_streaming" gorm:"default:1"` + AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {