fix postgresql support

fixes #517
This commit is contained in:
Singee 2023-10-17 16:53:21 +08:00
parent 64cdb7eafb
commit 6198f86001
No known key found for this signature in database
GPG Key ID: 4EC3FD10E03041D7
7 changed files with 37 additions and 7 deletions

View File

@ -22,6 +22,7 @@ var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true var DisplayTokenStatEnabled = true
var UsingSQLite = false var UsingSQLite = false
var UsingPG = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions // Any options with "Secret", "Token" in its key won't be return by GetOptions

View File

@ -15,9 +15,15 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{} ability := Ability{}
groupCol := "`group`"
if common.UsingPG {
groupCol = `"group"`
}
var err error = nil var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite { if common.UsingSQLite {
err = channelQuery.Order("RANDOM()").First(&ability).Error err = channelQuery.Order("RANDOM()").First(&ability).Error
} else { } else {

View File

@ -21,14 +21,19 @@ var (
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
var token Token var token Token
if !common.RedisEnabled { if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err return &token, err
} }
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil { if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -38,7 +38,12 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }
@ -58,6 +63,8 @@ func GetRandomChannel() (*Channel, error) {
var err error = nil var err error = nil
if common.UsingSQLite { if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingPG {
err = DB.Where("status = ? and group = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
} else { } else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
} }

View File

@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") { if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL // Use PostgreSQL
common.SysLog("using PostgreSQL as database") common.SysLog("using PostgreSQL as database")
common.UsingPG = true
return gorm.Open(postgres.New(postgres.Config{ return gorm.Open(postgres.New(postgres.Config{
DSN: dsn, DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
} }
redemption := &Redemption{} redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }

View File

@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
} }
func GetUserGroup(id int) (group string, err error) { func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error groupCol := "`group`"
if common.UsingPG {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err return group, err
} }