From c2aa0e33122510c9e4173e66a90f3efdd898cd1b Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 22 Oct 2023 18:33:45 +0800 Subject: [PATCH] fix: fix pg support --- common/constants.go | 4 ---- common/database.go | 6 ++++++ common/utils.go | 8 ++++++++ model/ability.go | 11 ++++++----- model/cache.go | 3 +-- model/channel.go | 9 +++------ model/main.go | 2 +- model/redemption.go | 2 +- model/user.go | 2 +- 9 files changed, 27 insertions(+), 20 deletions(-) create mode 100644 common/database.go diff --git a/common/constants.go b/common/constants.go index eaffb29e..c25785c7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -21,13 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true -var UsingSQLite = false -var UsingPG = false - // Any options with "Secret", "Token" in its key won't be return by GetOptions var SessionSecret = uuid.New().String() -var SQLitePath = "one-api.db" var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex diff --git a/common/database.go b/common/database.go new file mode 100644 index 00000000..c7e9fd52 --- /dev/null +++ b/common/database.go @@ -0,0 +1,6 @@ +package common + +var UsingSQLite = false +var UsingPostgreSQL = false + +var SQLitePath = "one-api.db" diff --git a/common/utils.go b/common/utils.go index ab901b77..21bec8f5 100644 --- a/common/utils.go +++ b/common/utils.go @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/model/ability.go b/model/ability.go index 9899ddd4..3da83be8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,16 +15,17 @@ type Ability struct { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} - groupCol := "`group`" - if common.UsingPG { + trueVal := "1" + if common.UsingPostgreSQL { groupCol = `"group"` + trueVal = "true" } var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = 1", group, model) - channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) - if common.UsingSQLite { + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) + if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { err = channelQuery.Order("RAND()").First(&ability).Error diff --git a/model/cache.go b/model/cache.go index 981daf11..c6d0c70a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -22,10 +22,9 @@ var ( func CacheGetTokenByKey(key string) (*Token, error) { keyCol := "`key`" - if common.UsingPG { + if common.UsingPostgreSQL { keyCol = `"key"` } - var token Token if !common.RedisEnabled { err := DB.Where(keyCol+" = ?", key).First(&token).Error diff --git a/model/channel.go b/model/channel.go index 78de5d65..86d77f72 100644 --- a/model/channel.go +++ b/model/channel.go @@ -39,11 +39,10 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func SearchChannels(keyword string) (channels []*Channel, err error) { keyCol := "`key`" - if common.UsingPG { + if common.UsingPostgreSQL { keyCol = `"key"` } - - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", keyword, keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -61,10 +60,8 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { func GetRandomChannel() (*Channel, error) { channel := Channel{} var err error = nil - if common.UsingSQLite { + if common.UsingSQLite || common.UsingPostgreSQL { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error - } else if common.UsingPG { - err = DB.Where("status = ? and group = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error } else { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error } diff --git a/model/main.go b/model/main.go index d281eeb9..08182634 100644 --- a/model/main.go +++ b/model/main.go @@ -42,7 +42,7 @@ func chooseDB() (*gorm.DB, error) { if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL common.SysLog("using PostgreSQL as database") - common.UsingPG = true + common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, // disables implicit prepared statement usage diff --git a/model/redemption.go b/model/redemption.go index 9045e8bb..f16412b5 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -51,7 +51,7 @@ func Redeem(key string, userId int) (quota int, err error) { redemption := &Redemption{} keyCol := "`key`" - if common.UsingPG { + if common.UsingPostgreSQL { keyCol = `"key"` } diff --git a/model/user.go b/model/user.go index f5e51d08..7844eb6a 100644 --- a/model/user.go +++ b/model/user.go @@ -267,7 +267,7 @@ func GetUserEmail(id int) (email string, err error) { func GetUserGroup(id int) (group string, err error) { groupCol := "`group`" - if common.UsingPG { + if common.UsingPostgreSQL { groupCol = `"group"` }