diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index 947df3bf..626727d4 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -193,3 +193,29 @@ export function removeTrailingSlash(url) { return url; } } + +let channelModels = undefined; +export async function loadChannelModels() { + const res = await API.get('/api/models'); + const { success, data } = res.data; + if (!success) { + return; + } + channelModels = data; + localStorage.setItem('channel_models', JSON.stringify(data)); +} + +export function getChannelModels(type) { + if (channelModels !== undefined && type in channelModels) { + return channelModels[type]; + } + let models = localStorage.getItem('channel_models'); + if (!models) { + return []; + } + channelModels = JSON.parse(models); + if (type in channelModels) { + return channelModels[type]; + } + return []; +} diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index 03b4df57..fc568d35 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -1,9 +1,9 @@ -import PropTypes from "prop-types"; -import { useState, useEffect } from "react"; -import { CHANNEL_OPTIONS } from "constants/ChannelConstants"; -import { useTheme } from "@mui/material/styles"; -import { API } from "utils/api"; -import { showError, showSuccess } from "utils/common"; +import PropTypes from 'prop-types'; +import { useState, useEffect } from 'react'; +import { CHANNEL_OPTIONS } from 'constants/ChannelConstants'; +import { useTheme } from '@mui/material/styles'; +import { API } from 'utils/api'; +import { showError, showSuccess, getChannelModels } from 'utils/common'; import { Dialog, DialogTitle, @@ -22,15 +22,15 @@ import { Autocomplete, FormHelperText, Switch, - Checkbox, -} from "@mui/material"; + Checkbox +} from '@mui/material'; -import { Formik } from "formik"; -import * as Yup from "yup"; -import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig -import { createFilterOptions } from "@mui/material/Autocomplete"; -import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank"; -import CheckBoxIcon from "@mui/icons-material/CheckBox"; +import { Formik } from 'formik'; +import * as Yup from 'yup'; +import { defaultConfig, typeConfig } from '../type/Config'; //typeConfig +import { createFilterOptions } from '@mui/material/Autocomplete'; +import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank'; +import CheckBoxIcon from '@mui/icons-material/CheckBox'; const icon = ; const checkedIcon = ; @@ -38,38 +38,34 @@ const checkedIcon = ; const filter = createFilterOptions(); const validationSchema = Yup.object().shape({ is_edit: Yup.boolean(), - name: Yup.string().required("名称 不能为空"), - type: Yup.number().required("渠道 不能为空"), - key: Yup.string().when("is_edit", { + name: Yup.string().required('名称 不能为空'), + type: Yup.number().required('渠道 不能为空'), + key: Yup.string().when('is_edit', { is: false, - then: Yup.string().required("密钥 不能为空"), + then: Yup.string().required('密钥 不能为空') }), other: Yup.string(), - models: Yup.array().min(1, "模型 不能为空"), - groups: Yup.array().min(1, "用户组 不能为空"), - base_url: Yup.string().when("type", { + models: Yup.array().min(1, '模型 不能为空'), + groups: Yup.array().min(1, '用户组 不能为空'), + base_url: Yup.string().when('type', { is: (value) => [3, 8].includes(value), - then: Yup.string().required("渠道API地址 不能为空"), // base_url 是必需的 - otherwise: Yup.string(), // 在其他情况下,base_url 可以是任意字符串 + then: Yup.string().required('渠道API地址 不能为空'), // base_url 是必需的 + otherwise: Yup.string() // 在其他情况下,base_url 可以是任意字符串 }), - model_mapping: Yup.string().test( - "is-json", - "必须是有效的JSON字符串", - function (value) { - try { - if (value === "" || value === null || value === undefined) { - return true; - } - const parsedValue = JSON.parse(value); - if (typeof parsedValue === "object") { - return true; - } - } catch (e) { - return false; + model_mapping: Yup.string().test('is-json', '必须是有效的JSON字符串', function (value) { + try { + if (value === '' || value === null || value === undefined) { + return true; } + const parsedValue = JSON.parse(value); + if (typeof parsedValue === 'object') { + return true; + } + } catch (e) { return false; } - ), + return false; + }) }); const EditModal = ({ open, channelId, onCancel, onOk }) => { @@ -81,12 +77,13 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const [groupOptions, setGroupOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); const [batchAdd, setBatchAdd] = useState(false); + const [basicModels, setBasicModels] = useState([]); const initChannel = (typeValue) => { if (typeConfig[typeValue]?.inputLabel) { setInputLabel({ ...defaultConfig.inputLabel, - ...typeConfig[typeValue].inputLabel, + ...typeConfig[typeValue].inputLabel }); } else { setInputLabel(defaultConfig.inputLabel); @@ -95,7 +92,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { if (typeConfig[typeValue]?.prompt) { setInputPrompt({ ...defaultConfig.prompt, - ...typeConfig[typeValue].prompt, + ...typeConfig[typeValue].prompt }); } else { setInputPrompt(defaultConfig.prompt); @@ -104,42 +101,14 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { return typeConfig[typeValue]?.input; }; const handleTypeChange = (setFieldValue, typeValue, values) => { - const newInput = initChannel(typeValue); - - if (newInput) { - Object.keys(newInput).forEach((key) => { - if ( - (!Array.isArray(values[key]) && - values[key] !== null && - values[key] !== undefined && - values[key] !== "") || - (Array.isArray(values[key]) && values[key].length > 0) - ) { - return; - } - - if (key === "models") { - setFieldValue(key, initialModel(newInput[key])); - return; - } - setFieldValue(key, newInput[key]); - }); + initChannel(typeValue); + let localModels = getChannelModels(typeValue); + setBasicModels(localModels); + if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length == 0) { + setFieldValue('models', initialModel(localModels)); } }; - const basicModels = (channelType) => { - let modelGroup = - typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup; - // 循环 modelOptions,找到 modelGroup 对应的模型 - let modelList = []; - modelOptions.forEach((model) => { - if (model.group === modelGroup) { - modelList.push(model); - } - }); - return modelList; - }; - const fetchGroups = async () => { try { let res = await API.get(`/api/group/`); @@ -155,7 +124,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const { data } = res.data; data.forEach((item) => { if (!item.owned_by) { - item.owned_by = "未知"; + item.owned_by = '未知'; } }); // 先对data排序 @@ -171,7 +140,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { data.map((model) => { return { id: model.id, - group: model.owned_by, + group: model.owned_by }; }) ); @@ -182,23 +151,23 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const submit = async (values, { setErrors, setStatus, setSubmitting }) => { setSubmitting(true); - if (values.base_url && values.base_url.endsWith("/")) { + if (values.base_url && values.base_url.endsWith('/')) { values.base_url = values.base_url.slice(0, values.base_url.length - 1); } - if (values.type === 3 && values.other === "") { - values.other = "2023-09-01-preview"; + if (values.type === 3 && values.other === '') { + values.other = '2023-09-01-preview'; } - if (values.type === 18 && values.other === "") { - values.other = "v2.1"; + if (values.type === 18 && values.other === '') { + values.other = 'v2.1'; } let res; - const modelsStr = values.models.map((model) => model.id).join(","); - values.group = values.groups.join(","); + const modelsStr = values.models.map((model) => model.id).join(','); + values.group = values.groups.join(','); if (channelId) { res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId), - models: modelsStr, + models: modelsStr }); } else { res = await API.post(`/api/channel/`, { ...values, models: modelsStr }); @@ -206,9 +175,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const { success, message } = res.data; if (success) { if (channelId) { - showSuccess("渠道更新成功!"); + showSuccess('渠道更新成功!'); } else { - showSuccess("渠道创建成功!"); + showSuccess('渠道创建成功!'); } setSubmitting(false); setStatus({ success: true }); @@ -226,15 +195,15 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { } // 如果 channelModel 是一个字符串 - if (typeof channelModel === "string") { - channelModel = channelModel.split(","); + if (typeof channelModel === 'string') { + channelModel = channelModel.split(','); } let modelList = channelModel.map((model) => { const modelOption = modelOptions.find((option) => option.id === model); if (modelOption) { return modelOption; } - return { id: model, group: "自定义:点击或回车输入" }; + return { id: model, group: '自定义:点击或回车输入' }; }); return modelList; } @@ -243,24 +212,20 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; if (success) { - if (data.models === "") { + if (data.models === '') { data.models = []; } else { data.models = initialModel(data.models); } - if (data.group === "") { + if (data.group === '') { data.groups = []; } else { - data.groups = data.group.split(","); + data.groups = data.group.split(','); } - if (data.model_mapping !== "") { - data.model_mapping = JSON.stringify( - JSON.parse(data.model_mapping), - null, - 2 - ); + if (data.model_mapping !== '') { + data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } - data.base_url = data.base_url ?? ""; + data.base_url = data.base_url ?? ''; data.is_edit = true; initChannel(data.type); setInitialInput(data); @@ -286,45 +251,25 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, [channelId]); return ( - + - {channelId ? "编辑渠道" : "新建渠道"} + {channelId ? '编辑渠道' : '新建渠道'} - - {({ - errors, - handleBlur, - handleChange, - handleSubmit, - isSubmitting, - touched, - values, - setFieldValue, - }) => ( + + {({ errors, handleBlur, handleChange, handleSubmit, isSubmitting, touched, values, setFieldValue }) => (
- - - {inputLabel.type} - + + {inputLabel.type}