diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js
index 947df3bf..626727d4 100644
--- a/web/berry/src/utils/common.js
+++ b/web/berry/src/utils/common.js
@@ -193,3 +193,29 @@ export function removeTrailingSlash(url) {
return url;
}
}
+
+let channelModels = undefined;
+export async function loadChannelModels() {
+ const res = await API.get('/api/models');
+ const { success, data } = res.data;
+ if (!success) {
+ return;
+ }
+ channelModels = data;
+ localStorage.setItem('channel_models', JSON.stringify(data));
+}
+
+export function getChannelModels(type) {
+ if (channelModels !== undefined && type in channelModels) {
+ return channelModels[type];
+ }
+ let models = localStorage.getItem('channel_models');
+ if (!models) {
+ return [];
+ }
+ channelModels = JSON.parse(models);
+ if (type in channelModels) {
+ return channelModels[type];
+ }
+ return [];
+}
diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js
index 03b4df57..fc568d35 100644
--- a/web/berry/src/views/Channel/component/EditModal.js
+++ b/web/berry/src/views/Channel/component/EditModal.js
@@ -1,9 +1,9 @@
-import PropTypes from "prop-types";
-import { useState, useEffect } from "react";
-import { CHANNEL_OPTIONS } from "constants/ChannelConstants";
-import { useTheme } from "@mui/material/styles";
-import { API } from "utils/api";
-import { showError, showSuccess } from "utils/common";
+import PropTypes from 'prop-types';
+import { useState, useEffect } from 'react';
+import { CHANNEL_OPTIONS } from 'constants/ChannelConstants';
+import { useTheme } from '@mui/material/styles';
+import { API } from 'utils/api';
+import { showError, showSuccess, getChannelModels } from 'utils/common';
import {
Dialog,
DialogTitle,
@@ -22,15 +22,15 @@ import {
Autocomplete,
FormHelperText,
Switch,
- Checkbox,
-} from "@mui/material";
+ Checkbox
+} from '@mui/material';
-import { Formik } from "formik";
-import * as Yup from "yup";
-import { defaultConfig, typeConfig } from "../type/Config"; //typeConfig
-import { createFilterOptions } from "@mui/material/Autocomplete";
-import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank";
-import CheckBoxIcon from "@mui/icons-material/CheckBox";
+import { Formik } from 'formik';
+import * as Yup from 'yup';
+import { defaultConfig, typeConfig } from '../type/Config'; //typeConfig
+import { createFilterOptions } from '@mui/material/Autocomplete';
+import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank';
+import CheckBoxIcon from '@mui/icons-material/CheckBox';
const icon = ;
const checkedIcon = ;
@@ -38,38 +38,34 @@ const checkedIcon = ;
const filter = createFilterOptions();
const validationSchema = Yup.object().shape({
is_edit: Yup.boolean(),
- name: Yup.string().required("名称 不能为空"),
- type: Yup.number().required("渠道 不能为空"),
- key: Yup.string().when("is_edit", {
+ name: Yup.string().required('名称 不能为空'),
+ type: Yup.number().required('渠道 不能为空'),
+ key: Yup.string().when('is_edit', {
is: false,
- then: Yup.string().required("密钥 不能为空"),
+ then: Yup.string().required('密钥 不能为空')
}),
other: Yup.string(),
- models: Yup.array().min(1, "模型 不能为空"),
- groups: Yup.array().min(1, "用户组 不能为空"),
- base_url: Yup.string().when("type", {
+ models: Yup.array().min(1, '模型 不能为空'),
+ groups: Yup.array().min(1, '用户组 不能为空'),
+ base_url: Yup.string().when('type', {
is: (value) => [3, 8].includes(value),
- then: Yup.string().required("渠道API地址 不能为空"), // base_url 是必需的
- otherwise: Yup.string(), // 在其他情况下,base_url 可以是任意字符串
+ then: Yup.string().required('渠道API地址 不能为空'), // base_url 是必需的
+ otherwise: Yup.string() // 在其他情况下,base_url 可以是任意字符串
}),
- model_mapping: Yup.string().test(
- "is-json",
- "必须是有效的JSON字符串",
- function (value) {
- try {
- if (value === "" || value === null || value === undefined) {
- return true;
- }
- const parsedValue = JSON.parse(value);
- if (typeof parsedValue === "object") {
- return true;
- }
- } catch (e) {
- return false;
+ model_mapping: Yup.string().test('is-json', '必须是有效的JSON字符串', function (value) {
+ try {
+ if (value === '' || value === null || value === undefined) {
+ return true;
}
+ const parsedValue = JSON.parse(value);
+ if (typeof parsedValue === 'object') {
+ return true;
+ }
+ } catch (e) {
return false;
}
- ),
+ return false;
+ })
});
const EditModal = ({ open, channelId, onCancel, onOk }) => {
@@ -81,12 +77,13 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const [groupOptions, setGroupOptions] = useState([]);
const [modelOptions, setModelOptions] = useState([]);
const [batchAdd, setBatchAdd] = useState(false);
+ const [basicModels, setBasicModels] = useState([]);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
setInputLabel({
...defaultConfig.inputLabel,
- ...typeConfig[typeValue].inputLabel,
+ ...typeConfig[typeValue].inputLabel
});
} else {
setInputLabel(defaultConfig.inputLabel);
@@ -95,7 +92,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
if (typeConfig[typeValue]?.prompt) {
setInputPrompt({
...defaultConfig.prompt,
- ...typeConfig[typeValue].prompt,
+ ...typeConfig[typeValue].prompt
});
} else {
setInputPrompt(defaultConfig.prompt);
@@ -104,42 +101,14 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
return typeConfig[typeValue]?.input;
};
const handleTypeChange = (setFieldValue, typeValue, values) => {
- const newInput = initChannel(typeValue);
-
- if (newInput) {
- Object.keys(newInput).forEach((key) => {
- if (
- (!Array.isArray(values[key]) &&
- values[key] !== null &&
- values[key] !== undefined &&
- values[key] !== "") ||
- (Array.isArray(values[key]) && values[key].length > 0)
- ) {
- return;
- }
-
- if (key === "models") {
- setFieldValue(key, initialModel(newInput[key]));
- return;
- }
- setFieldValue(key, newInput[key]);
- });
+ initChannel(typeValue);
+ let localModels = getChannelModels(typeValue);
+ setBasicModels(localModels);
+ if (localModels.length > 0 && Array.isArray(values['models']) && values['models'].length == 0) {
+ setFieldValue('models', initialModel(localModels));
}
};
- const basicModels = (channelType) => {
- let modelGroup =
- typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup;
- // 循环 modelOptions,找到 modelGroup 对应的模型
- let modelList = [];
- modelOptions.forEach((model) => {
- if (model.group === modelGroup) {
- modelList.push(model);
- }
- });
- return modelList;
- };
-
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group/`);
@@ -155,7 +124,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const { data } = res.data;
data.forEach((item) => {
if (!item.owned_by) {
- item.owned_by = "未知";
+ item.owned_by = '未知';
}
});
// 先对data排序
@@ -171,7 +140,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
data.map((model) => {
return {
id: model.id,
- group: model.owned_by,
+ group: model.owned_by
};
})
);
@@ -182,23 +151,23 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
- if (values.base_url && values.base_url.endsWith("/")) {
+ if (values.base_url && values.base_url.endsWith('/')) {
values.base_url = values.base_url.slice(0, values.base_url.length - 1);
}
- if (values.type === 3 && values.other === "") {
- values.other = "2023-09-01-preview";
+ if (values.type === 3 && values.other === '') {
+ values.other = '2023-09-01-preview';
}
- if (values.type === 18 && values.other === "") {
- values.other = "v2.1";
+ if (values.type === 18 && values.other === '') {
+ values.other = 'v2.1';
}
let res;
- const modelsStr = values.models.map((model) => model.id).join(",");
- values.group = values.groups.join(",");
+ const modelsStr = values.models.map((model) => model.id).join(',');
+ values.group = values.groups.join(',');
if (channelId) {
res = await API.put(`/api/channel/`, {
...values,
id: parseInt(channelId),
- models: modelsStr,
+ models: modelsStr
});
} else {
res = await API.post(`/api/channel/`, { ...values, models: modelsStr });
@@ -206,9 +175,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
const { success, message } = res.data;
if (success) {
if (channelId) {
- showSuccess("渠道更新成功!");
+ showSuccess('渠道更新成功!');
} else {
- showSuccess("渠道创建成功!");
+ showSuccess('渠道创建成功!');
}
setSubmitting(false);
setStatus({ success: true });
@@ -226,15 +195,15 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
}
// 如果 channelModel 是一个字符串
- if (typeof channelModel === "string") {
- channelModel = channelModel.split(",");
+ if (typeof channelModel === 'string') {
+ channelModel = channelModel.split(',');
}
let modelList = channelModel.map((model) => {
const modelOption = modelOptions.find((option) => option.id === model);
if (modelOption) {
return modelOption;
}
- return { id: model, group: "自定义:点击或回车输入" };
+ return { id: model, group: '自定义:点击或回车输入' };
});
return modelList;
}
@@ -243,24 +212,20 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
if (success) {
- if (data.models === "") {
+ if (data.models === '') {
data.models = [];
} else {
data.models = initialModel(data.models);
}
- if (data.group === "") {
+ if (data.group === '') {
data.groups = [];
} else {
- data.groups = data.group.split(",");
+ data.groups = data.group.split(',');
}
- if (data.model_mapping !== "") {
- data.model_mapping = JSON.stringify(
- JSON.parse(data.model_mapping),
- null,
- 2
- );
+ if (data.model_mapping !== '') {
+ data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
- data.base_url = data.base_url ?? "";
+ data.base_url = data.base_url ?? '';
data.is_edit = true;
initChannel(data.type);
setInitialInput(data);
@@ -286,45 +251,25 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
}, [channelId]);
return (
-