diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 46262f6c..6ddad7ea 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He } func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { - url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) + url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] - if channel.BaseURL == "" { - channel.BaseURL = baseURL + if channel.GetBaseURL() == "" { + channel.BaseURL = &baseURL } switch channel.Type { case common.ChannelTypeOpenAI: - if channel.BaseURL != "" { - baseURL = channel.BaseURL + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() } case common.ChannelTypeAzure: return 0, errors.New("尚未实现") case common.ChannelTypeCustom: - baseURL = channel.BaseURL + baseURL = channel.GetBaseURL() case common.ChannelTypeCloseAI: return updateChannelCloseAIBalance(channel) case common.ChannelTypeOpenAISB: diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c7e6f0d..f7a565a2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) } else { - if channel.BaseURL != "" { - requestURL = channel.BaseURL + if channel.GetBaseURL() != "" { + requestURL = channel.GetBaseURL() } requestURL += "/v1/chat/completions" } diff --git a/middleware/distributor.go b/middleware/distributor.go index c87c0b00..c171ffa6 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -94,9 +94,9 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.ModelMapping) + c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.BaseURL) + c.Set("base_url", channel.GetBaseURL()) switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) diff --git a/model/ability.go b/model/ability.go index eb68fa0d..919de227 100644 --- a/model/ability.go +++ b/model/ability.go @@ -10,16 +10,18 @@ type Ability struct { Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} var err error = nil if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) + err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model) + err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error } if err != nil { return nil, err diff --git a/model/cache.go b/model/cache.go index 1b547842..b9d6b612 100644 --- a/model/cache.go +++ b/model/cache.go @@ -165,7 +165,7 @@ func InitChannelCache() { for group, model2channels := range newGroup2model2channels { for model, channels := range model2channels { sort.Slice(channels, func(i, j int) bool { - return channels[i].Priority > channels[j].Priority + return channels[i].GetPriority() > channels[j].GetPriority() }) newGroup2model2channels[group][model] = channels } @@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } + endIdx := len(channels) // choose by priority firstChannel := channels[0] - if firstChannel.Priority > 0 { - return firstChannel, nil + if firstChannel.GetPriority() > 0 { + for i := range channels { + if channels[i].GetPriority() != firstChannel.GetPriority() { + endIdx = i + break + } + } } - idx := rand.Intn(len(channels)) + idx := rand.Intn(endIdx) return channels[idx], nil } diff --git a/model/channel.go b/model/channel.go index 16e7cda3..628f435e 100644 --- a/model/channel.go +++ b/model/channel.go @@ -16,15 +16,15 @@ type Channel struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds - BaseURL string `json:"base_url" gorm:"column:base_url"` + BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` - ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - Priority int64 `json:"priority" gorm:"bigint;default:0"` + ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -80,6 +80,27 @@ func BatchInsertChannels(channels []Channel) error { return nil } +func (channel *Channel) GetPriority() int64 { + if channel.Priority == nil { + return 0 + } + return *channel.Priority +} + +func (channel *Channel) GetBaseURL() string { + if channel.BaseURL == nil { + return "" + } + return *channel.BaseURL +} + +func (channel *Channel) GetModelMapping() string { + if channel.ModelMapping == nil { + return "" + } + return *channel.ModelMapping +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index dffdd141..711a1963 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -175,7 +175,7 @@ const EditChannel = () => { return; } let localInputs = inputs; - if (localInputs.base_url.endsWith('/')) { + if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); } if (localInputs.type === 3 && localInputs.other === '') { @@ -184,9 +184,6 @@ const EditChannel = () => { if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; } - if (localInputs.model_mapping === '') { - localInputs.model_mapping = '{}'; - } let res; localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(',');