fix postgresql support

fixes #517
This commit is contained in:
Singee 2023-10-17 16:53:21 +08:00
parent 64cdb7eafb
commit 6198f86001
No known key found for this signature in database
GPG Key ID: 4EC3FD10E03041D7
7 changed files with 37 additions and 7 deletions

View File

@ -22,6 +22,7 @@ var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
var UsingPG = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions

View File

@ -15,9 +15,15 @@ type Ability struct {
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
if common.UsingPG {
groupCol = `"group"`
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = 1", group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {

View File

@ -21,14 +21,19 @@ var (
)
func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
var token Token
if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}

View File

@ -38,7 +38,12 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
}
func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err
}
@ -58,6 +63,8 @@ func GetRandomChannel() (*Channel, error) {
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingPG {
err = DB.Where("status = ? and group = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}

View File

@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPG = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage

View File

@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
}
redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPG {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
if err != nil {
return errors.New("无效的兑换码")
}

View File

@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
groupCol := "`group`"
if common.UsingPG {
groupCol = `"group"`
}
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
return group, err
}