feat: support test specific model (#1600)

This commit is contained in:
Qiying Wang 2024-07-05 18:05:16 +08:00 committed by GitHub
parent 273be55797
commit d7a78f3397
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 29 deletions

View File

@ -14,6 +14,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -27,15 +28,15 @@ import (
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"github.com/gin-gonic/gin"
) )
func buildTestRequest() *relaymodel.GeneralOpenAIRequest { func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2, MaxTokens: 2,
Stream: false, Model: model,
Model: "gpt-3.5-turbo",
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
@ -45,7 +46,7 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) { func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@ -68,12 +69,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
var modelName string modelName := request.Model
modelList := adaptor.GetModelList()
modelMap := channel.GetModelMapping() modelMap := channel.GetModelMapping()
if len(modelList) != 0 {
modelName = modelList[0]
}
if modelName == "" || !strings.Contains(channel.Models, modelName) { if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",") modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 { if len(modelNames) > 0 {
@ -83,9 +80,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
modelName = modelMap[modelName] modelName = modelMap[modelName]
} }
} }
request := buildTestRequest() meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
@ -139,10 +135,15 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
model := c.Query("model")
testRequest := buildTestRequest(model)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel) err, _ = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
@ -150,6 +151,7 @@ func TestChannel(c *gin.Context) {
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -157,6 +159,7 @@ func TestChannel(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"time": consumedTime, "time": consumedTime,
"model": model,
}) })
return return
} }
@ -187,11 +190,12 @@ func testChannels(notify bool, scope string) error {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled { if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else { } else {

View File

@ -1,5 +1,5 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { import {
API, API,
@ -70,11 +70,31 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/?p=${startIdx}`); const res = await API.get(`/api/channel/?p=${startIdx}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
let localChannels = data.map((channel) => {
if (channel.models === '') {
channel.models = [];
channel.test_model = "";
} else {
channel.models = channel.models.split(',');
if (channel.models.length > 0) {
channel.test_model = channel.models[0];
}
channel.model_options = channel.models.map((model) => {
return {
key: model,
text: model,
value: model,
}
})
console.log('channel', channel)
}
return channel;
});
if (startIdx === 0) { if (startIdx === 0) {
setChannels(data); setChannels(localChannels);
} else { } else {
let newChannels = [...channels]; let newChannels = [...channels];
newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data); newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels);
setChannels(newChannels); setChannels(newChannels);
} }
} else { } else {
@ -225,19 +245,31 @@ const ChannelsTable = () => {
setSearching(false); setSearching(false);
}; };
const testChannel = async (id, name, idx) => { const switchTestModel = async (idx, model) => {
const res = await API.get(`/api/channel/test/${id}/`); let newChannels = [...channels];
const { success, message, time } = res.data; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].test_model = model;
setChannels(newChannels);
};
const testChannel = async (id, name, idx, m) => {
const res = await API.get(`/api/channel/test/${id}?model=${m}`);
const { success, message, time, model } = res.data;
if (success) { if (success) {
let newChannels = [...channels]; let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx; let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000; newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels); setChannels(newChannels);
showInfo(`渠道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); showInfo(`渠道 ${name} 测试成功,模型 ${model}耗时 ${time.toFixed(2)} 秒。`);
} else { } else {
showError(message); showError(message);
} }
let newChannels = [...channels];
let realIdx = (activePage - 1) * ITEMS_PER_PAGE + idx;
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
}; };
const testChannels = async (scope) => { const testChannels = async (scope) => {
@ -405,6 +437,7 @@ const ChannelsTable = () => {
> >
优先级 优先级
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell>测试模型</Table.HeaderCell>
<Table.HeaderCell>操作</Table.HeaderCell> <Table.HeaderCell>操作</Table.HeaderCell>
</Table.Row> </Table.Row>
</Table.Header> </Table.Header>
@ -459,13 +492,24 @@ const ChannelsTable = () => {
basic basic
/> />
</Table.Cell> </Table.Cell>
<Table.Cell>
<Dropdown
placeholder='请选择测试模型'
selection
options={channel.model_options}
defaultValue={channel.test_model}
onChange={(event, data) => {
switchTestModel(idx, data.value);
}}
/>
</Table.Cell>
<Table.Cell> <Table.Cell>
<div> <div>
<Button <Button
size={'small'} size={'small'}
positive positive
onClick={() => { onClick={() => {
testChannel(channel.id, channel.name, idx); testChannel(channel.id, channel.name, idx, channel.test_model);
}} }}
> >
测试 测试