From 065da8ef8c8bcbc0a7fc3ae22e39397ffe036b6a Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 00:46:30 +0800 Subject: [PATCH 1/5] fix: fix ali function call (#1242) --- relay/channel/ali/main.go | 2 +- relay/channel/ali/model.go | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/relay/channel/ali/main.go b/relay/channel/ali/main.go index 6fdfa4d4..dd1707ee 100644 --- a/relay/channel/ali/main.go +++ b/relay/channel/ali/main.go @@ -50,8 +50,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, TopK: request.TopK, ResultFormat: "message", + Tools: request.Tools, }, - Tools: request.Tools, } } diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go index e19d427a..3b8a8372 100644 --- a/relay/channel/ali/model.go +++ b/relay/channel/ali/model.go @@ -16,21 +16,21 @@ type Input struct { } type Parameters struct { - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - Seed uint64 `json:"seed,omitempty"` - EnableSearch bool `json:"enable_search,omitempty"` - IncrementalOutput bool `json:"incremental_output,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - ResultFormat string `json:"result_format,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` } type ChatRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Parameters `json:"parameters,omitempty"` - Tools []model.Tool `json:"tools,omitempty"` + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` } type EmbeddingRequest struct { From dc7aaf2de5aaf0000073a7466acbda6fe213c291 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 02:08:18 +0800 Subject: [PATCH 2/5] feat: able to set model limitation for token (close #178) --- controller/model.go | 28 ++++++++++++ controller/token.go | 1 + middleware/auth.go | 13 ++++++ middleware/distributor.go | 38 +++------------- middleware/utils.go | 42 ++++++++++++++++++ model/cache.go | 20 +++++++++ model/channel.go | 30 ++++++++++++- model/token.go | 29 ++++++------ router/api-router.go | 1 + web/default/src/pages/Token/EditToken.js | 56 +++++++++++++++++++++--- 10 files changed, 204 insertions(+), 54 deletions(-) diff --git a/controller/model.go b/controller/model.go index 4c5476b4..bf4b83a7 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -142,3 +143,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt("id") + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/token.go b/controller/token.go index 949931da..c6128534 100644 --- a/controller/token.go +++ b/controller/token.go @@ -216,6 +216,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index 30997efd..443199d0 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" @@ -107,6 +108,18 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } + requestModel, err := getRequestModel(c) + if err != nil { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set("request_model", requestModel) + if token.Models != nil && *token.Models != "" { + if !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) diff --git a/middleware/distributor.go b/middleware/distributor.go index e845c2f8..04489a2b 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,12 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "net/http" "strconv" - "strings" - - "github.com/gin-gonic/gin" ) type ModelRequest struct { @@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) { return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel := c.GetString("request_model") + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的请求") - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..b65b018b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,9 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/cache.go b/model/cache.go index 244fe6ac..cfc5445a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,6 +21,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex diff --git a/model/channel.go b/model/channel.go index fc4905b1..24829bc5 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" @@ -8,6 +9,8 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" + "sort" + "strings" ) type Channel struct { @@ -25,7 +28,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` @@ -202,3 +205,28 @@ func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + var modelsList []string + err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error + if err != nil { + return nil, err + } + set := make(map[string]bool) + for i := 0; i < len(modelsList); i++ { + modelList := strings.Split(modelsList[i], ",") + for _, model := range modelList { + set[model] = true + } + } + modelList := make([]string, 0, len(set)) + for model := range set { + modelList = append(modelList, model) + } + sort.Strings(modelList) + return modelList, err +} diff --git a/model/token.go b/model/token.go index 493e27c9..fef80fcf 100644 --- a/model/token.go +++ b/model/token.go @@ -12,24 +12,25 @@ import ( ) type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` } func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error query := DB.Where("user_id = ?", userId) - + switch order { case "remain_quota": query = query.Order("unlimited_quota desc, remain_quota desc") @@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token default: query = query.Order("id desc") } - + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -121,7 +122,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error return err } diff --git a/router/api-router.go b/router/api-router.go index 5b755ede..4aa6d830 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -43,6 +43,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") diff --git a/web/default/src/pages/Token/EditToken.js b/web/default/src/pages/Token/EditToken.js index 0ab37c29..6bc3ad23 100644 --- a/web/default/src/pages/Token/EditToken.js +++ b/web/default/src/pages/Token/EditToken.js @@ -1,19 +1,21 @@ import React, { useEffect, useState } from 'react'; import { Button, Form, Header, Message, Segment } from 'semantic-ui-react'; -import { useParams, useNavigate } from 'react-router-dom'; -import { API, showError, showSuccess, timestamp2string } from '../../helpers'; -import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render'; +import { useNavigate, useParams } from 'react-router-dom'; +import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers'; +import { renderQuotaWithPrompt } from '../../helpers/render'; const EditToken = () => { const params = useParams(); const tokenId = params.id; const isEdit = tokenId !== undefined; const [loading, setLoading] = useState(isEdit); + const [modelOptions, setModelOptions] = useState([]); const originInputs = { name: '', remain_quota: isEdit ? 0 : 500000, expired_time: -1, - unlimited_quota: false + unlimited_quota: false, + models: [] }; const [inputs, setInputs] = useState(originInputs); const { name, remain_quota, expired_time, unlimited_quota } = inputs; @@ -22,8 +24,8 @@ const EditToken = () => { setInputs((inputs) => ({ ...inputs, [name]: value })); }; const handleCancel = () => { - navigate("/token"); - } + navigate('/token'); + }; const setExpiredTime = (month, day, hour, minute) => { let now = new Date(); let timestamp = now.getTime() / 1000; @@ -50,6 +52,11 @@ const EditToken = () => { if (data.expired_time !== -1) { data.expired_time = timestamp2string(data.expired_time); } + if (data.models === '') { + data.models = []; + } else { + data.models = data.models.split(','); + } setInputs(data); } else { showError(message); @@ -60,8 +67,26 @@ const EditToken = () => { if (isEdit) { loadToken().then(); } + loadAvailableModels().then(); }, []); + const loadAvailableModels = async () => { + let res = await API.get(`/api/user/available_models`); + const { success, message, data } = res.data; + if (success) { + let options = data.map((model) => { + return { + key: model, + text: model, + value: model + }; + }); + setModelOptions(options); + } else { + showError(message); + } + }; + const submit = async () => { if (!isEdit && inputs.name === '') return; let localInputs = inputs; @@ -74,6 +99,7 @@ const EditToken = () => { } localInputs.expired_time = Math.ceil(time / 1000); } + localInputs.models = localInputs.models.join(','); let res; if (isEdit) { res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) }); @@ -109,6 +135,24 @@ const EditToken = () => { required={!isEdit} /> + + { + copy(value).then(); + }} + selection + onChange={handleInputChange} + value={inputs.models} + autoComplete='new-password' + options={modelOptions} + /> + Date: Thu, 4 Apr 2024 02:44:59 +0800 Subject: [PATCH 3/5] feat: /v1/models now only return available models --- controller/model.go | 35 ++++++++++++++++++++++++++++++++++- middleware/auth.go | 3 ++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/controller/model.go b/controller/model.go index bf4b83a7..53649391 100644 --- a/controller/model.go +++ b/controller/model.go @@ -11,6 +11,7 @@ import ( relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -121,9 +122,41 @@ func DashboardListModels(c *gin.Context) { } func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + var availableOpenAIModels []OpenAIModels + for _, model := range openAIModels { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": availableOpenAIModels, }) } diff --git a/middleware/auth.go b/middleware/auth.go index 443199d0..29701524 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -115,7 +115,8 @@ func TokenAuth() func(c *gin.Context) { } c.Set("request_model", requestModel) if token.Models != nil && *token.Models != "" { - if !isModelInList(requestModel, *token.Models) { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) return } From 8b9fa3d6e452fbc95bfc37db836c69ed39f3f094 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 02:58:21 +0800 Subject: [PATCH 4/5] fix: fix GetGroupModels --- model/ability.go | 18 ++++++++++++++++++ model/channel.go | 30 +----------------------------- 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/model/ability.go b/model/ability.go index 48b856a2..4a48bc51 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" "gorm.io/gorm" + "sort" "strings" ) @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/channel.go b/model/channel.go index 24829bc5..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,6 @@ package model import ( - "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" @@ -9,8 +8,6 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "sort" - "strings" ) type Channel struct { @@ -28,7 +25,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` @@ -205,28 +202,3 @@ func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } - -func GetGroupModels(ctx context.Context, group string) ([]string, error) { - groupCol := "`group`" - if common.UsingPostgreSQL { - groupCol = `"group"` - } - var modelsList []string - err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error - if err != nil { - return nil, err - } - set := make(map[string]bool) - for i := 0; i < len(modelsList); i++ { - modelList := strings.Split(modelsList[i], ",") - for _, model := range modelList { - set[model] = true - } - } - modelList := make([]string, 0, len(set)) - for model := range set { - modelList = append(modelList, model) - } - sort.Strings(modelList) - return modelList, err -} From ed70881a58bc77c9be86122d95612c2d225be633 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 4 Apr 2024 11:18:21 +0800 Subject: [PATCH 5/5] fix: fix token create --- controller/token.go | 1 + 1 file changed, 1 insertion(+) diff --git a/controller/token.go b/controller/token.go index c6128534..13b90de0 100644 --- a/controller/token.go +++ b/controller/token.go @@ -130,6 +130,7 @@ func AddToken(c *gin.Context) { ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, } err = cleanToken.Insert() if err != nil {