From 7ef4a7db593a702d9008e4a861db82bd5c9ad075 Mon Sep 17 00:00:00 2001 From: MartialBE Date: Mon, 1 Jan 2024 14:36:58 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Optimize=20model=20list.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 172 ++++++++++-------- controller/model.go | 93 +++++++--- providers/openai/base.go | 1 - router/api-router.go | 1 + web/src/views/Channel/component/EditModal.js | 86 +++++++-- web/src/views/Channel/component/NameLabel.js | 5 +- web/src/views/Channel/type/Config.js | 33 ++-- .../Dashboard/component/SupportModels.js | 69 +++++++ web/src/views/Dashboard/index.js | 5 + 9 files changed, 327 insertions(+), 138 deletions(-) create mode 100644 web/src/views/Dashboard/component/SupportModels.js diff --git a/common/model-ratio.go b/common/model-ratio.go index 3857f3d9..4453a800 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -6,6 +6,100 @@ import ( "time" ) +type ModelType struct { + Ratio float64 + Type int +} + +var ModelTypes map[string]ModelType + +// ModelRatio +// https://platform.openai.com/docs/models/model-endpoint-compatibility +// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf +// https://openai.com/pricing +// TODO: when a new api is enabled, check the pricing here +// 1 === $0.002 / 1K tokens +// 1 === ¥0.014 / 1k tokens +var ModelRatio map[string]float64 + +func init() { + ModelTypes = map[string]ModelType{ + "gpt-4": {15, ChannelTypeOpenAI}, + "gpt-4-0314": {15, ChannelTypeOpenAI}, + "gpt-4-0613": {15, ChannelTypeOpenAI}, + "gpt-4-32k": {30, ChannelTypeOpenAI}, + "gpt-4-32k-0314": {30, ChannelTypeOpenAI}, + "gpt-4-32k-0613": {30, ChannelTypeOpenAI}, + "gpt-4-1106-preview": {5, ChannelTypeOpenAI}, // $0.01 / 1K tokens + "gpt-4-vision-preview": {5, ChannelTypeOpenAI}, // $0.01 / 1K tokens + "gpt-3.5-turbo": {0.75, ChannelTypeOpenAI}, // $0.0015 / 1K tokens + "gpt-3.5-turbo-0301": {0.75, ChannelTypeOpenAI}, + "gpt-3.5-turbo-0613": {0.75, ChannelTypeOpenAI}, + "gpt-3.5-turbo-16k": {1.5, ChannelTypeOpenAI}, // $0.003 / 1K tokens + "gpt-3.5-turbo-16k-0613": {1.5, ChannelTypeOpenAI}, + "gpt-3.5-turbo-instruct": {0.75, ChannelTypeOpenAI}, // $0.0015 / 1K tokens + "gpt-3.5-turbo-1106": {0.5, ChannelTypeOpenAI}, // $0.001 / 1K tokens + "text-ada-001": {0.2, ChannelTypeOpenAI}, + "text-babbage-001": {0.25, ChannelTypeOpenAI}, + "text-curie-001": {1, ChannelTypeOpenAI}, + "text-davinci-002": {10, ChannelTypeOpenAI}, + "text-davinci-003": {10, ChannelTypeOpenAI}, + "text-davinci-edit-001": {10, ChannelTypeOpenAI}, + "code-davinci-edit-001": {10, ChannelTypeOpenAI}, + "whisper-1": {15, ChannelTypeOpenAI}, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": {7.5, ChannelTypeOpenAI}, // $0.015 / 1K characters + "tts-1-1106": {7.5, ChannelTypeOpenAI}, + "tts-1-hd": {15, ChannelTypeOpenAI}, // $0.030 / 1K characters + "tts-1-hd-1106": {15, ChannelTypeOpenAI}, + "davinci": {10, ChannelTypeOpenAI}, + "curie": {10, ChannelTypeOpenAI}, + "babbage": {10, ChannelTypeOpenAI}, + "ada": {10, ChannelTypeOpenAI}, + "text-embedding-ada-002": {0.05, ChannelTypeOpenAI}, + "text-search-ada-doc-001": {10, ChannelTypeOpenAI}, + "text-moderation-stable": {0.1, ChannelTypeOpenAI}, + "text-moderation-latest": {0.1, ChannelTypeOpenAI}, + "dall-e-2": {8, ChannelTypeOpenAI}, // $0.016 - $0.020 / image + "dall-e-3": {20, ChannelTypeOpenAI}, // $0.040 - $0.120 / image + "claude-instant-1": {0.815, ChannelTypeAnthropic}, // $1.63 / 1M tokens + "claude-2": {5.51, ChannelTypeAnthropic}, // $11.02 / 1M tokens + "claude-2.0": {5.51, ChannelTypeAnthropic}, // $11.02 / 1M tokens + "claude-2.1": {5.51, ChannelTypeAnthropic}, // $11.02 / 1M tokens + "ERNIE-Bot": {0.8572, ChannelTypeBaidu}, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": {0.5715, ChannelTypeBaidu}, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": {8.572, ChannelTypeBaidu}, // ¥0.12 / 1k tokens + "Embedding-V1": {0.1429, ChannelTypeBaidu}, // ¥0.002 / 1k tokens + "PaLM-2": {1, ChannelTypePaLM}, + "gemini-pro": {1, ChannelTypeGemini}, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": {1, ChannelTypeGemini}, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "chatglm_turbo": {0.3572, ChannelTypeZhipu}, // ¥0.005 / 1k tokens + "chatglm_pro": {0.7143, ChannelTypeZhipu}, // ¥0.01 / 1k tokens + "chatglm_std": {0.3572, ChannelTypeZhipu}, // ¥0.005 / 1k tokens + "chatglm_lite": {0.1429, ChannelTypeZhipu}, // ¥0.002 / 1k tokens + "qwen-turbo": {0.5715, ChannelTypeAli}, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-plus": {1.4286, ChannelTypeAli}, // ¥0.02 / 1k tokens + "qwen-max": {1.4286, ChannelTypeAli}, // ¥0.02 / 1k tokens + "qwen-max-longcontext": {1.4286, ChannelTypeAli}, // ¥0.02 / 1k tokens + "qwen-vl-plus": {0.5715, ChannelTypeAli}, // ¥0.008 / 1k tokens + "text-embedding-v1": {0.05, ChannelTypeAli}, // ¥0.0007 / 1k tokens + "SparkDesk": {1.2858, ChannelTypeXunfei}, // ¥0.018 / 1k tokens + "360GPT_S2_V9": {0.8572, ChannelType360}, // ¥0.012 / 1k tokens + "embedding-bert-512-v1": {0.0715, ChannelType360}, // ¥0.001 / 1k tokens + "embedding_s1_v1": {0.0715, ChannelType360}, // ¥0.001 / 1k tokens + "semantic_similarity_s1_v1": {0.0715, ChannelType360}, // ¥0.001 / 1k tokens + "hunyuan": {7.143, ChannelTypeTencent}, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "Baichuan2-Turbo": {0.5715, ChannelTypeBaichuan}, // ¥0.008 / 1k tokens + "Baichuan2-Turbo-192k": {1.143, ChannelTypeBaichuan}, // ¥0.016 / 1k tokens + "Baichuan2-53B": {1.4286, ChannelTypeBaichuan}, // ¥0.02 / 1k tokens + "Baichuan-Text-Embedding": {0.0357, ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens + } + + ModelRatio = make(map[string]float64) + for name, modelType := range ModelTypes { + ModelRatio[name] = modelType.Ratio + } +} + var DalleSizeRatios = map[string]map[string]float64{ "dall-e-2": { "256x256": 1, @@ -29,84 +123,6 @@ var DalleImagePromptLengthLimitations = map[string]int{ "dall-e-3": 4000, } -// ModelRatio -// https://platform.openai.com/docs/models/model-endpoint-compatibility -// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf -// https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here -// 1 === $0.002 / 1K tokens -// 1 === ¥0.014 / 1k tokens -var ModelRatio = map[string]float64{ - "gpt-4": 15, - "gpt-4-0314": 15, - "gpt-4-0613": 15, - "gpt-4-32k": 30, - "gpt-4-32k-0314": 30, - "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-0301": 0.75, - "gpt-3.5-turbo-0613": 0.75, - "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens - "gpt-3.5-turbo-16k-0613": 1.5, - "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens - "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens - "text-ada-001": 0.2, - "text-babbage-001": 0.25, - "text-curie-001": 1, - "text-davinci-002": 10, - "text-davinci-003": 10, - "text-davinci-edit-001": 10, - "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, // $0.015 / 1K characters - "tts-1-1106": 7.5, - "tts-1-hd": 15, // $0.030 / 1K characters - "tts-1-hd-1106": 15, - "davinci": 10, - "curie": 10, - "babbage": 10, - "ada": 10, - "text-embedding-ada-002": 0.05, - "text-search-ada-doc-001": 10, - "text-moderation-stable": 0.1, - "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image - "claude-instant-1": 0.815, // $1.63 / 1M tokens - "claude-2": 5.51, // $11.02 / 1M tokens - "claude-2.0": 5.51, // $11.02 / 1M tokens - "claude-2.1": 5.51, // $11.02 / 1M tokens - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing - "qwen-plus": 1.4286, // ¥0.02 / 1k tokens - "qwen-max": 1.4286, // ¥0.02 / 1k tokens - "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens - "qwen-vl-plus": 0.5715, // ¥0.008 / 1k tokens - "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens - "SparkDesk": 1.2858, // ¥0.018 / 1k tokens - "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens - "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens - "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens - "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - "Baichuan2-Turbo": 0.5715, // ¥0.008 / 1k tokens - "Baichuan2-Turbo-192k": 1.143, // ¥0.016 / 1k tokens - "Baichuan2-53B": 1.4286, // ¥0.02 / 1k tokens - "Baichuan-Text-Embedding": 0.0357, // ¥0.0005 / 1k tokens -} - func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { diff --git a/controller/model.go b/controller/model.go index 9e023c93..f938d59d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -38,37 +38,35 @@ type OpenAIModels struct { Parent *string `json:"parent"` } -var openAIModels []OpenAIModels -var openAIModelsMap map[string]OpenAIModels +var modelOwnedBy map[int]string func init() { - // https://platform.openai.com/docs/models/model-endpoint-compatibility - keys := make([]string, 0, len(common.ModelRatio)) - for k := range common.ModelRatio { - keys = append(keys, k) - } - sort.Strings(keys) - - for _, modelId := range keys { - openAIModels = append(openAIModels, OpenAIModels{ - Id: modelId, - Object: "model", - Created: 1677649963, - OwnedBy: nil, - Permission: nil, - Root: nil, - Parent: nil, - }) - } - - openAIModelsMap = make(map[string]OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + modelOwnedBy = map[int]string{ + common.ChannelTypeOpenAI: "OpenAI", + common.ChannelTypeAnthropic: "Anthropic", + common.ChannelTypeBaidu: "Baidu", + common.ChannelTypePaLM: "Google PaLM", + common.ChannelTypeGemini: "Google Gemini", + common.ChannelTypeZhipu: "Zhipu", + common.ChannelTypeAli: "Ali", + common.ChannelTypeXunfei: "Xunfei", + common.ChannelType360: "360", + common.ChannelTypeTencent: "Tencent", + common.ChannelTypeBaichuan: "Baichuan", } } func ListModels(c *gin.Context) { groupName := c.GetString("group") + if groupName == "" { + id := c.GetInt("id") + user, err := model.GetUserById(id, false) + if err != nil { + common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error()) + return + } + groupName = user.Group + } models, err := model.CacheGetGroupModels(groupName) if err != nil { @@ -83,13 +81,18 @@ func ListModels(c *gin.Context) { Id: modelId, Object: "model", Created: 1677649963, - OwnedBy: nil, + OwnedBy: getModelOwnedBy(modelId), Permission: nil, Root: nil, Parent: nil, }) } + // 根据 OwnedBy 排序 + sort.Slice(groupOpenAIModels, func(i, j int) bool { + return *groupOpenAIModels[i].OwnedBy < *groupOpenAIModels[j].OwnedBy + }) + c.JSON(200, gin.H{ "object": "list", "data": groupOpenAIModels, @@ -97,6 +100,23 @@ func ListModels(c *gin.Context) { } func ListModelsForAdmin(c *gin.Context) { + openAIModels := make([]OpenAIModels, 0, len(common.ModelTypes)) + for modelId := range common.ModelRatio { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelId, + Object: "model", + Created: 1677649963, + OwnedBy: getModelOwnedBy(modelId), + Permission: nil, + Root: nil, + Parent: nil, + }) + } + // 根据 OwnedBy 排序 + sort.Slice(openAIModels, func(i, j int) bool { + return *openAIModels[i].OwnedBy < *openAIModels[j].OwnedBy + }) + c.JSON(200, gin.H{ "object": "list", "data": openAIModels, @@ -105,8 +125,17 @@ func ListModelsForAdmin(c *gin.Context) { func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { - c.JSON(200, model) + ownedByName := getModelOwnedBy(modelId) + if ownedByName != nil { + c.JSON(200, OpenAIModels{ + Id: modelId, + Object: "model", + Created: 1677649963, + OwnedBy: ownedByName, + Permission: nil, + Root: nil, + Parent: nil, + }) } else { openAIError := types.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), @@ -119,3 +148,13 @@ func RetrieveModel(c *gin.Context) { }) } } + +func getModelOwnedBy(modelId string) (ownedBy *string) { + if modelType, ok := common.ModelTypes[modelId]; ok { + if ownedByName, ok := modelOwnedBy[modelType.Type]; ok { + return &ownedByName + } + } + + return +} diff --git a/providers/openai/base.go b/providers/openai/base.go index 7e61ea9f..f37546f7 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -75,7 +75,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) // 检测模型是是否包含 . 如果有则直接去掉 if strings.Contains(requestURL, ".") { requestURL = strings.Replace(requestURL, ".", "", -1) - fmt.Println(requestURL) } } diff --git a/router/api-router.go b/router/api-router.go index 5cb2407e..323abe7b 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -42,6 +42,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/models", controller.ListModels) } adminRoute := userRoute.Group("/") diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js index 233150d8..efbb53ce 100644 --- a/web/src/views/Channel/component/EditModal.js +++ b/web/src/views/Channel/component/EditModal.js @@ -68,7 +68,6 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt); const [groupOptions, setGroupOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); - const [basicModels, setBasicModels] = useState([]); const initChannel = (typeValue) => { if (typeConfig[typeValue]?.inputLabel) { @@ -96,11 +95,28 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { ) { return; } + + if (key === 'models') { + setFieldValue(key, initialModel(newInput[key])); + return; + } setFieldValue(key, newInput[key]); }); } }; + const basicModels = (channelType) => { + let modelGroup = typeConfig[channelType]?.modelGroup || defaultConfig.modelGroup; + // 循环 modelOptions,找到 modelGroup 对应的模型 + let modelList = []; + modelOptions.forEach((model) => { + if (model.group === modelGroup) { + modelList.push(model); + } + }); + return modelList; + }; + const fetchGroups = async () => { try { let res = await API.get(`/api/group/`); @@ -113,13 +129,13 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { const fetchModels = async () => { try { let res = await API.get(`/api/channel/models`); - setModelOptions(res.data.data.map((model) => model.id)); - setBasicModels( - res.data.data - .filter((model) => { - return model.id.startsWith('gpt-3') || model.id.startsWith('gpt-4'); - }) - .map((model) => model.id) + setModelOptions( + res.data.data.map((model) => { + return { + id: model.id, + group: model.owned_by + }; + }) ); } catch (error) { showError(error.message); @@ -138,12 +154,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { values.other = 'v2.1'; } let res; - values.models = values.models.join(','); + const modelsStr = values.models.map((model) => model.id).join(','); values.group = values.groups.join(','); if (channelId) { - res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId) }); + res = await API.put(`/api/channel/`, { ...values, id: parseInt(channelId), models: modelsStr }); } else { - res = await API.post(`/api/channel/`, values); + res = await API.post(`/api/channel/`, { ...values, models: modelsStr }); } const { success, message } = res.data; if (success) { @@ -157,11 +173,30 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { onOk(true); } else { setStatus({ success: false }); - // showError(message); + showError(message); setErrors({ submit: message }); } }; + function initialModel(channelModel) { + if (!channelModel) { + return []; + } + + // 如果 channelModel 是一个字符串 + if (typeof channelModel === 'string') { + channelModel = channelModel.split(','); + } + let modelList = channelModel.map((model) => { + const modelOption = modelOptions.find((option) => option.id === model); + if (modelOption) { + return modelOption; + } + return { id: model, group: '自定义:点击或回车输入' }; + }); + return modelList; + } + const loadChannel = async () => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; @@ -169,7 +204,7 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { if (data.models === '') { data.models = []; } else { - data.models = data.models.split(','); + data.models = initialModel(data.models); } if (data.group === '') { data.groups = []; @@ -348,12 +383,12 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { freeSolo id="channel-models-label" options={modelOptions} - value={Array.isArray(values.models) ? values.models : values.models.split(',')} + value={values.models} onChange={(e, value) => { const event = { target: { name: 'models', - value: value + value: value.map((item) => (typeof item === 'string' ? { id: item, group: '自定义:点击或回车输入' } : item)) } }; handleChange(event); @@ -361,12 +396,25 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { onBlur={handleBlur} filterSelectedOptions renderInput={(params) => } + groupBy={(option) => option.group} + getOptionLabel={(option) => { + if (typeof option === 'string') { + return option; + } + if (option.inputValue) { + return option.inputValue; + } + return option.id; + }} filterOptions={(options, params) => { const filtered = filter(options, params); const { inputValue } = params; - const isExisting = options.some((option) => inputValue === option); + const isExisting = options.some((option) => inputValue === option.id); if (inputValue !== '' && !isExisting) { - filtered.push(inputValue); + filtered.push({ + id: inputValue, + group: '自定义:点击或回车输入' + }); } return filtered; }} @@ -387,10 +435,10 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {