Merge branch 'songquanpeng:main' into main
This commit is contained in:
commit
26f9d25860
@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
|||||||
}
|
}
|
||||||
|
|
||||||
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
|
||||||
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
|
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL())
|
||||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
|
|||||||
|
|
||||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.BaseURL == "" {
|
if channel.GetBaseURL() == "" {
|
||||||
channel.BaseURL = baseURL
|
channel.BaseURL = &baseURL
|
||||||
}
|
}
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeOpenAI:
|
case common.ChannelTypeOpenAI:
|
||||||
if channel.BaseURL != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.BaseURL
|
baseURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
return 0, errors.New("尚未实现")
|
return 0, errors.New("尚未实现")
|
||||||
case common.ChannelTypeCustom:
|
case common.ChannelTypeCustom:
|
||||||
baseURL = channel.BaseURL
|
baseURL = channel.GetBaseURL()
|
||||||
case common.ChannelTypeCloseAI:
|
case common.ChannelTypeCloseAI:
|
||||||
return updateChannelCloseAIBalance(channel)
|
return updateChannelCloseAIBalance(channel)
|
||||||
case common.ChannelTypeOpenAISB:
|
case common.ChannelTypeOpenAISB:
|
||||||
|
@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
|
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
||||||
} else {
|
} else {
|
||||||
if channel.BaseURL != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
requestURL = channel.BaseURL
|
requestURL = channel.GetBaseURL()
|
||||||
}
|
}
|
||||||
requestURL += "/v1/chat/completions"
|
requestURL += "/v1/chat/completions"
|
||||||
}
|
}
|
||||||
|
@ -94,9 +94,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("channel", channel.Type)
|
c.Set("channel", channel.Type)
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("model_mapping", channel.ModelMapping)
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.BaseURL)
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
|
@ -10,16 +10,18 @@ type Ability struct {
|
|||||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Priority int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
ability := Ability{}
|
ability := Ability{}
|
||||||
var err error = nil
|
var err error = nil
|
||||||
if common.UsingSQLite {
|
if common.UsingSQLite {
|
||||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
|
||||||
|
err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("group = ? and model = ? and enabled = 1", group, model)
|
||||||
|
err = DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery).Order("RAND()").Limit(1).First(&ability).Error
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -165,7 +165,7 @@ func InitChannelCache() {
|
|||||||
for group, model2channels := range newGroup2model2channels {
|
for group, model2channels := range newGroup2model2channels {
|
||||||
for model, channels := range model2channels {
|
for model, channels := range model2channels {
|
||||||
sort.Slice(channels, func(i, j int) bool {
|
sort.Slice(channels, func(i, j int) bool {
|
||||||
return channels[i].Priority > channels[j].Priority
|
return channels[i].GetPriority() > channels[j].GetPriority()
|
||||||
})
|
})
|
||||||
newGroup2model2channels[group][model] = channels
|
newGroup2model2channels[group][model] = channels
|
||||||
}
|
}
|
||||||
@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
if len(channels) == 0 {
|
if len(channels) == 0 {
|
||||||
return nil, errors.New("channel not found")
|
return nil, errors.New("channel not found")
|
||||||
}
|
}
|
||||||
|
endIdx := len(channels)
|
||||||
// choose by priority
|
// choose by priority
|
||||||
firstChannel := channels[0]
|
firstChannel := channels[0]
|
||||||
if firstChannel.Priority > 0 {
|
if firstChannel.GetPriority() > 0 {
|
||||||
return firstChannel, nil
|
for i := range channels {
|
||||||
|
if channels[i].GetPriority() != firstChannel.GetPriority() {
|
||||||
|
endIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
idx := rand.Intn(len(channels))
|
idx := rand.Intn(endIdx)
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
@ -16,15 +16,15 @@ type Channel struct {
|
|||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
BaseURL string `json:"base_url" gorm:"column:base_url"`
|
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"`
|
||||||
Other string `json:"other"`
|
Other string `json:"other"`
|
||||||
Balance float64 `json:"balance"` // in USD
|
Balance float64 `json:"balance"` // in USD
|
||||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||||
Models string `json:"models"`
|
Models string `json:"models"`
|
||||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
@ -80,6 +80,27 @@ func BatchInsertChannels(channels []Channel) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetPriority() int64 {
|
||||||
|
if channel.Priority == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *channel.Priority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetBaseURL() string {
|
||||||
|
if channel.BaseURL == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *channel.BaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetModelMapping() string {
|
||||||
|
if channel.ModelMapping == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *channel.ModelMapping
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(channel).Error
|
err = DB.Create(channel).Error
|
||||||
|
@ -175,7 +175,7 @@ const EditChannel = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
if (localInputs.base_url.endsWith('/')) {
|
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||||
}
|
}
|
||||||
if (localInputs.type === 3 && localInputs.other === '') {
|
if (localInputs.type === 3 && localInputs.other === '') {
|
||||||
@ -184,9 +184,6 @@ const EditChannel = () => {
|
|||||||
if (localInputs.type === 18 && localInputs.other === '') {
|
if (localInputs.type === 18 && localInputs.other === '') {
|
||||||
localInputs.other = 'v2.1';
|
localInputs.other = 'v2.1';
|
||||||
}
|
}
|
||||||
if (localInputs.model_mapping === '') {
|
|
||||||
localInputs.model_mapping = '{}';
|
|
||||||
}
|
|
||||||
let res;
|
let res;
|
||||||
localInputs.models = localInputs.models.join(',');
|
localInputs.models = localInputs.models.join(',');
|
||||||
localInputs.group = localInputs.groups.join(',');
|
localInputs.group = localInputs.groups.join(',');
|
||||||
|
Loading…
Reference in New Issue
Block a user