diff --git a/controller/relay-text.go b/controller/relay-text.go index 00fc6f89..9ce11601 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -53,6 +53,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) } } + // map model name + modelMapping := c.GetString("model_mapping") + isModelMapped := false + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[textRequest.Model] != "" { + textRequest.Model = modelMap[textRequest.Model] + isModelMapped = true + } + } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { @@ -114,7 +128,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) + var requestBody io.Reader + if isModelMapped { + jsonStr, err := json.Marshal(textRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } else { + requestBody = c.Request.Body + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } diff --git a/controller/relay.go b/controller/relay.go index 2910cc97..42aa0c0f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -27,16 +27,16 @@ const ( // https://platform.openai.com/docs/api-reference/chat type GeneralOpenAIRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt any `json:"prompt"` - Stream bool `json:"stream"` - MaxTokens int `json:"max_tokens"` - Temperature float64 `json:"temperature"` - TopP float64 `json:"top_p"` - N int `json:"n"` - Input any `json:"input"` - Instruction string `json:"instruction"` + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Prompt any `json:"prompt,omitempty"` + Stream bool `json:"stream,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Input any `json:"input,omitempty"` + Instruction string `json:"instruction,omitempty"` } type ChatRequest struct { diff --git a/middleware/distributor.go b/middleware/distributor.go index e7532432..9da2ed1e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -88,6 +88,7 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) + c.Set("model_mapping", channel.ModelMapping) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.BaseURL) if channel.Type == common.ChannelTypeAzure { diff --git a/model/channel.go b/model/channel.go index e99f4d10..bf2afbfc 100644 --- a/model/channel.go +++ b/model/channel.go @@ -22,6 +22,7 @@ type Channel struct { Models string `json:"models"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` + ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 58b80219..ec04234d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; import { useParams } from 'react-router-dom'; -import { API, showError, showInfo, showSuccess } from '../../helpers'; +import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { CHANNEL_OPTIONS } from '../../constants'; const EditChannel = () => { @@ -15,6 +15,7 @@ const EditChannel = () => { key: '', base_url: '', other: '', + model_mapping:'', models: [], groups: ['default'] }; @@ -42,6 +43,9 @@ const EditChannel = () => { } else { data.groups = data.group.split(','); } + if (data.model_mapping !== '') { + data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); + } setInputs(data); } else { showError(message); @@ -94,6 +98,10 @@ const EditChannel = () => { showInfo('请至少选择一个模型!'); return; } + if (inputs.model_mapping !== "" && !verifyJSON(inputs.model_mapping)) { + showInfo('模型映射必须是合法的 JSON 格式!'); + return; + } let localInputs = inputs; if (localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); @@ -246,6 +254,17 @@ const EditChannel = () => { handleInputChange(null, { name: 'models', value: [] }); }}>清除所有模型 +