diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 46262f6c..6ddad7ea 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He } func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { - url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) + url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) if err != nil { @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) { baseURL := common.ChannelBaseURLs[channel.Type] - if channel.BaseURL == "" { - channel.BaseURL = baseURL + if channel.GetBaseURL() == "" { + channel.BaseURL = &baseURL } switch channel.Type { case common.ChannelTypeOpenAI: - if channel.BaseURL != "" { - baseURL = channel.BaseURL + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() } case common.ChannelTypeAzure: return 0, errors.New("尚未实现") case common.ChannelTypeCustom: - baseURL = channel.BaseURL + baseURL = channel.GetBaseURL() case common.ChannelTypeCloseAI: return updateChannelCloseAIBalance(channel) case common.ChannelTypeOpenAISB: diff --git a/controller/channel-test.go b/controller/channel-test.go index 8c7e6f0d..f7a565a2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) } else { - if channel.BaseURL != "" { - requestURL = channel.BaseURL + if channel.GetBaseURL() != "" { + requestURL = channel.GetBaseURL() } requestURL += "/v1/chat/completions" } diff --git a/middleware/distributor.go b/middleware/distributor.go index ab374a85..9ded3231 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -82,9 +82,9 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.ModelMapping) + c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.BaseURL) + c.Set("base_url", channel.GetBaseURL()) switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) diff --git a/model/channel.go b/model/channel.go index 1a478b91..8a5b79ff 100644 --- a/model/channel.go +++ b/model/channel.go @@ -15,14 +15,14 @@ type Channel struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"` ResponseTime int `json:"response_time"` // in milliseconds - BaseURL string `json:"base_url" gorm:"column:base_url"` + BaseURL *string `json:"base_url" gorm:"column:base_url"` Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` - ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` } @@ -80,12 +80,26 @@ func BatchInsertChannels(channels []Channel) error { } func (channel *Channel) GetPriority() int64 { - if channel == nil { + if channel.Priority == nil { return 0 } return *channel.Priority } +func (channel *Channel) GetBaseURL() string { + if channel.BaseURL == nil { + return "" + } + return *channel.BaseURL +} + +func (channel *Channel) GetModelMapping() string { + if channel.ModelMapping == nil { + return "" + } + return *channel.ModelMapping +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 78ff1952..e0053709 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -183,9 +183,6 @@ const EditChannel = () => { if (localInputs.type === 18 && localInputs.other === '') { localInputs.other = 'v2.1'; } - if (localInputs.model_mapping === '') { - localInputs.model_mapping = '{}'; - } let res; localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(',');