fix: fix GetGroupModels

This commit is contained in:
JustSong 2024-04-04 02:58:21 +08:00
parent 8b9813d63b
commit 8b9fa3d6e4
2 changed files with 19 additions and 29 deletions

View File

@ -1,8 +1,10 @@
package model package model
import ( import (
"context"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"gorm.io/gorm" "gorm.io/gorm"
"sort"
"strings" "strings"
) )
@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error { func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
} }
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var models []string
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
if err != nil {
return nil, err
}
sort.Strings(models)
return models, err
}

View File

@ -1,7 +1,6 @@
package model package model
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
@ -9,8 +8,6 @@ import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"sort"
"strings"
) )
type Channel struct { type Channel struct {
@ -28,7 +25,7 @@ type Channel struct {
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
@ -205,28 +202,3 @@ func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
var modelsList []string
err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error
if err != nil {
return nil, err
}
set := make(map[string]bool)
for i := 0; i < len(modelsList); i++ {
modelList := strings.Split(modelsList[i], ",")
for _, model := range modelList {
set[model] = true
}
}
modelList := make([]string, 0, len(set))
for model := range set {
modelList = append(modelList, model)
}
sort.Strings(modelList)
return modelList, err
}