diff --git a/controller/channel-test.go b/controller/channel-test.go index 5a8e9bda..ea4ab886 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -18,7 +18,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) { +func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *types.OpenAIError) { if channel.TestModel == "" { return errors.New("请填写测速模型后再试"), nil } @@ -33,7 +33,13 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = req - request.Model = channel.TestModel + request := buildTestRequest() + + if testModel != "" { + request.Model = testModel + } else { + request.Model = channel.TestModel + } provider := providers.GetProvider(channel, c) if provider == nil { @@ -54,21 +60,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e chatProvider.SetUsage(&types.Usage{}) - response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(&request) + response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(request) if openAIErrorWithStatusCode != nil { return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError } - usage := chatProvider.GetUsage() - - if usage.CompletionTokens == 0 { - return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil - } - // 转换为JSON字符串 jsonBytes, _ := json.Marshal(response) - common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, string(jsonBytes))) + common.SysLog(fmt.Sprintf("测试渠道 %s : %s 返回内容为:%s", channel.Name, request.Model, string(jsonBytes))) return nil, nil } @@ -81,9 +81,9 @@ func buildTestRequest() *types.ChatCompletionRequest { Content: "You just need to output 'hi' next.", }, }, - Model: "", - // MaxTokens: 1, - Stream: false, + Model: "", + MaxTokens: 2, + Stream: false, } return testRequest } @@ -105,9 +105,9 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest() + testModel := c.Query("model") tik := time.Now() - err, _ = testChannel(channel, *testRequest) + err, _ = testChannel(channel, testModel) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) @@ -163,7 +163,6 @@ func testAllChannels(notify bool) error { if err != nil { return err } - testRequest := buildTestRequest() var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value @@ -172,7 +171,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, *testRequest) + err, openaiErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if milliseconds > disableThreshold { diff --git a/web/src/views/Channel/component/TableRow.js b/web/src/views/Channel/component/TableRow.js index 06d0881a..cb746fea 100644 --- a/web/src/views/Channel/component/TableRow.js +++ b/web/src/views/Channel/component/TableRow.js @@ -22,6 +22,8 @@ import { Collapse, Typography, TextField, + Stack, + Menu, Box } from '@mui/material'; @@ -32,12 +34,51 @@ import ResponseTimeLabel from './ResponseTimeLabel'; import GroupLabel from './GroupLabel'; import { IconDotsVertical, IconEdit, IconTrash, IconCopy, IconWorldWww } from '@tabler/icons-react'; +import { styled, alpha } from '@mui/material/styles'; import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'; import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp'; import { copy } from 'utils/common'; +const StyledMenu = styled((props) => ( + +))(({ theme }) => ({ + '& .MuiPaper-root': { + borderRadius: 6, + marginTop: theme.spacing(1), + minWidth: 180, + color: theme.palette.mode === 'light' ? 'rgb(55, 65, 81)' : theme.palette.grey[300], + boxShadow: + 'rgb(255, 255, 255) 0px 0px 0px 0px, rgba(0, 0, 0, 0.05) 0px 0px 0px 1px, rgba(0, 0, 0, 0.1) 0px 10px 15px -3px, rgba(0, 0, 0, 0.05) 0px 4px 6px -2px', + '& .MuiMenu-list': { + padding: '4px 0' + }, + '& .MuiMenuItem-root': { + '& .MuiSvgIcon-root': { + fontSize: 18, + color: theme.palette.text.secondary, + marginRight: theme.spacing(1.5) + }, + '&:active': { + backgroundColor: alpha(theme.palette.primary.main, theme.palette.action.selectedOpacity) + } + } + } +})); + export default function ChannelTableRow({ item, manageChannel, handleOpenModal, setModalChannelId }) { const [open, setOpen] = useState(null); + const [openTest, setOpenTest] = useState(false); const [openDelete, setOpenDelete] = useState(false); const [statusSwitch, setStatusSwitch] = useState(item.status); const [priorityValve, setPriority] = useState(item.priority); @@ -63,6 +104,10 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, setOpen(event.currentTarget); }; + const handleTestModel = (event) => { + setOpenTest(event.currentTarget); + }; + const handleCloseMenu = () => { setOpen(null); }; @@ -105,11 +150,21 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, setWeight(currentValue); }; - const handleResponseTime = async () => { - const { success, time } = await manageChannel(item.id, 'test', ''); + const handleResponseTime = async (modelName) => { + setOpenTest(null); + + if (typeof modelName !== 'string') { + modelName = item.test_model; + } + + if (modelName == '') { + showError('请先设置测试模型'); + return; + } + const { success, time } = await manageChannel(item.id, 'test', modelName); if (success) { setResponseTimeData({ test_time: Date.now() / 1000, response_time: time * 1000 }); - showInfo(`通道 ${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); + showInfo(`通道 ${item.name}: ${modelName} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } }; @@ -202,9 +257,24 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, - - - + + + + + + + @@ -256,6 +326,28 @@ export default function ChannelTableRow({ item, manageChannel, handleOpenModal, + { + setOpenTest(null); + }} + > + {modelMap.map((model) => ( + { + handleResponseTime(model); + }} + > + {model} + + ))} + diff --git a/web/src/views/Channel/index.js b/web/src/views/Channel/index.js index dadebec4..ac150e5c 100644 --- a/web/src/views/Channel/index.js +++ b/web/src/views/Channel/index.js @@ -158,7 +158,9 @@ export default function ChannelPage() { }); break; case 'test': - res = await API.get(url + `test/${id}`); + res = await API.get(url + `test/${id}`, { + params: { model: value } + }); break; } const { success, message } = res.data; @@ -377,16 +379,16 @@ export default function ChannelPage() { orderBy={orderBy} onRequestSort={handleSort} headLabel={[ - { id: 'collapse', label: '', disableSort: true }, - { id: 'id', label: 'ID', disableSort: false }, + { id: 'collapse', label: '', disableSort: true, width: '50px' }, + { id: 'id', label: 'ID', disableSort: false, width: '80px' }, { id: 'name', label: '名称', disableSort: false }, { id: 'group', label: '分组', disableSort: true }, { id: 'type', label: '类型', disableSort: false }, { id: 'status', label: '状态', disableSort: false }, { id: 'response_time', label: '响应时间', disableSort: false }, { id: 'balance', label: '余额', disableSort: false }, - { id: 'priority', label: '优先级', disableSort: false }, - { id: 'weight', label: '权重', disableSort: false }, + { id: 'priority', label: '优先级', disableSort: false, width: '80px' }, + { id: 'weight', label: '权重', disableSort: false, width: '80px' }, { id: 'action', label: '操作', disableSort: true } ]} />