From 8b9813d63b3e3303b7ac1f8a8944a7131b0d5ca6 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 02:44:59 +0800 Subject: [PATCH] feat: /v1/models now only return available models --- controller/model.go | 35 ++++++++++++++++++++++++++++++++++- middleware/auth.go | 3 ++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/controller/model.go b/controller/model.go index bf4b83a7..53649391 100644 --- a/controller/model.go +++ b/controller/model.go @@ -11,6 +11,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -121,9 +122,41 @@ func DashboardListModels(c *gin.Context) { } func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + var availableOpenAIModels []OpenAIModels + for _, model := range openAIModels { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": availableOpenAIModels, }) } diff --git a/middleware/auth.go b/middleware/auth.go index 443199d0..29701524 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -115,7 +115,8 @@ func TokenAuth() func(c *gin.Context) { } c.Set("request_model", requestModel) if token.Models != nil && *token.Models != "" { - if !isModelInList(requestModel, *token.Models) { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) return }