From 6198f86001e403a7aa300f48b111aa81d1e8f168 Mon Sep 17 00:00:00 2001 From: Singee Date: Tue, 17 Oct 2023 16:53:21 +0800 Subject: [PATCH] fix postgresql support fixes #517 --- common/constants.go | 1 + model/ability.go | 10 ++++++++-- model/cache.go | 9 +++++++-- model/channel.go | 9 ++++++++- model/main.go | 1 + model/redemption.go | 7 ++++++- model/user.go | 7 ++++++- 7 files changed, 37 insertions(+), 7 deletions(-) diff --git a/common/constants.go b/common/constants.go index a0361c35..eaffb29e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -22,6 +22,7 @@ var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var UsingSQLite = false +var UsingPG = false // Any options with "Secret", "Token" in its key won't be return by GetOptions diff --git a/model/ability.go b/model/ability.go index 50972a26..9899ddd4 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,9 +15,15 @@ type Ability struct { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} + + groupCol := "`group`" + if common.UsingPG { + groupCol = `"group"` + } + var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) - channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = 1", group, model) + channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) if common.UsingSQLite { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { diff --git a/model/cache.go b/model/cache.go index a7f5c06f..981daf11 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,14 +21,19 @@ var ( ) func CacheGetTokenByKey(key string) (*Token, error) { + keyCol := "`key`" + if common.UsingPG { + keyCol = `"key"` + } + var token Token if !common.RedisEnabled { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error return &token, err } tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error if err != nil { return nil, err } diff --git a/model/channel.go b/model/channel.go index 091a0d71..78de5d65 100644 --- a/model/channel.go +++ b/model/channel.go @@ -38,7 +38,12 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { } func SearchChannels(keyword string) (channels []*Channel, err error) { - err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error + keyCol := "`key`" + if common.UsingPG { + keyCol = `"key"` + } + + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", keyword, keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -58,6 +63,8 @@ func GetRandomChannel() (*Channel, error) { var err error = nil if common.UsingSQLite { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error + } else if common.UsingPG { + err = DB.Where("status = ? and group = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error } else { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error } diff --git a/model/main.go b/model/main.go index 0e962049..d281eeb9 100644 --- a/model/main.go +++ b/model/main.go @@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) { if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") + common.UsingPG = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage diff --git a/model/redemption.go b/model/redemption.go index fafb2145..9045e8bb 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) { } redemption := &Redemption{} + keyCol := "`key`" + if common.UsingPG { + keyCol = `"key"` + } + err = DB.Transaction(func(tx *gorm.DB) error { - err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error if err != nil { return errors.New("无效的兑换码") } diff --git a/model/user.go b/model/user.go index 1b2ec7e6..f5e51d08 100644 --- a/model/user.go +++ b/model/user.go @@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) { } func GetUserGroup(id int) (group string, err error) { - err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + groupCol := "`group`" + if common.UsingPG { + groupCol = `"group"` + } + + err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error return group, err }