refactor: use model_mapping instead of deployment_mapping
This commit is contained in:
parent
9042b7c4e4
commit
10291630f1
@ -35,7 +35,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,或正确配置部署映射,并且 apiVersion 已正确填写!")
|
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,或正确配置模型映射,并且 apiVersion 已正确填写!")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
fallthrough
|
fallthrough
|
||||||
@ -44,8 +44,11 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
deployment := channel.GetDeployment(request.Model)
|
modelMap := channel.GetModelMapping()
|
||||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", channel.GetBaseURL(), deployment, channel.Other)
|
if modelMap[request.Model] != "" {
|
||||||
|
request.Model = modelMap[request.Model]
|
||||||
|
}
|
||||||
|
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=%s", request.Model, channel.Other), channel.Type)
|
||||||
} else {
|
} else {
|
||||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
||||||
requestURL = baseURL
|
requestURL = baseURL
|
||||||
|
@ -90,18 +90,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetStringMapString("model_mapping")
|
||||||
isModelMapped := false
|
isModelMapped := false
|
||||||
if modelMapping != "" && modelMapping != "{}" {
|
originalModel := textRequest.Model
|
||||||
modelMap := make(map[string]string)
|
if modelMapping[textRequest.Model] != "" {
|
||||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
textRequest.Model = modelMapping[textRequest.Model]
|
||||||
if err != nil {
|
isModelMapped = true
|
||||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if modelMap[textRequest.Model] != "" {
|
|
||||||
textRequest.Model = modelMap[textRequest.Model]
|
|
||||||
isModelMapped = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
apiType := APITypeOpenAI
|
apiType := APITypeOpenAI
|
||||||
switch channelType {
|
switch channelType {
|
||||||
@ -142,9 +136,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
task := strings.TrimPrefix(requestURL, "/v1/")
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
||||||
|
|
||||||
deploymentMapping := c.GetStringMapString("deployment_mapping")
|
var deployment string
|
||||||
deployment := deploymentMapping[textRequest.Model]
|
if isModelMapped {
|
||||||
if deployment == "" {
|
deployment = textRequest.Model
|
||||||
|
textRequest.Model = originalModel
|
||||||
|
} else {
|
||||||
deployment = model.ModelToDeployment(textRequest.Model)
|
deployment = model.ModelToDeployment(textRequest.Model)
|
||||||
}
|
}
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, deployment, task)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, deployment, task)
|
||||||
|
@ -85,7 +85,6 @@ func Distribute() func(c *gin.Context) {
|
|||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case common.ChannelTypeAzure:
|
case common.ChannelTypeAzure:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
c.Set("deployment_mapping", channel.GetDeploymentMapping())
|
|
||||||
case common.ChannelTypeXunfei:
|
case common.ChannelTypeXunfei:
|
||||||
c.Set("api_version", channel.Other)
|
c.Set("api_version", channel.Other)
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
@ -27,7 +27,6 @@ type Channel struct {
|
|||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
DeploymentMapping *string `json:"deployment_mapping" gorm:"type:varchar(1024);default:''"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
@ -76,26 +75,6 @@ func BatchInsertChannels(channels []Channel) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetDeploymentMapping() (deploymentMapping map[string]string) {
|
|
||||||
if channel.DeploymentMapping == nil || *channel.DeploymentMapping == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := json.Unmarshal([]byte(*channel.DeploymentMapping), &deploymentMapping)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("failed to unmarshal deployment mapping: " + err.Error())
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (channel *Channel) GetDeployment(model string) string {
|
|
||||||
deploymentMapping := channel.GetDeploymentMapping()
|
|
||||||
deployment, ok := deploymentMapping[model]
|
|
||||||
if !ok {
|
|
||||||
return ModelToDeployment(model)
|
|
||||||
}
|
|
||||||
return deployment
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelToDeployment(model string) string {
|
func ModelToDeployment(model string) string {
|
||||||
model = strings.Replace(model, ".", "", -1)
|
model = strings.Replace(model, ".", "", -1)
|
||||||
// https://github.com/songquanpeng/one-api/issues/67
|
// https://github.com/songquanpeng/one-api/issues/67
|
||||||
@ -119,11 +98,15 @@ func (channel *Channel) GetBaseURL() string {
|
|||||||
return *channel.BaseURL
|
return *channel.BaseURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) GetModelMapping() string {
|
func (channel *Channel) GetModelMapping() (m map[string]string) {
|
||||||
if channel.ModelMapping == nil {
|
if channel.ModelMapping == nil {
|
||||||
return ""
|
return
|
||||||
}
|
}
|
||||||
return *channel.ModelMapping
|
err := json.Unmarshal([]byte(*channel.ModelMapping), &m)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to unmarshal model mapping: " + err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
|
@ -7,12 +7,8 @@ import { CHANNEL_OPTIONS } from '../../constants';
|
|||||||
const MODEL_MAPPING_EXAMPLE = {
|
const MODEL_MAPPING_EXAMPLE = {
|
||||||
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
|
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
|
||||||
'gpt-4-0314': 'gpt-4',
|
'gpt-4-0314': 'gpt-4',
|
||||||
'gpt-4-32k-0314': 'gpt-4-32k'
|
'gpt-4-32k-0314': 'gpt-4-32k',
|
||||||
};
|
'gpt-4': 'azure-deployment-1',
|
||||||
|
|
||||||
const DEPLOYMENT_MAPPING_EXAMPLE = {
|
|
||||||
'gpt-3.5-turbo': 'gpt-35-turbo',
|
|
||||||
'gpt-4': 'custom-gpt4'
|
|
||||||
};
|
};
|
||||||
|
|
||||||
function type2secretPrompt(type) {
|
function type2secretPrompt(type) {
|
||||||
@ -394,7 +390,7 @@ const EditChannel = () => {
|
|||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
label='模型重定向'
|
label='模型重定向'
|
||||||
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,如果是 Azure OpenAI 则替换为 deployment 名称,例如:\n${JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)}`}
|
||||||
name='model_mapping'
|
name='model_mapping'
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={inputs.model_mapping}
|
value={inputs.model_mapping}
|
||||||
@ -402,23 +398,6 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
{
|
|
||||||
inputs.type === 3 && (
|
|
||||||
<>
|
|
||||||
<Form.Field>
|
|
||||||
<Form.TextArea
|
|
||||||
label='部署映射'
|
|
||||||
placeholder={`此项可选,为一个 JSON 字符串,键为请求中模型名称,值为要替换的部署名称,用于修改 Azure OpenAI 的 deployment 参数。如果不填此参数,则部署名称必须与模型名称一致,例如 gpt-3.5-turbo 模型对应的部署名称应该为 gpt-35-turbo。若实际的部署名称不满足此规则,则可以通过该参数自定义。例如:\n${JSON.stringify(DEPLOYMENT_MAPPING_EXAMPLE, null, 2)}`}
|
|
||||||
name='deployment_mapping'
|
|
||||||
onChange={handleInputChange}
|
|
||||||
value={inputs.deployment_mapping}
|
|
||||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</Form.Field>
|
|
||||||
</>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
{
|
{
|
||||||
batch ? <Form.Field>
|
batch ? <Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
|
Loading…
Reference in New Issue
Block a user