From c78fba83eb4ba36b579428eb12721da2bea1c24f Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 6 Oct 2023 22:34:28 +0800 Subject: [PATCH] fix: pgsql new --- model/ability.go | 33 ++++++++++++++++++++++----------- model/channel.go | 15 ++++++--------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/model/ability.go b/model/ability.go index cdc8fad6..d59c40f8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,7 +1,6 @@ package model import ( - "fmt" "one-api/common" "strings" ) @@ -20,21 +19,33 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe ability := Ability{} var err error = nil - cmd := "`group` = ? and model = ? and enabled = 1" - - if common.UsingPostgreSQL { - // Make cmd compatible with PostgreSQL - cmd = "\"group\" = ? and model = ? and enabled = true" - } + var cmdWhere *Ability if stream { - cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled) + cmdWhere = &Ability{ + Group: group, + Model: model, + Enabled: true, + AllowStreaming: common.ChannelAllowStreamEnabled, + } } else { - cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled) + cmdWhere = &Ability{ + Group: group, + Model: model, + Enabled: true, + AllowNonStreaming: common.ChannelAllowNonStreamEnabled, + } } - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmd, group, model) - channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere) + + cmd1 := "`group` = ? and model = ? and enabled = 1 and priority = (?)" + + if common.UsingPostgreSQL { + cmd1 = "\"group\" = ? and model = ? and enabled = 1 and priority = (?)" + } + + channelQuery := DB.Where(cmd1, group, model, maxPrioritySubQuery) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { diff --git a/model/channel.go b/model/channel.go index d20a5172..3f5b393b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -2,6 +2,7 @@ package model import ( "one-api/common" + "strconv" "gorm.io/gorm" ) @@ -41,13 +42,9 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { } func SearchChannels(keyword string) (channels []*Channel, err error) { - whereItem := "id = ? or name LIKE ? or `key` = ?" + idKeyword, err := strconv.Atoi(keyword) - if common.UsingPostgreSQL { - whereItem = "id = ? or name LIKE ? or \"key\" = ?" - } - - err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("name LIKE ?", keyword+"%").Or(&Channel{Id: idKeyword}).Or(&Channel{Key: keyword}).Find(&channels).Error return channels, err } @@ -66,11 +63,11 @@ func GetRandomChannel() (*Channel, error) { channel := Channel{} var err error = nil if common.UsingPostgreSQL { - err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error } else if common.UsingSQLite { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error } else { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RAND()").Limit(1).First(&channel).Error } return &channel, err }