✨ feat: Add support for retrieving model list from providers (#188)
* ✨ feat: Add support for retrieving model list from providers * 🔖 chore: Custom channel automatically get the model
This commit is contained in:
parent
ef63fbfd31
commit
7263582b9b
58
controller/channel-model.go
Normal file
58
controller/channel-model.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/providers"
|
||||||
|
providersBase "one-api/providers/base"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetModelList(c *gin.Context) {
|
||||||
|
channel := model.Channel{}
|
||||||
|
err := c.ShouldBindJSON(&channel)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keys := strings.Split(channel.Key, "\n")
|
||||||
|
channel.Key = keys[0]
|
||||||
|
|
||||||
|
provider := providers.GetProvider(&channel, c)
|
||||||
|
if provider == nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "provider not found",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelProvider, ok := provider.(providersBase.ModelListInterface)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "channel not implemented",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList, err := modelProvider.GetModelList()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": modelList,
|
||||||
|
})
|
||||||
|
}
|
@ -25,6 +25,7 @@ type ProviderConfig struct {
|
|||||||
ImagesGenerations string
|
ImagesGenerations string
|
||||||
ImagesEdit string
|
ImagesEdit string
|
||||||
ImagesVariations string
|
ImagesVariations string
|
||||||
|
ModelList string
|
||||||
}
|
}
|
||||||
|
|
||||||
type BaseProvider struct {
|
type BaseProvider struct {
|
||||||
|
@ -99,6 +99,16 @@ type ImageVariationsInterface interface {
|
|||||||
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// type RelayInterface interface {
|
||||||
|
// ProviderInterface
|
||||||
|
// CreateRelay() (*http.Response, *types.OpenAIErrorWithStatusCode)
|
||||||
|
// }
|
||||||
|
|
||||||
|
type ModelListInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
GetModelList() ([]string, error)
|
||||||
|
}
|
||||||
|
|
||||||
// 余额接口
|
// 余额接口
|
||||||
type BalanceInterface interface {
|
type BalanceInterface interface {
|
||||||
Balance() (float64, error)
|
Balance() (float64, error)
|
||||||
|
@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
|
|||||||
return base.ProviderConfig{
|
return base.ProviderConfig{
|
||||||
BaseURL: "https://api.cohere.ai/v1",
|
BaseURL: "https://api.cohere.ai/v1",
|
||||||
ChatCompletions: "/chat",
|
ChatCompletions: "/chat",
|
||||||
|
ModelList: "/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
35
providers/cohere/model.go
Normal file
35
providers/cohere/model.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package cohere
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *CohereProvider) GetModelList() ([]string, error) {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("page_size", "1000")
|
||||||
|
params.Add("endpoint", "chat")
|
||||||
|
queryString := params.Encode()
|
||||||
|
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList) + "?" + queryString
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("new_request_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &ModelListResponse{}
|
||||||
|
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errors.New(errWithCode.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelList []string
|
||||||
|
for _, model := range response.Models {
|
||||||
|
modelList = append(modelList, model.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelList, nil
|
||||||
|
}
|
@ -15,28 +15,34 @@ type CohereConnector struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CohereRequest struct {
|
type CohereRequest struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Preamble string `json:"preamble,omitempty"`
|
Preamble string `json:"preamble,omitempty"`
|
||||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||||
ConversationId string `json:"conversation_id,omitempty"`
|
ConversationId string `json:"conversation_id,omitempty"`
|
||||||
PromptTruncation string `json:"prompt_truncation,omitempty"`
|
PromptTruncation string `json:"prompt_truncation,omitempty"`
|
||||||
Connectors []CohereConnector `json:"connectors,omitempty"`
|
Connectors []CohereConnector `json:"connectors,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
MaxInputTokens int `json:"max_input_tokens,omitempty"`
|
||||||
K int `json:"k,omitempty"`
|
K int `json:"k,omitempty"`
|
||||||
P float64 `json:"p,omitempty"`
|
P float64 `json:"p,omitempty"`
|
||||||
Seed *int `json:"seed,omitempty"`
|
Seed *int `json:"seed,omitempty"`
|
||||||
StopSequences any `json:"stop_sequences,omitempty"`
|
StopSequences any `json:"stop_sequences,omitempty"`
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
|
Tools []*types.ChatCompletionFunction `json:"tools,omitempty"`
|
||||||
ToolResults any `json:"tool_results,omitempty"`
|
ToolResults any `json:"tool_results,omitempty"`
|
||||||
// SearchQueriesOnly bool `json:"search_queries_only,omitempty"`
|
SearchQueriesOnly *bool `json:"search_queries_only,omitempty"`
|
||||||
|
Documents []ChatDocument `json:"documents,omitempty"`
|
||||||
|
CitationQuality *string `json:"citation_quality,omitempty"`
|
||||||
|
RawPrompting *bool `json:"raw_prompting,omitempty"`
|
||||||
|
ReturnPrompt *bool `json:"return_prompt,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatDocument = map[string]string
|
||||||
|
|
||||||
type APIVersion struct {
|
type APIVersion struct {
|
||||||
Version string `json:"version"`
|
Version string `json:"version"`
|
||||||
}
|
}
|
||||||
@ -60,16 +66,46 @@ type CohereToolCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CohereResponse struct {
|
type CohereResponse struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
ResponseID string `json:"response_id,omitempty"`
|
ResponseID string `json:"response_id,omitempty"`
|
||||||
GenerationID string `json:"generation_id,omitempty"`
|
Citations []*ChatCitation `json:"citations,omitempty"`
|
||||||
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
Documents []ChatDocument `json:"documents,omitempty"`
|
||||||
FinishReason string `json:"finish_reason,omitempty"`
|
IsSearchRequired *bool `json:"is_search_required,omitempty"`
|
||||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
SearchQueries []*ChatSearchQuery `json:"search_queries,omitempty"`
|
||||||
Meta Meta `json:"meta,omitempty"`
|
SearchResults []*ChatSearchResult `json:"search_results,omitempty"`
|
||||||
|
GenerationID string `json:"generation_id,omitempty"`
|
||||||
|
ChatHistory []ChatHistory `json:"chat_history,omitempty"`
|
||||||
|
Prompt *string `json:"prompt,omitempty"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||||
|
Meta Meta `json:"meta,omitempty"`
|
||||||
CohereError
|
CohereError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatCitation struct {
|
||||||
|
Start int `json:"start"`
|
||||||
|
End int `json:"end"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
DocumentIds []string `json:"document_ids,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatSearchQuery struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
GenerationId string `json:"generation_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatSearchResult struct {
|
||||||
|
SearchQuery *ChatSearchQuery `json:"search_query,omitempty" url:"search_query,omitempty"`
|
||||||
|
Connector *ChatSearchResultConnector `json:"connector,omitempty" url:"connector,omitempty"`
|
||||||
|
DocumentIds []string `json:"document_ids,omitempty" url:"document_ids,omitempty"`
|
||||||
|
ErrorMessage *string `json:"error_message,omitempty" url:"error_message,omitempty"`
|
||||||
|
ContinueOnFailure *bool `json:"continue_on_failure,omitempty" url:"continue_on_failure,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatSearchResultConnector struct {
|
||||||
|
Id string `json:"id" url:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
type CohereError struct {
|
type CohereError struct {
|
||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
}
|
}
|
||||||
@ -83,3 +119,77 @@ type CohereStreamResponse struct {
|
|||||||
FinishReason string `json:"finish_reason,omitempty"`
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []CohereToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RerankRequest struct {
|
||||||
|
Model *string `json:"model,omitempty"`
|
||||||
|
Query string `json:"query" url:"query"`
|
||||||
|
Documents []*RerankRequestDocumentsItem `json:"documents,omitempty"`
|
||||||
|
TopN *int `json:"top_n,omitempty"`
|
||||||
|
RankFields []string `json:"rank_fields,omitempty"`
|
||||||
|
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||||
|
MaxChunksPerDoc *int `json:"max_chunks_per_doc,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankRequestDocumentsItem struct {
|
||||||
|
String string
|
||||||
|
RerankRequestDocumentsItemText *RerankDocumentsItemText
|
||||||
|
}
|
||||||
|
type RerankDocumentsItemText struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponse struct {
|
||||||
|
Id *string `json:"id,omitempty"`
|
||||||
|
Results []*RerankResponseResultsItem `json:"results,omitempty"`
|
||||||
|
Meta *Meta `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RerankResponseResultsItem struct {
|
||||||
|
Document *RerankDocumentsItemText `json:"document,omitempty"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
RelevanceScore float64 `json:"relevance_score"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedRequest struct {
|
||||||
|
Texts any `json:"texts,omitempty"`
|
||||||
|
Model *string `json:"model,omitempty"`
|
||||||
|
InputType *string `json:"input_type,omitempty"`
|
||||||
|
EmbeddingTypes []string `json:"embedding_types,omitempty"`
|
||||||
|
Truncate *string `json:"truncate,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedResponse struct {
|
||||||
|
ResponseType string `json:"response_type"`
|
||||||
|
Embeddings any `json:"embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedFloatsResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Embeddings [][]float64 `json:"embeddings,omitempty"`
|
||||||
|
Texts []string `json:"texts,omitempty"`
|
||||||
|
Meta *Meta `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedByTypeResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Embeddings *EmbedByTypeResponseEmbeddings `json:"embeddings,omitempty"`
|
||||||
|
Texts []string `json:"texts,omitempty"`
|
||||||
|
Meta *Meta `json:"meta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbedByTypeResponseEmbeddings struct {
|
||||||
|
Float [][]float64 `json:"float,omitempty"`
|
||||||
|
Int8 [][]int `json:"int8,omitempty"`
|
||||||
|
Uint8 [][]int `json:"uint8,omitempty"`
|
||||||
|
Binary [][]int `json:"binary,omitempty"`
|
||||||
|
Ubinary [][]int `json:"ubinary,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelListResponse struct {
|
||||||
|
Models []ModelDetails `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Endpoints []string `json:"endpoints"`
|
||||||
|
}
|
||||||
|
@ -28,6 +28,7 @@ func getDeepseekConfig() base.ProviderConfig {
|
|||||||
return base.ProviderConfig{
|
return base.ProviderConfig{
|
||||||
BaseURL: "https://api.deepseek.com",
|
BaseURL: "https://api.deepseek.com",
|
||||||
ChatCompletions: "/v1/chat/completions",
|
ChatCompletions: "/v1/chat/completions",
|
||||||
|
ModelList: "/v1/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ func getConfig() base.ProviderConfig {
|
|||||||
return base.ProviderConfig{
|
return base.ProviderConfig{
|
||||||
BaseURL: "https://generativelanguage.googleapis.com",
|
BaseURL: "https://generativelanguage.googleapis.com",
|
||||||
ChatCompletions: "/",
|
ChatCompletions: "/",
|
||||||
|
ModelList: "/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
45
providers/gemini/model.go
Normal file
45
providers/gemini/model.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *GeminiProvider) GetModelList() ([]string, error) {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("page_size", "1000")
|
||||||
|
queryString := params.Encode()
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||||
|
version := "v1beta"
|
||||||
|
fullRequestURL := fmt.Sprintf("%s/%s%s?%s", baseURL, version, p.Config.ModelList, queryString)
|
||||||
|
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("new_request_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &ModelListResponse{}
|
||||||
|
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errors.New(errWithCode.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelList []string
|
||||||
|
for _, model := range response.Models {
|
||||||
|
for _, modelType := range model.SupportedGenerationMethods {
|
||||||
|
if modelType == "generateContent" {
|
||||||
|
modelName := strings.TrimPrefix(model.Name, "models/")
|
||||||
|
modelList = append(modelList, modelName)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelList, nil
|
||||||
|
}
|
@ -218,3 +218,12 @@ func ConvertRole(roleName string) string {
|
|||||||
return types.ChatMessageRoleUser
|
return types.ChatMessageRoleUser
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelListResponse struct {
|
||||||
|
Models []ModelDetails `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ func getConfig() base.ProviderConfig {
|
|||||||
return base.ProviderConfig{
|
return base.ProviderConfig{
|
||||||
BaseURL: "https://api.groq.com/openai",
|
BaseURL: "https://api.groq.com/openai",
|
||||||
ChatCompletions: "/v1/chat/completions",
|
ChatCompletions: "/v1/chat/completions",
|
||||||
|
ModelList: "/v1/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ func getMistralConfig(baseURL string) base.ProviderConfig {
|
|||||||
BaseURL: baseURL,
|
BaseURL: baseURL,
|
||||||
ChatCompletions: "/v1/chat/completions",
|
ChatCompletions: "/v1/chat/completions",
|
||||||
Embeddings: "/v1/embeddings",
|
Embeddings: "/v1/embeddings",
|
||||||
|
ModelList: "/v1/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
29
providers/mistral/model.go
Normal file
29
providers/mistral/model.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package mistral
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *MistralProvider) GetModelList() ([]string, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("new_request_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &ModelListResponse{}
|
||||||
|
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errors.New(errWithCode.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelList []string
|
||||||
|
for _, model := range response.Data {
|
||||||
|
modelList = append(modelList, model.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelList, nil
|
||||||
|
}
|
@ -53,3 +53,15 @@ type ChatCompletionStreamResponse struct {
|
|||||||
types.ChatCompletionStreamResponse
|
types.ChatCompletionStreamResponse
|
||||||
Usage *types.Usage `json:"usage,omitempty"`
|
Usage *types.Usage `json:"usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelListResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []ModelDetails `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
@ -57,6 +57,7 @@ func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
|||||||
ImagesGenerations: "/v1/images/generations",
|
ImagesGenerations: "/v1/images/generations",
|
||||||
ImagesEdit: "/v1/images/edits",
|
ImagesEdit: "/v1/images/edits",
|
||||||
ImagesVariations: "/v1/images/variations",
|
ImagesVariations: "/v1/images/variations",
|
||||||
|
ModelList: "/v1/models",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
29
providers/openai/model.go
Normal file
29
providers/openai/model.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) GetModelList() ([]string, error) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.Config.ModelList, "")
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
req, err := p.Requester.NewRequest(http.MethodGet, fullRequestURL, p.Requester.WithHeader(headers))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("new_request_failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &ModelListResponse{}
|
||||||
|
_, errWithCode := p.Requester.SendRequest(req, response, false)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return nil, errors.New(errWithCode.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelList []string
|
||||||
|
for _, model := range response.Data {
|
||||||
|
modelList = append(modelList, model.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelList, nil
|
||||||
|
}
|
@ -73,3 +73,15 @@ type OpenAIUsageResponse struct {
|
|||||||
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
//DailyCosts []OpenAIUsageDailyCost `json:"daily_costs"`
|
||||||
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
TotalUsage float64 `json:"total_usage"` // unit: 0.01 dollar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelListResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []ModelDetails `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelDetails struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
@ -78,6 +78,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
{
|
{
|
||||||
channelRoute.GET("/", controller.GetChannelsList)
|
channelRoute.GET("/", controller.GetChannelsList)
|
||||||
channelRoute.GET("/models", relay.ListModelsForAdmin)
|
channelRoute.GET("/models", relay.ListModelsForAdmin)
|
||||||
|
channelRoute.POST("/provider_models_list", controller.GetModelList)
|
||||||
channelRoute.GET("/:id", controller.GetChannel)
|
channelRoute.GET("/:id", controller.GetChannel)
|
||||||
channelRoute.GET("/test", controller.TestAllChannels)
|
channelRoute.GET("/test", controller.TestAllChannels)
|
||||||
channelRoute.GET("/test/:id", controller.TestChannel)
|
channelRoute.GET("/test/:id", controller.TestChannel)
|
||||||
|
@ -24,8 +24,10 @@ import {
|
|||||||
Checkbox,
|
Checkbox,
|
||||||
Switch,
|
Switch,
|
||||||
FormControlLabel,
|
FormControlLabel,
|
||||||
Typography
|
Typography,
|
||||||
|
Tooltip
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
|
import LoadingButton from '@mui/lab/LoadingButton';
|
||||||
|
|
||||||
import { Formik } from 'formik';
|
import { Formik } from 'formik';
|
||||||
import * as Yup from 'yup';
|
import * as Yup from 'yup';
|
||||||
@ -78,6 +80,7 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
|
const [inputPrompt, setInputPrompt] = useState(defaultConfig.prompt);
|
||||||
const [modelOptions, setModelOptions] = useState([]);
|
const [modelOptions, setModelOptions] = useState([]);
|
||||||
const [batchAdd, setBatchAdd] = useState(false);
|
const [batchAdd, setBatchAdd] = useState(false);
|
||||||
|
const [providerModelsLoad, setProviderModelsLoad] = useState(false);
|
||||||
|
|
||||||
const initChannel = (typeValue) => {
|
const initChannel = (typeValue) => {
|
||||||
if (typeConfig[typeValue]?.inputLabel) {
|
if (typeConfig[typeValue]?.inputLabel) {
|
||||||
@ -144,6 +147,22 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
return modelList;
|
return modelList;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const getProviderModels = async (values, setFieldValue) => {
|
||||||
|
setProviderModelsLoad(true);
|
||||||
|
try {
|
||||||
|
const res = await API.post(`/api/channel/provider_models_list`, { ...values, models: '' });
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success && data) {
|
||||||
|
setFieldValue('models', data);
|
||||||
|
} else {
|
||||||
|
showError(message || '获取模型列表失败');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showError(error.message);
|
||||||
|
}
|
||||||
|
setProviderModelsLoad(false);
|
||||||
|
};
|
||||||
|
|
||||||
const fetchModels = async () => {
|
const fetchModels = async () => {
|
||||||
try {
|
try {
|
||||||
let res = await API.get(`/api/channel/models`);
|
let res = await API.get(`/api/channel/models`);
|
||||||
@ -505,6 +524,18 @@ const EditModal = ({ open, channelId, onCancel, onOk, groupOptions }) => {
|
|||||||
>
|
>
|
||||||
填入所有模型
|
填入所有模型
|
||||||
</Button>
|
</Button>
|
||||||
|
{inputLabel.provider_models_list && (
|
||||||
|
<Tooltip title={inputPrompt.provider_models_list} placement="top">
|
||||||
|
<LoadingButton
|
||||||
|
loading={providerModelsLoad}
|
||||||
|
onClick={() => {
|
||||||
|
getProviderModels(values, setFieldValue);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{inputLabel.provider_models_list}
|
||||||
|
</LoadingButton>
|
||||||
|
</Tooltip>
|
||||||
|
)}
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Container>
|
</Container>
|
||||||
<FormControl fullWidth error={Boolean(touched.key && errors.key)} sx={{ ...theme.typography.otherInput }}>
|
<FormControl fullWidth error={Boolean(touched.key && errors.key)} sx={{ ...theme.typography.otherInput }}>
|
||||||
|
@ -24,7 +24,8 @@ const defaultConfig = {
|
|||||||
models: '模型',
|
models: '模型',
|
||||||
model_mapping: '模型映射关系',
|
model_mapping: '模型映射关系',
|
||||||
groups: '用户组',
|
groups: '用户组',
|
||||||
only_chat: '仅支持聊天'
|
only_chat: '仅支持聊天',
|
||||||
|
provider_models_list: ''
|
||||||
},
|
},
|
||||||
prompt: {
|
prompt: {
|
||||||
type: '请选择渠道类型',
|
type: '请选择渠道类型',
|
||||||
@ -39,12 +40,23 @@ const defaultConfig = {
|
|||||||
model_mapping:
|
model_mapping:
|
||||||
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
'请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}',
|
||||||
groups: '请选择该渠道所支持的用户组',
|
groups: '请选择该渠道所支持的用户组',
|
||||||
only_chat: '如果选择了仅支持聊天,那么遇到有函数调用的请求会跳过该渠道'
|
only_chat: '如果选择了仅支持聊天,那么遇到有函数调用的请求会跳过该渠道',
|
||||||
|
provider_models_list: '必须填写所有数据后才能获取模型列表'
|
||||||
},
|
},
|
||||||
modelGroup: 'OpenAI'
|
modelGroup: 'OpenAI'
|
||||||
};
|
};
|
||||||
|
|
||||||
const typeConfig = {
|
const typeConfig = {
|
||||||
|
1: {
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从OpenAI获取模型列表'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
8: {
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从渠道获取模型列表'
|
||||||
|
}
|
||||||
|
},
|
||||||
3: {
|
3: {
|
||||||
inputLabel: {
|
inputLabel: {
|
||||||
base_url: 'AZURE_OPENAI_ENDPOINT',
|
base_url: 'AZURE_OPENAI_ENDPOINT',
|
||||||
@ -143,7 +155,8 @@ const typeConfig = {
|
|||||||
},
|
},
|
||||||
25: {
|
25: {
|
||||||
inputLabel: {
|
inputLabel: {
|
||||||
other: '版本号'
|
other: '版本号',
|
||||||
|
provider_models_list: '从Gemini获取模型列表'
|
||||||
},
|
},
|
||||||
input: {
|
input: {
|
||||||
models: ['gemini-pro', 'gemini-pro-vision', 'gemini-1.0-pro', 'gemini-1.5-pro'],
|
models: ['gemini-pro', 'gemini-pro-vision', 'gemini-1.0-pro', 'gemini-1.5-pro'],
|
||||||
@ -189,6 +202,9 @@ const typeConfig = {
|
|||||||
models: ['deepseek-coder', 'deepseek-chat'],
|
models: ['deepseek-coder', 'deepseek-chat'],
|
||||||
test_model: 'deepseek-chat'
|
test_model: 'deepseek-chat'
|
||||||
},
|
},
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从Deepseek获取模型列表'
|
||||||
|
},
|
||||||
modelGroup: 'Deepseek'
|
modelGroup: 'Deepseek'
|
||||||
},
|
},
|
||||||
29: {
|
29: {
|
||||||
@ -210,6 +226,9 @@ const typeConfig = {
|
|||||||
],
|
],
|
||||||
test_model: 'open-mistral-7b'
|
test_model: 'open-mistral-7b'
|
||||||
},
|
},
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从Mistral获取模型列表'
|
||||||
|
},
|
||||||
modelGroup: 'Mistral'
|
modelGroup: 'Mistral'
|
||||||
},
|
},
|
||||||
31: {
|
31: {
|
||||||
@ -217,6 +236,9 @@ const typeConfig = {
|
|||||||
models: ['llama2-7b-2048', 'llama2-70b-4096', 'mixtral-8x7b-32768', 'gemma-7b-it'],
|
models: ['llama2-7b-2048', 'llama2-70b-4096', 'mixtral-8x7b-32768', 'gemma-7b-it'],
|
||||||
test_model: 'llama2-7b-2048'
|
test_model: 'llama2-7b-2048'
|
||||||
},
|
},
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从Groq获取模型列表'
|
||||||
|
},
|
||||||
modelGroup: 'Groq'
|
modelGroup: 'Groq'
|
||||||
},
|
},
|
||||||
32: {
|
32: {
|
||||||
@ -297,6 +319,9 @@ const typeConfig = {
|
|||||||
models: ['command-r', 'command-r-plus'],
|
models: ['command-r', 'command-r-plus'],
|
||||||
test_model: 'command-r'
|
test_model: 'command-r'
|
||||||
},
|
},
|
||||||
|
inputLabel: {
|
||||||
|
provider_models_list: '从Cohere获取模型列表'
|
||||||
|
},
|
||||||
modelGroup: 'Cohere'
|
modelGroup: 'Cohere'
|
||||||
},
|
},
|
||||||
37: {
|
37: {
|
||||||
|
Loading…
Reference in New Issue
Block a user