From 9c0a49b97ae4e6a84957dce4175800273d503ba9 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Fri, 29 Dec 2023 16:23:25 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20custom=20test=20model?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 29 +++++---------- model/channel.go | 1 + web/src/views/Channel/component/EditModal.js | 26 +++++++++++++- web/src/views/Channel/type/Config.js | 38 +++++++++++++++----- 4 files changed, 63 insertions(+), 31 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 15faba24..b291c04c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -18,6 +18,10 @@ import ( ) func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) { + if channel.TestModel == "" { + return errors.New("请填写测速模型后再试"), nil + } + // 创建一个 http.Request req, err := http.NewRequest("POST", "/v1/chat/completions", nil) if err != nil { @@ -28,26 +32,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req - - // 创建映射 - channelTypeToModel := map[int]string{ - common.ChannelTypePaLM: "PaLM-2", - common.ChannelTypeAnthropic: "claude-2", - common.ChannelTypeBaidu: "ERNIE-Bot", - common.ChannelTypeZhipu: "chatglm_lite", - common.ChannelTypeAli: "qwen-turbo", - common.ChannelType360: "360GPT_S2_V9", - common.ChannelTypeXunfei: "SparkDesk", - common.ChannelTypeTencent: "hunyuan", - common.ChannelTypeAzure: "gpt-3.5-turbo", - } - - // 从映射中获取模型名称 - model, ok := channelTypeToModel[channel.Type] - if !ok { - model = "gpt-3.5-turbo" // 默认值 - } - request.Model = model + request.Model = channel.TestModel provider := providers.GetProvider(channel, c) if provider == nil { @@ -69,13 +54,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e promptTokens := common.CountTokenMessages(request.Messages, request.Model) Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens) if openAIErrorWithStatusCode != nil { - return nil, &openAIErrorWithStatusCode.OpenAIError + return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError } if Usage.CompletionTokens == 0 { return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil } + common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String())) + return nil, nil } diff --git a/model/channel.go b/model/channel.go index b8352862..1c6d7121 100644 --- a/model/channel.go +++ b/model/channel.go @@ -26,6 +26,7 @@ type Channel struct { ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` Proxy string `json:"proxy" gorm:"type:varchar(255);default:''"` + TestModel string `json:"test_model" gorm:"type:varchar(50);default:''"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index bef5b5bc..233150d8 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -36,6 +36,7 @@ const validationSchema = Yup.object().shape({ key: Yup.string().when('is_edit', { is: false, then: Yup.string().required('密钥 不能为空') }), other: Yup.string(), proxy: Yup.string(), + test_model: Yup.string(), models: Yup.array().min(1, '模型 不能为空'), groups: Yup.array().min(1, '用户组 不能为空'), base_url: Yup.string().when('type', { @@ -90,7 +91,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { if (newInput) { Object.keys(newInput).forEach((key) => { if ( - (!Array.isArray(values[key]) && values[key] !== null && values[key] !== undefined) || + (!Array.isArray(values[key]) && values[key] !== null && values[key] !== undefined && values[key] !== '') || (Array.isArray(values[key]) && values[key].length > 0) ) { return; @@ -464,6 +465,29 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { {inputPrompt.proxy} )} + {inputPrompt.test_model && ( + + {inputLabel.test_model} + + {touched.test_model && errors.test_model ? ( + + {errors.test_model} + + ) : ( + {inputPrompt.test_model} + )} + + )}