diff --git a/controller/channel-test.go b/controller/channel-test.go index da523658..0f58f4d0 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -35,7 +35,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai case common.ChannelTypeAzure: defer func() { if err != nil { - err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,或正确配置部署映射,并且 apiVersion 已正确填写!") + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,或正确配置模型映射,并且 apiVersion 已正确填写!") } }() fallthrough @@ -44,8 +44,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - deployment := channel.GetDeployment(request.Model) - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", channel.GetBaseURL(), deployment, channel.Other) + modelMap := channel.GetModelMapping() + if modelMap[request.Model] != "" { + request.Model = modelMap[request.Model] + } + requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=%s", request.Model, channel.Other), channel.Type) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL diff --git a/controller/relay-text.go b/controller/relay-text.go index e35c3784..463c1cd4 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -90,18 +90,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } } // map model name - modelMapping := c.GetString("model_mapping") + modelMapping := c.GetStringMapString("model_mapping") isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } + originalModel := textRequest.Model + if modelMapping[textRequest.Model] != "" { + textRequest.Model = modelMapping[textRequest.Model] + isModelMapped = true } apiType := APITypeOpenAI switch channelType { @@ -142,9 +136,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { baseURL = c.GetString("base_url") task := strings.TrimPrefix(requestURL, "/v1/") - deploymentMapping := c.GetStringMapString("deployment_mapping") - deployment := deploymentMapping[textRequest.Model] - if deployment == "" { + var deployment string + if isModelMapped { + deployment = textRequest.Model + textRequest.Model = originalModel + } else { deployment = model.ModelToDeployment(textRequest.Model) } fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, deployment, task) diff --git a/middleware/distributor.go b/middleware/distributor.go index d68b2f4e..c4ddc3a0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -85,7 +85,6 @@ func Distribute() func(c *gin.Context) { switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) - c.Set("deployment_mapping", channel.GetDeploymentMapping()) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: diff --git a/model/channel.go b/model/channel.go index 8f7188d6..78c37ba0 100644 --- a/model/channel.go +++ b/model/channel.go @@ -27,7 +27,6 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` - DeploymentMapping *string `json:"deployment_mapping" gorm:"type:varchar(1024);default:''"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -76,26 +75,6 @@ func BatchInsertChannels(channels []Channel) error { return nil } -func (channel *Channel) GetDeploymentMapping() (deploymentMapping map[string]string) { - if channel.DeploymentMapping == nil || *channel.DeploymentMapping == "" { - return - } - err := json.Unmarshal([]byte(*channel.DeploymentMapping), &deploymentMapping) - if err != nil { - common.SysError("failed to unmarshal deployment mapping: " + err.Error()) - } - return -} - -func (channel *Channel) GetDeployment(model string) string { - deploymentMapping := channel.GetDeploymentMapping() - deployment, ok := deploymentMapping[model] - if !ok { - return ModelToDeployment(model) - } - return deployment -} - func ModelToDeployment(model string) string { model = strings.Replace(model, ".", "", -1) // https://github.com/songquanpeng/one-api/issues/67 @@ -119,11 +98,15 @@ func (channel *Channel) GetBaseURL() string { return *channel.BaseURL } -func (channel *Channel) GetModelMapping() string { +func (channel *Channel) GetModelMapping() (m map[string]string) { if channel.ModelMapping == nil { - return "" + return } - return *channel.ModelMapping + err := json.Unmarshal([]byte(*channel.ModelMapping), &m) + if err != nil { + common.SysError("failed to unmarshal model mapping: " + err.Error()) + } + return } func (channel *Channel) Insert() error { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 35f6906e..75540e2d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -7,12 +7,8 @@ import { CHANNEL_OPTIONS } from '../../constants'; const MODEL_MAPPING_EXAMPLE = { 'gpt-3.5-turbo-0301': 'gpt-3.5-turbo', 'gpt-4-0314': 'gpt-4', - 'gpt-4-32k-0314': 'gpt-4-32k' -}; - -const DEPLOYMENT_MAPPING_EXAMPLE = { - 'gpt-3.5-turbo': 'gpt-35-turbo', - 'gpt-4': 'custom-gpt4' + 'gpt-4-32k-0314': 'gpt-4-32k', + 'gpt-4': 'azure-deployment-1', }; function type2secretPrompt(type) { @@ -394,7 +390,7 @@ const EditChannel = () => { { autoComplete='new-password' /> - { - inputs.type === 3 && ( - <> - - - - - ) - } { batch ?