From 7263582b9bb914a9e58a67261966d8c8e03f52c6 Mon Sep 17 00:00:00 2001
From: Buer <42402987+MartialBE@users.noreply.github.com>
Date: Thu, 16 May 2024 15:21:13 +0800
Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20Add=20support=20for=20retri?=
=?UTF-8?q?eving=20model=20list=20from=20providers=20(#188)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* ✨ feat: Add support for retrieving model list from providers
* 🔖 chore: Custom channel automatically get the model
---
controller/channel-model.go | 58 +++++++
providers/base/common.go | 1 +
providers/base/interface.go | 10 ++
providers/cohere/base.go | 1 +
providers/cohere/model.go | 35 ++++
providers/cohere/type.go | 164 ++++++++++++++++---
providers/deepseek/base.go | 1 +
providers/gemini/base.go | 1 +
providers/gemini/model.go | 45 +++++
providers/gemini/type.go | 9 +
providers/groq/base.go | 1 +
providers/mistral/base.go | 1 +
providers/mistral/model.go | 29 ++++
providers/mistral/type.go | 12 ++
providers/openai/base.go | 1 +
providers/openai/model.go | 29 ++++
providers/openai/type.go | 12 ++
router/api-router.go | 1 +
web/src/views/Channel/component/EditModal.js | 33 +++-
web/src/views/Channel/type/Config.js | 31 +++-
20 files changed, 444 insertions(+), 31 deletions(-)
create mode 100644 controller/channel-model.go
create mode 100644 providers/cohere/model.go
create mode 100644 providers/gemini/model.go
create mode 100644 providers/mistral/model.go
create mode 100644 providers/openai/model.go
diff --git a/controller/channel-model.go b/controller/channel-model.go
new file mode 100644
index 00000000..8583a891
--- /dev/null
+++ b/controller/channel-model.go
@@ -0,0 +1,58 @@
+package controller
+
+import (
+ "net/http"
+ "one-api/model"
+ "one-api/providers"
+ providersBase "one-api/providers/base"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetModelList(c *gin.Context) {
+ channel := model.Channel{}
+ err := c.ShouldBindJSON(&channel)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ keys := strings.Split(channel.Key, "\n")
+ channel.Key = keys[0]
+
+ provider := providers.GetProvider(&channel, c)
+ if provider == nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "provider not found",
+ })
+ return
+ }
+
+ modelProvider, ok := provider.(providersBase.ModelListInterface)
+ if !ok {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "channel not implemented",
+ })
+ return
+ }
+
+ modelList, err := modelProvider.GetModelList()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": modelList,
+ })
+}
diff --git a/providers/base/common.go b/providers/base/common.go
index 1c486ab3..37429880 100644
--- a/providers/base/common.go
+++ b/providers/base/common.go
@@ -25,6 +25,7 @@ type ProviderConfig struct {
ImagesGenerations string
ImagesEdit string
ImagesVariations string
+ ModelList string
}
type BaseProvider struct {
diff --git a/providers/base/interface.go b/providers/base/interface.go
index 5618c154..32f6a0be 100644
--- a/providers/base/interface.go
+++ b/providers/base/interface.go
@@ -99,6 +99,16 @@ type ImageVariationsInterface interface {
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
}
+// type RelayInterface interface {
+// ProviderInterface
+// CreateRelay() (*http.Response, *types.OpenAIErrorWithStatusCode)
+// }
+
+type ModelListInterface interface {
+ ProviderInterface
+ GetModelList() ([]string, error)
+}
+
// 余额接口
type BalanceInterface interface {
Balance() (float64, error)
diff --git a/providers/cohere/base.go b/providers/cohere/base.go
index 352657a5..d76f8417 100644
--- a/providers/cohere/base.go
+++ b/providers/cohere/base.go
@@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.cohere.ai/v1",
ChatCompletions: "/chat",
+ ModelList: "/models",
}
}
diff --git a/providers/cohere/model.go b/providers/cohere/model.go
new file mode 100644
index 00000000..ceaa9a2a
--- /dev/null
+++ b/providers/cohere/model.go
@@ -0,0 +1,35 @@
+package cohere
+
+import (
+ "errors"
+ "net/http"
+ "net/url"
+)
+
+func (p *CohereProvider) GetModelList() ([]string, error) {
+ params := url.Values{}
+ params.Add("page_size", "1000")
+ params.Add("endpoint", "chat")
+ queryString := params.Encode()
+
+ fullRequestURL := p.GetFullRequestURL(p.Config.ModelList) + "?" + queryString
+ headers := p.GetRequestHeaders()
+
+ req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
+ if err != nil {
+ return nil, errors.New("new_request_failed")
+ }
+
+ response := &ModelListResponse{}
+ _, errWithCode := p.Requester.SendRequest(req, response, false)
+ if errWithCode != nil {
+ return nil, errors.New(errWithCode.Message)
+ }
+
+ var modelList []string
+ for _, model := range response.Models {
+ modelList = append(modelList, model.Name)
+ }
+
+ return modelList, nil
+}
diff --git a/providers/cohere/type.go b/providers/cohere/type.go
index c8fa9852..74d7c870 100644
--- a/providers/cohere/type.go
+++ b/providers/cohere/type.go
@@ -15,28 +15,34 @@ type CohereConnector struct {
}
type CohereRequest struct {
- Message string `json:"message"`
- Model string `json:"model,omitempty"`
- Stream bool `json:"stream,omitempty"`
- Preamble string `json:"preamble,omitempty"`
- ChatHistory []ChatHistory `json:"chat_history,omitempty"`
- ConversationId string `json:"conversation_id,omitempty"`
- PromptTruncation string `json:"prompt_truncation,omitempty"`
- Connectors []CohereConnector `json:"connectors,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- MaxTokens int `json:"max_tokens,omitempty"`
- MaxInputTokens int `json:"max_input_tokens,omitempty"`
- K int `json:"k,omitempty"`
- P float64 `json:"p,omitempty"`
- Seed *int `json:"seed,omitempty"`
- StopSequences any `json:"stop_sequences,omitempty"`
- FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty float64 `json:"presence_penalty,omitempty"`
- Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
- ToolResults any `json:"tool_results,omitempty"`
- // SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
+ Message string `json:"message"`
+ Model string `json:"model,omitempty"`
+ Stream bool `json:"stream,omitempty"`
+ Preamble string `json:"preamble,omitempty"`
+ ChatHistory []ChatHistory `json:"chat_history,omitempty"`
+ ConversationId string `json:"conversation_id,omitempty"`
+ PromptTruncation string `json:"prompt_truncation,omitempty"`
+ Connectors []CohereConnector `json:"connectors,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ MaxInputTokens int `json:"max_input_tokens,omitempty"`
+ K int `json:"k,omitempty"`
+ P float64 `json:"p,omitempty"`
+ Seed *int `json:"seed,omitempty"`
+ StopSequences any `json:"stop_sequences,omitempty"`
+ FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty float64 `json:"presence_penalty,omitempty"`
+ Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
+ ToolResults any `json:"tool_results,omitempty"`
+ SearchQueriesOnly *bool `json:"search_queries_only,omitempty"`
+ Documents []ChatDocument `json:"documents,omitempty"`
+ CitationQuality *string `json:"citation_quality,omitempty"`
+ RawPrompting *bool `json:"raw_prompting,omitempty"`
+ ReturnPrompt *bool `json:"return_prompt,omitempty"`
}
+type ChatDocument = map[string]string
+
type APIVersion struct {
Version string `json:"version"`
}
@@ -60,16 +66,46 @@ type CohereToolCall struct {
}
type CohereResponse struct {
- Text string `json:"text,omitempty"`
- ResponseID string `json:"response_id,omitempty"`
- GenerationID string `json:"generation_id,omitempty"`
- ChatHistory []ChatHistory `json:"chat_history,omitempty"`
- FinishReason string `json:"finish_reason,omitempty"`
- ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
- Meta Meta `json:"meta,omitempty"`
+ Text string `json:"text,omitempty"`
+ ResponseID string `json:"response_id,omitempty"`
+ Citations []*ChatCitation `json:"citations,omitempty"`
+ Documents []ChatDocument `json:"documents,omitempty"`
+ IsSearchRequired *bool `json:"is_search_required,omitempty"`
+ SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty"`
+ SearchResults []*ChatSearchResult `json:"search_results,omitempty"`
+ GenerationID string `json:"generation_id,omitempty"`
+ ChatHistory []ChatHistory `json:"chat_history,omitempty"`
+ Prompt *string `json:"prompt,omitempty"`
+ FinishReason string `json:"finish_reason,omitempty"`
+ ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
+ Meta Meta `json:"meta,omitempty"`
CohereError
}
+type ChatCitation struct {
+ Start int `json:"start"`
+ End int `json:"end"`
+ Text string `json:"text"`
+ DocumentIds []string `json:"document_ids,omitempty"`
+}
+
+type ChatSearchQuery struct {
+ Text string `json:"text"`
+ GenerationId string `json:"generation_id"`
+}
+
+type ChatSearchResult struct {
+ SearchQuery *ChatSearchQuery `json:"search_query,omitempty" url:"search_query,omitempty"`
+ Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"`
+ DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"`
+ ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"`
+ ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"`
+}
+
+type ChatSearchResultConnector struct {
+ Id string `json:"id" url:"id"`
+}
+
type CohereError struct {
Message string `json:"message,omitempty"`
}
@@ -83,3 +119,77 @@ type CohereStreamResponse struct {
FinishReason string `json:"finish_reason,omitempty"`
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
}
+
+type RerankRequest struct {
+ Model *string `json:"model,omitempty"`
+ Query string `json:"query" url:"query"`
+ Documents []*RerankRequestDocumentsItem `json:"documents,omitempty"`
+ TopN *int `json:"top_n,omitempty"`
+ RankFields []string `json:"rank_fields,omitempty"`
+ ReturnDocuments *bool `json:"return_documents,omitempty"`
+ MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"`
+}
+
+type RerankRequestDocumentsItem struct {
+ String string
+ RerankRequestDocumentsItemText *RerankDocumentsItemText
+}
+type RerankDocumentsItemText struct {
+ Text string `json:"text"`
+}
+
+type RerankResponse struct {
+ Id *string `json:"id,omitempty"`
+ Results []*RerankResponseResultsItem `json:"results,omitempty"`
+ Meta *Meta `json:"meta,omitempty"`
+}
+
+type RerankResponseResultsItem struct {
+ Document *RerankDocumentsItemText `json:"document,omitempty"`
+ Index int `json:"index"`
+ RelevanceScore float64 `json:"relevance_score"`
+}
+
+type EmbedRequest struct {
+ Texts any `json:"texts,omitempty"`
+ Model *string `json:"model,omitempty"`
+ InputType *string `json:"input_type,omitempty"`
+ EmbeddingTypes []string `json:"embedding_types,omitempty"`
+ Truncate *string `json:"truncate,omitempty"`
+}
+
+type EmbedResponse struct {
+ ResponseType string `json:"response_type"`
+ Embeddings any `json:"embeddings"`
+}
+
+type EmbedFloatsResponse struct {
+ Id string `json:"id"`
+ Embeddings [][]float64 `json:"embeddings,omitempty"`
+ Texts []string `json:"texts,omitempty"`
+ Meta *Meta `json:"meta,omitempty"`
+}
+
+type EmbedByTypeResponse struct {
+ Id string `json:"id"`
+ Embeddings *EmbedByTypeResponseEmbeddings `json:"embeddings,omitempty"`
+ Texts []string `json:"texts,omitempty"`
+ Meta *Meta `json:"meta,omitempty"`
+}
+
+type EmbedByTypeResponseEmbeddings struct {
+ Float [][]float64 `json:"float,omitempty"`
+ Int8 [][]int `json:"int8,omitempty"`
+ Uint8 [][]int `json:"uint8,omitempty"`
+ Binary [][]int `json:"binary,omitempty"`
+ Ubinary [][]int `json:"ubinary,omitempty"`
+}
+
+type ModelListResponse struct {
+ Models []ModelDetails `json:"models"`
+}
+
+type ModelDetails struct {
+ Name string `json:"name"`
+ Endpoints []string `json:"endpoints"`
+}
diff --git a/providers/deepseek/base.go b/providers/deepseek/base.go
index 5f0936d7..a4cdefef 100644
--- a/providers/deepseek/base.go
+++ b/providers/deepseek/base.go
@@ -28,6 +28,7 @@ func getDeepseekConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.deepseek.com",
ChatCompletions: "/v1/chat/completions",
+ ModelList: "/v1/models",
}
}
diff --git a/providers/gemini/base.go b/providers/gemini/base.go
index 3b2b4fb6..364f9f80 100644
--- a/providers/gemini/base.go
+++ b/providers/gemini/base.go
@@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://generativelanguage.googleapis.com",
ChatCompletions: "/",
+ ModelList: "/models",
}
}
diff --git a/providers/gemini/model.go b/providers/gemini/model.go
new file mode 100644
index 00000000..2a24d73a
--- /dev/null
+++ b/providers/gemini/model.go
@@ -0,0 +1,45 @@
+package gemini
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+func (p *GeminiProvider) GetModelList() ([]string, error) {
+ params := url.Values{}
+ params.Add("page_size", "1000")
+ queryString := params.Encode()
+
+ baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
+ version := "v1beta"
+ fullRequestURL := fmt.Sprintf("%s/%s%s?%s", baseURL, version, p.Config.ModelList, queryString)
+
+ headers := p.GetRequestHeaders()
+
+ req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
+ if err != nil {
+ return nil, errors.New("new_request_failed")
+ }
+
+ response := &ModelListResponse{}
+ _, errWithCode := p.Requester.SendRequest(req, response, false)
+ if errWithCode != nil {
+ return nil, errors.New(errWithCode.Message)
+ }
+
+ var modelList []string
+ for _, model := range response.Models {
+ for _, modelType := range model.SupportedGenerationMethods {
+ if modelType == "generateContent" {
+ modelName := strings.TrimPrefix(model.Name, "models/")
+ modelList = append(modelList, modelName)
+ break
+ }
+ }
+ }
+
+ return modelList, nil
+}
diff --git a/providers/gemini/type.go b/providers/gemini/type.go
index d653dccf..65c7c9ce 100644
--- a/providers/gemini/type.go
+++ b/providers/gemini/type.go
@@ -218,3 +218,12 @@ func ConvertRole(roleName string) string {
return types.ChatMessageRoleUser
}
}
+
+type ModelListResponse struct {
+ Models []ModelDetails `json:"models"`
+}
+
+type ModelDetails struct {
+ Name string `json:"name"`
+ SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
+}
diff --git a/providers/groq/base.go b/providers/groq/base.go
index 0152b8fd..c72a302c 100644
--- a/providers/groq/base.go
+++ b/providers/groq/base.go
@@ -27,6 +27,7 @@ func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.groq.com/openai",
ChatCompletions: "/v1/chat/completions",
+ ModelList: "/v1/models",
}
}
diff --git a/providers/mistral/base.go b/providers/mistral/base.go
index 66a8c234..f69b3a55 100644
--- a/providers/mistral/base.go
+++ b/providers/mistral/base.go
@@ -41,6 +41,7 @@ func getMistralConfig(baseURL string) base.ProviderConfig {
BaseURL: baseURL,
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
+ ModelList: "/v1/models",
}
}
diff --git a/providers/mistral/model.go b/providers/mistral/model.go
new file mode 100644
index 00000000..69204814
--- /dev/null
+++ b/providers/mistral/model.go
@@ -0,0 +1,29 @@
+package mistral
+
+import (
+ "errors"
+ "net/http"
+)
+
+func (p *MistralProvider) GetModelList() ([]string, error) {
+ fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
+ headers := p.GetRequestHeaders()
+
+ req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
+ if err != nil {
+ return nil, errors.New("new_request_failed")
+ }
+
+ response := &ModelListResponse{}
+ _, errWithCode := p.Requester.SendRequest(req, response, false)
+ if errWithCode != nil {
+ return nil, errors.New(errWithCode.Message)
+ }
+
+ var modelList []string
+ for _, model := range response.Data {
+ modelList = append(modelList, model.Id)
+ }
+
+ return modelList, nil
+}
diff --git a/providers/mistral/type.go b/providers/mistral/type.go
index 17fff033..9b2b01ef 100644
--- a/providers/mistral/type.go
+++ b/providers/mistral/type.go
@@ -53,3 +53,15 @@ type ChatCompletionStreamResponse struct {
types.ChatCompletionStreamResponse
Usage *types.Usage `json:"usage,omitempty"`
}
+
+type ModelListResponse struct {
+ Object string `json:"object"`
+ Data []ModelDetails `json:"data"`
+}
+
+type ModelDetails struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ OwnedBy string `json:"owned_by"`
+}
diff --git a/providers/openai/base.go b/providers/openai/base.go
index 9121af47..0a3eebbc 100644
--- a/providers/openai/base.go
+++ b/providers/openai/base.go
@@ -57,6 +57,7 @@ func getOpenAIConfig(baseURL string) base.ProviderConfig {
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
+ ModelList: "/v1/models",
}
}
diff --git a/providers/openai/model.go b/providers/openai/model.go
new file mode 100644
index 00000000..f6b4ade0
--- /dev/null
+++ b/providers/openai/model.go
@@ -0,0 +1,29 @@
+package openai
+
+import (
+ "errors"
+ "net/http"
+)
+
+func (p *OpenAIProvider) GetModelList() ([]string, error) {
+ fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
+ headers := p.GetRequestHeaders()
+
+ req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
+ if err != nil {
+ return nil, errors.New("new_request_failed")
+ }
+
+ response := &ModelListResponse{}
+ _, errWithCode := p.Requester.SendRequest(req, response, false)
+ if errWithCode != nil {
+ return nil, errors.New(errWithCode.Message)
+ }
+
+ var modelList []string
+ for _, model := range response.Data {
+ modelList = append(modelList, model.Id)
+ }
+
+ return modelList, nil
+}
diff --git a/providers/openai/type.go b/providers/openai/type.go
index 0670dca9..ce82ba5f 100644
--- a/providers/openai/type.go
+++ b/providers/openai/type.go
@@ -73,3 +73,15 @@ type OpenAIUsageResponse struct {
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
}
+
+type ModelListResponse struct {
+ Object string `json:"object"`
+ Data []ModelDetails `json:"data"`
+}
+
+type ModelDetails struct {
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ OwnedBy string `json:"owned_by"`
+}
diff --git a/router/api-router.go b/router/api-router.go
index dcc8dc08..4bc0e3d8 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -78,6 +78,7 @@ func SetApiRouter(router *gin.Engine) {
{
channelRoute.GET("/", controller.GetChannelsList)
channelRoute.GET("/models", relay.ListModelsForAdmin)
+ channelRoute.POST("/provider_models_list", controller.GetModelList)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestAllChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
diff --git a/web/src/views/Channel/component/EditModal.js b/web/src/views/Channel/component/EditModal.js
index 99730822..3229227b 100644
--- a/web/src/views/Channel/component/EditModal.js
+++ b/web/src/views/Channel/component/EditModal.js
@@ -24,8 +24,10 @@ import {
Checkbox,
Switch,
FormControlLabel,
- Typography
+ Typography,
+ Tooltip
} from '@mui/material';
+import LoadingButton from '@mui/lab/LoadingButton';
import { Formik } from 'formik';
import * as Yup from 'yup';
@@ -78,6 +80,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
const [modelOptions, setModelOptions] = useState([]);
const [batchAdd, setBatchAdd] = useState(false);
+ const [providerModelsLoad, setProviderModelsLoad] = useState(false);
const initChannel = (typeValue) => {
if (typeConfig[typeValue]?.inputLabel) {
@@ -144,6 +147,22 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
return modelList;
};
+ const getProviderModels = async (values, setFieldValue) => {
+ setProviderModelsLoad(true);
+ try {
+ const res = await API.post(`/api/channel/provider_models_list`, { ...values, models: '' });
+ const { success, message, data } = res.data;
+ if (success && data) {
+ setFieldValue('models', data);
+ } else {
+ showError(message || '获取模型列表失败');
+ }
+ } catch (error) {
+ showError(error.message);
+ }
+ setProviderModelsLoad(false);
+ };
+
const fetchModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
@@ -505,6 +524,18 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
>
填入所有模型
+ {inputLabel.provider_models_list && (
+
+ {
+ getProviderModels(values, setFieldValue);
+ }}
+ >
+ {inputLabel.provider_models_list}
+
+
+ )}
diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js
index 09d5537f..6a6accb5 100644
--- a/web/src/views/Channel/type/Config.js
+++ b/web/src/views/Channel/type/Config.js
@@ -24,7 +24,8 @@ const defaultConfig = {
models: '模型',
model_mapping: '模型映射关系',
groups: '用户组',
- only_chat: '仅支持聊天'
+ only_chat: '仅支持聊天',
+ provider_models_list: ''
},
prompt: {
type: '请选择渠道类型',
@@ -39,12 +40,23 @@ const defaultConfig = {
model_mapping:
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
groups: '请选择该渠道所支持的用户组',
- only_chat: '如果选择了仅支持聊天,那么遇到有函数调用的请求会跳过该渠道'
+ only_chat: '如果选择了仅支持聊天,那么遇到有函数调用的请求会跳过该渠道',
+ provider_models_list: '必须填写所有数据后才能获取模型列表'
},
modelGroup: 'OpenAI'
};
const typeConfig = {
+ 1: {
+ inputLabel: {
+ provider_models_list: '从OpenAI获取模型列表'
+ }
+ },
+ 8: {
+ inputLabel: {
+ provider_models_list: '从渠道获取模型列表'
+ }
+ },
3: {
inputLabel: {
base_url: 'AZURE_OPENAI_ENDPOINT',
@@ -143,7 +155,8 @@ const typeConfig = {
},
25: {
inputLabel: {
- other: '版本号'
+ other: '版本号',
+ provider_models_list: '从Gemini获取模型列表'
},
input: {
models: ['gemini-pro', 'gemini-pro-vision', 'gemini-1.0-pro', 'gemini-1.5-pro'],
@@ -189,6 +202,9 @@ const typeConfig = {
models: ['deepseek-coder', 'deepseek-chat'],
test_model: 'deepseek-chat'
},
+ inputLabel: {
+ provider_models_list: '从Deepseek获取模型列表'
+ },
modelGroup: 'Deepseek'
},
29: {
@@ -210,6 +226,9 @@ const typeConfig = {
],
test_model: 'open-mistral-7b'
},
+ inputLabel: {
+ provider_models_list: '从Mistral获取模型列表'
+ },
modelGroup: 'Mistral'
},
31: {
@@ -217,6 +236,9 @@ const typeConfig = {
models: ['llama2-7b-2048', 'llama2-70b-4096', 'mixtral-8x7b-32768', 'gemma-7b-it'],
test_model: 'llama2-7b-2048'
},
+ inputLabel: {
+ provider_models_list: '从Groq获取模型列表'
+ },
modelGroup: 'Groq'
},
32: {
@@ -297,6 +319,9 @@ const typeConfig = {
models: ['command-r', 'command-r-plus'],
test_model: 'command-r'
},
+ inputLabel: {
+ provider_models_list: '从Cohere获取模型列表'
+ },
modelGroup: 'Cohere'
},
37: {