Merge branch 'songquanpeng:main' into main

This commit is contained in:
Calcium-Ion 2023-09-18 22:52:10 +08:00 committed by GitHub
commit 26f9d25860
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 25 deletions

View File

@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
} }
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil { if err != nil {
@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := common.ChannelBaseURLs[channel.Type]
if channel.BaseURL == "" { if channel.GetBaseURL() == "" {
channel.BaseURL = baseURL channel.BaseURL = &baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case common.ChannelTypeOpenAI:
if channel.BaseURL != "" { if channel.GetBaseURL() != "" {
baseURL = channel.BaseURL baseURL = channel.GetBaseURL()
} }
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case common.ChannelTypeCustom:
baseURL = channel.BaseURL baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI: case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel) return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB: case common.ChannelTypeOpenAISB:

View File

@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
requestURL := common.ChannelBaseURLs[channel.Type] requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
} else { } else {
if channel.BaseURL != "" { if channel.GetBaseURL() != "" {
requestURL = channel.BaseURL requestURL = channel.GetBaseURL()
} }
requestURL += "/v1/chat/completions" requestURL += "/v1/chat/completions"
} }

View File

@ -94,9 +94,9 @@ func Distribute() func(c *gin.Context) {
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping) c.Set("model_mapping", channel.GetModelMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL) c.Set("base_url", channel.GetBaseURL())
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
c.Set("api_version", channel.Other) c.Set("api_version", channel.Other)

View File

@ -10,16 +10,18 @@ type Ability struct {
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Priority int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{} ability := Ability{}
var err error = nil var err error = nil
if common.UsingSQLite { if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error
} else { } else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model)
err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error
} }
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -165,7 +165,7 @@ func InitChannelCache() {
for group, model2channels := range newGroup2model2channels { for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels { for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool { sort.Slice(channels, func(i, j int) bool {
return channels[i].Priority > channels[j].Priority return channels[i].GetPriority() > channels[j].GetPriority()
}) })
newGroup2model2channels[group][model] = channels newGroup2model2channels[group][model] = channels
} }
@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 { if len(channels) == 0 {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
endIdx := len(channels)
// choose by priority // choose by priority
firstChannel := channels[0] firstChannel := channels[0]
if firstChannel.Priority > 0 { if firstChannel.GetPriority() > 0 {
return firstChannel, nil for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
} }
idx := rand.Intn(len(channels)) }
}
idx := rand.Intn(endIdx)
return channels[idx], nil return channels[idx], nil
} }

View File

@ -16,15 +16,15 @@ type Channel struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL string `json:"base_url" gorm:"column:base_url"` BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
Other string `json:"other"` Other string `json:"other"`
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -80,6 +80,27 @@ func BatchInsertChannels(channels []Channel) error {
return nil return nil
} }
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() string {
if channel.ModelMapping == nil {
return ""
}
return *channel.ModelMapping
}
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
var err error var err error
err = DB.Create(channel).Error err = DB.Create(channel).Error

View File

@ -175,7 +175,7 @@ const EditChannel = () => {
return; return;
} }
let localInputs = inputs; let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) { if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
} }
if (localInputs.type === 3 && localInputs.other === '') { if (localInputs.type === 3 && localInputs.other === '') {
@ -184,9 +184,6 @@ const EditChannel = () => {
if (localInputs.type === 18 && localInputs.other === '') { if (localInputs.type === 18 && localInputs.other === '') {
localInputs.other = 'v2.1'; localInputs.other = 'v2.1';
} }
if (localInputs.model_mapping === '') {
localInputs.model_mapping = '{}';
}
let res; let res;
localInputs.models = localInputs.models.join(','); localInputs.models = localInputs.models.join(',');
localInputs.group = localInputs.groups.join(','); localInputs.group = localInputs.groups.join(',');