From a5647b1ea7c859100c86e517cdfcac660d44dab2 Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 18 Sep 2023 21:43:45 +0800 Subject: [PATCH 1/4] fix: fix priority not updated & random choice not working --- model/ability.go | 2 +- model/cache.go | 14 ++++++++++---- model/channel.go | 9 ++++++++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/model/ability.go b/model/ability.go index eb68fa0d..8724e69f 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,7 +10,7 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { diff --git a/model/cache.go b/model/cache.go index 1b547842..b9d6b612 100644 --- a/model/cache.go +++ b/model/cache.go @@ -165,7 +165,7 @@ func InitChannelCache() { for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { sort.Slice(channels, func(i, j int) bool { - return channels[i].Priority > channels[j].Priority + return channels[i].GetPriority() > channels[j].GetPriority() }) newGroup2model2channels[group][model] = channels } @@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } + endIdx := len(channels) // choose by priority firstChannel := channels[0] - if firstChannel.Priority > 0 { - return firstChannel, nil + if firstChannel.GetPriority() > 0 { + for i := range channels { + if channels[i].GetPriority() != firstChannel.GetPriority() { + endIdx = i + break + } + } } - idx := rand.Intn(len(channels)) + idx := rand.Intn(endIdx) return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index d146193b..1a478b91 100644 --- a/model/channel.go +++ b/model/channel.go @@ -23,7 +23,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -79,6 +79,13 @@ func BatchInsertChannels(channels []Channel) error { return nil } +func (channel *Channel) GetPriority() int64 { + if channel == nil { + return 0 + } + return *channel.Priority +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error From 159b9e3369fdca0fb90cafdc96317c0f299d26bb Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 18 Sep 2023 22:07:17 +0800 Subject: [PATCH 2/4] fix: fix unable to set zero value for base url & model mapping --- controller/channel-billing.go | 12 ++++++------ controller/channel-test.go | 6 +++--- middleware/distributor.go | 4 ++-- model/channel.go | 20 +++++++++++++++++--- web/src/pages/Channel/EditChannel.js | 3 --- 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 46262f6c..6ddad7ea 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He } func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { - url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) + url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] - if channel.BaseURL == "" { - channel.BaseURL = baseURL + if channel.GetBaseURL() == "" { + channel.BaseURL = &baseURL } switch channel.Type { case common.ChannelTypeOpenAI: - if channel.BaseURL != "" { - baseURL = channel.BaseURL + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() } case common.ChannelTypeAzure: return 0, errors.New("尚未实现") case common.ChannelTypeCustom: - baseURL = channel.BaseURL + baseURL = channel.GetBaseURL() case common.ChannelTypeCloseAI: return updateChannelCloseAIBalance(channel) case common.ChannelTypeOpenAISB: diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c7e6f0d..f7a565a2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) } else { - if channel.BaseURL != "" { - requestURL = channel.BaseURL + if channel.GetBaseURL() != "" { + requestURL = channel.GetBaseURL() } requestURL += "/v1/chat/completions" } diff --git a/middleware/distributor.go b/middleware/distributor.go index ab374a85..9ded3231 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -82,9 +82,9 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.ModelMapping) + c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.BaseURL) + c.Set("base_url", channel.GetBaseURL()) switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) diff --git a/model/channel.go b/model/channel.go index 1a478b91..8a5b79ff 100644 --- a/model/channel.go +++ b/model/channel.go @@ -15,14 +15,14 @@ type Channel struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds - BaseURL string `json:"base_url" gorm:"column:base_url"` + BaseURL *string `json:"base_url" gorm:"column:base_url"` Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` - ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` } @@ -80,12 +80,26 @@ func BatchInsertChannels(channels []Channel) error { } func (channel *Channel) GetPriority() int64 { - if channel == nil { + if channel.Priority == nil { return 0 } return *channel.Priority } +func (channel *Channel) GetBaseURL() string { + if channel.BaseURL == nil { + return "" + } + return *channel.BaseURL +} + +func (channel *Channel) GetModelMapping() string { + if channel.ModelMapping == nil { + return "" + } + return *channel.ModelMapping +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 78ff1952..e0053709 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -183,9 +183,6 @@ const EditChannel = () => { if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; } - if (localInputs.model_mapping === '') { - localInputs.model_mapping = '{}'; - } let res; localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(','); From 37e09d764cb7ec93fc96ba8742fb7644f27be917 Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 18 Sep 2023 22:39:10 +0800 Subject: [PATCH 3/4] fix: fix random selection is not working when directly using database --- model/ability.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/model/ability.go b/model/ability.go index 8724e69f..919de227 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,16 +10,18 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` - Priority *int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} var err error = nil if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) + err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model) + err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error } if err != nil { return nil, err From 1d258cc898d87366dce0ac746b3f53f70eeccf35 Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 18 Sep 2023 22:49:05 +0800 Subject: [PATCH 4/4] fix: add default value for base url --- model/channel.go | 2 +- web/src/pages/Channel/EditChannel.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/channel.go b/model/channel.go index 8a5b79ff..aa3b8a10 100644 --- a/model/channel.go +++ b/model/channel.go @@ -15,7 +15,7 @@ type Channel struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds - BaseURL *string `json:"base_url" gorm:"column:base_url"` + BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index e0053709..4c8dd0c4 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -174,7 +174,7 @@ const EditChannel = () => { return; } let localInputs = inputs; - if (localInputs.base_url.endsWith('/')) { + if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); } if (localInputs.type === 3 && localInputs.other === '') {