feat: support model remap now

This commit is contained in:
JustSong 2023-06-27 13:42:45 +08:00
parent 431d505f79
commit 0941e294bf
5 changed files with 57 additions and 12 deletions

View File

@ -53,6 +53,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
} }
} }
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
@ -114,7 +128,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
} }
} }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
} }

View File

@ -27,16 +27,16 @@ const (
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages"` Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt"` Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream"` Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p"` TopP float64 `json:"top_p,omitempty"`
N int `json:"n"` N int `json:"n,omitempty"`
Input any `json:"input"` Input any `json:"input,omitempty"`
Instruction string `json:"instruction"` Instruction string `json:"instruction,omitempty"`
} }
type ChatRequest struct { type ChatRequest struct {

View File

@ -88,6 +88,7 @@ func Distribute() func(c *gin.Context) {
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL) c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {

View File

@ -22,6 +22,7 @@ type Channel struct {
Models string `json:"models"` Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

View File

@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams } from 'react-router-dom'; import { useParams } from 'react-router-dom';
import { API, showError, showInfo, showSuccess } from '../../helpers'; import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants'; import { CHANNEL_OPTIONS } from '../../constants';
const EditChannel = () => { const EditChannel = () => {
@ -15,6 +15,7 @@ const EditChannel = () => {
key: '', key: '',
base_url: '', base_url: '',
other: '', other: '',
model_mapping:'',
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@ -42,6 +43,9 @@ const EditChannel = () => {
} else { } else {
data.groups = data.group.split(','); data.groups = data.group.split(',');
} }
if (data.model_mapping !== '') {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
setInputs(data); setInputs(data);
} else { } else {
showError(message); showError(message);
@ -94,6 +98,10 @@ const EditChannel = () => {
showInfo('请至少选择一个模型!'); showInfo('请至少选择一个模型!');
return; return;
} }
if (inputs.model_mapping !== "" && !verifyJSON(inputs.model_mapping)) {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
let localInputs = inputs; let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) { if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@ -246,6 +254,17 @@ const EditChannel = () => {
handleInputChange(null, { name: 'models', value: [] }); handleInputChange(null, { name: 'models', value: [] });
}}>清除所有模型</Button> }}>清除所有模型</Button>
</div> </div>
<Form.Field>
<Form.TextArea
label='模型映射'
placeholder={'为一个 JSON 文本,键为用户请求的模型名称,值为要替换的模型名称'}
name='model_mapping'
onChange={handleInputChange}
value={inputs.model_mapping}
style={{ minHeight: 100, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
{ {
batch ? <Form.Field> batch ? <Form.Field>
<Form.TextArea <Form.TextArea