From 4f3d925e8db204b9d177269ab5b487f5c4b2de17 Mon Sep 17 00:00:00 2001 From: glzjin Date: Wed, 16 Aug 2023 00:16:00 +0800 Subject: [PATCH] Fix postgres support --- model/ability.go | 16 ++++++++++------ model/cache.go | 4 ++-- model/channel.go | 24 ++++++++++++++++++------ model/redemption.go | 9 +++++++-- model/user.go | 9 +++++++-- 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/model/ability.go b/model/ability.go index e87c3940..7eb15e6f 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,11 @@ package model import ( + "errors" + "math/rand" "one-api/common" "strings" + "time" ) type Ability struct { @@ -13,16 +16,17 @@ type Ability struct { } func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - ability := Ability{} + var abilities []Ability var err error = nil - if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error - } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error - } + err = DB.Where(&Ability{Group: group, Model: model, Enabled: true}).Find(&abilities).Error if err != nil { return nil, err } + if len(abilities) == 0 { + return nil, errors.New("channel not found") + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + ability := abilities[r.Intn(len(abilities))] channel := Channel{} channel.Id = ability.ChannelId err = DB.First(&channel, "id = ?", ability.ChannelId).Error diff --git a/model/cache.go b/model/cache.go index 64666c86..e2355726 100644 --- a/model/cache.go +++ b/model/cache.go @@ -22,12 +22,12 @@ var ( func CacheGetTokenByKey(key string) (*Token, error) { var token Token if !common.RedisEnabled { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(&Token{Key: key}).First(&token).Error return &token, err } tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(&Token{Key: key}).First(&token).Error if err != nil { return nil, err } diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..275839d5 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,12 @@ package model import ( + "errors" "gorm.io/gorm" + "math/rand" "one-api/common" + "strconv" + "time" ) type Channel struct { @@ -37,7 +41,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { } func SearchChannels(keyword string) (channels []*Channel, err error) { - err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error + idKeyword, err := strconv.Atoi(keyword) + if err != nil { + idKeyword = 0 + } + err = DB.Omit("key").Where("name LIKE ?", keyword+"%").Or(&Channel{Id: idKeyword}).Or(&Channel{Key: keyword}).Find(&channels).Error return channels, err } @@ -53,13 +61,17 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { } func GetRandomChannel() (*Channel, error) { - channel := Channel{} + var channels []Channel var err error = nil - if common.UsingSQLite { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error - } else { - err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Find(&channels).Error + if err != nil { + return nil, err } + if len(channels) == 0 { + return nil, errors.New("no enabled channel") + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + channel := channels[r.Intn(len(channels))] return &channel, err } diff --git a/model/redemption.go b/model/redemption.go index fafb2145..e77cc891 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "strconv" ) type Redemption struct { @@ -27,7 +28,11 @@ func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) { } func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { - err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error + idKeyword, err := strconv.Atoi(keyword) + if err != nil { + idKeyword = 0 + } + err = DB.Where("name LIKE ?", keyword+"%").Or(&Redemption{Id: idKeyword}).Find(&redemptions).Error return redemptions, err } @@ -51,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) { redemption := &Redemption{} err = DB.Transaction(func(tx *gorm.DB) error { - err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error if err != nil { return errors.New("无效的兑换码") } diff --git a/model/user.go b/model/user.go index 7c771840..57e6b843 100644 --- a/model/user.go +++ b/model/user.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "strconv" "strings" ) @@ -42,7 +43,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { } func SearchUsers(keyword string) (users []*User, err error) { - err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error + idKeyword, err := strconv.Atoi(keyword) + if err != nil { + idKeyword = 0 + } + err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Or(&User{Id: idKeyword}).Find(&users).Error return users, err } @@ -267,7 +272,7 @@ func GetUserEmail(id int) (email string, err error) { } func GetUserGroup(id int) (group string, err error) { - err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + err = DB.Model(&User{}).Where("id = ?", id).Select("group").Find(&group).Error return group, err }