diff --git a/model/ability.go b/model/ability.go index 48b856a2..4a48bc51 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" "gorm.io/gorm" + "sort" "strings" ) @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/channel.go b/model/channel.go index 24829bc5..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,6 @@ package model import ( - "context" "encoding/json" "fmt" "github.com/songquanpeng/one-api/common" @@ -9,8 +8,6 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "sort" - "strings" ) type Channel struct { @@ -28,7 +25,7 @@ type Channel struct { Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` Models string `json:"models"` - Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` @@ -205,28 +202,3 @@ func DeleteDisabledChannel() (int64, error) { result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } - -func GetGroupModels(ctx context.Context, group string) ([]string, error) { - groupCol := "`group`" - if common.UsingPostgreSQL { - groupCol = `"group"` - } - var modelsList []string - err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error - if err != nil { - return nil, err - } - set := make(map[string]bool) - for i := 0; i < len(modelsList); i++ { - modelList := strings.Split(modelsList[i], ",") - for _, model := range modelList { - set[model] = true - } - } - modelList := make([]string, 0, len(set)) - for model := range set { - modelList = append(modelList, model) - } - sort.Strings(modelList) - return modelList, err -}