From 9042b7c4e4c059e5ca1c10c3932786ca32a6570b Mon Sep 17 00:00:00 2001 From: WqyJh <781345688@qq.com> Date: Mon, 20 Nov 2023 16:00:21 +0800 Subject: [PATCH] feat: add deployment_mapping for Azure OpenAI --- controller/channel-test.go | 7 +++--- controller/relay-text.go | 14 +++++------ middleware/distributor.go | 1 + model/channel.go | 35 +++++++++++++++++++++++++++- web/src/pages/Channel/EditChannel.js | 23 ++++++++++++++++++ 5 files changed, 68 insertions(+), 12 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 1b0b745a..da523658 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -33,18 +33,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai case common.ChannelTypeXunfei: return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: - request.Model = "gpt-35-turbo" defer func() { if err != nil { - err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") + err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,或正确配置部署映射,并且 apiVersion 已正确填写!") } }() + fallthrough default: request.Model = "gpt-3.5-turbo" } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) + deployment := channel.GetDeployment(request.Model) + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", channel.GetBaseURL(), deployment, channel.Other) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL diff --git a/controller/relay-text.go b/controller/relay-text.go index 018c8d8a..e35c3784 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -141,15 +141,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) baseURL = c.GetString("base_url") task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) + deploymentMapping := c.GetStringMapString("deployment_mapping") + deployment := deploymentMapping[textRequest.Model] + if deployment == "" { + deployment = model.ModelToDeployment(textRequest.Model) + } + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, deployment, task) } case APITypeClaude: fullRequestURL = "https://api.anthropic.com/v1/complete" diff --git a/middleware/distributor.go b/middleware/distributor.go index c4ddc3a0..d68b2f4e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -85,6 +85,7 @@ func Distribute() func(c *gin.Context) { switch channel.Type { case common.ChannelTypeAzure: c.Set("api_version", channel.Other) + c.Set("deployment_mapping", channel.GetDeploymentMapping()) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..8f7188d6 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,11 @@ package model import ( - "gorm.io/gorm" + "encoding/json" "one-api/common" + "strings" + + "gorm.io/gorm" ) type Channel struct { @@ -24,6 +27,7 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` + DeploymentMapping *string `json:"deployment_mapping" gorm:"type:varchar(1024);default:''"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -72,6 +76,35 @@ func BatchInsertChannels(channels []Channel) error { return nil } +func (channel *Channel) GetDeploymentMapping() (deploymentMapping map[string]string) { + if channel.DeploymentMapping == nil || *channel.DeploymentMapping == "" { + return + } + err := json.Unmarshal([]byte(*channel.DeploymentMapping), &deploymentMapping) + if err != nil { + common.SysError("failed to unmarshal deployment mapping: " + err.Error()) + } + return +} + +func (channel *Channel) GetDeployment(model string) string { + deploymentMapping := channel.GetDeploymentMapping() + deployment, ok := deploymentMapping[model] + if !ok { + return ModelToDeployment(model) + } + return deployment +} + +func ModelToDeployment(model string) string { + model = strings.Replace(model, ".", "", -1) + // https://github.com/songquanpeng/one-api/issues/67 + model = strings.TrimSuffix(model, "-0301") + model = strings.TrimSuffix(model, "-0314") + model = strings.TrimSuffix(model, "-0613") + return model +} + func (channel *Channel) GetPriority() int64 { if channel.Priority == nil { return 0 diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 654a5d51..35f6906e 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -10,6 +10,11 @@ const MODEL_MAPPING_EXAMPLE = { 'gpt-4-32k-0314': 'gpt-4-32k' }; +const DEPLOYMENT_MAPPING_EXAMPLE = { + 'gpt-3.5-turbo': 'gpt-35-turbo', + 'gpt-4': 'custom-gpt4' +}; + function type2secretPrompt(type) { // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') switch (type) { @@ -43,6 +48,7 @@ const EditChannel = () => { base_url: '', other: '', model_mapping: '', + deployment_mapping: '', models: [], groups: ['default'] }; @@ -396,6 +402,23 @@ const EditChannel = () => { autoComplete='new-password' /> + { + inputs.type === 3 && ( + <> +