Fix postgres support

This commit is contained in:
glzjin 2023-08-16 00:16:00 +08:00
parent da1d81998f
commit 4f3d925e8d
5 changed files with 44 additions and 18 deletions

View File

@ -1,8 +1,11 @@
package model package model
import ( import (
"errors"
"math/rand"
"one-api/common" "one-api/common"
"strings" "strings"
"time"
) )
type Ability struct { type Ability struct {
@ -13,16 +16,17 @@ type Ability struct {
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
ability := Ability{} var abilities []Ability
var err error = nil var err error = nil
if common.UsingSQLite { err = DB.Where(&Ability{Group: group, Model: model, Enabled: true}).Find(&abilities).Error
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(abilities) == 0 {
return nil, errors.New("channel not found")
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
ability := abilities[r.Intn(len(abilities))]
channel := Channel{} channel := Channel{}
channel.Id = ability.ChannelId channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error err = DB.First(&channel, "id = ?", ability.ChannelId).Error

View File

@ -22,12 +22,12 @@ var (
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
var token Token var token Token
if !common.RedisEnabled { if !common.RedisEnabled {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(&Token{Key: key}).First(&token).Error
return &token, err return &token, err
} }
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil { if err != nil {
err := DB.Where("`key` = ?", key).First(&token).Error err := DB.Where(&Token{Key: key}).First(&token).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,8 +1,12 @@
package model package model
import ( import (
"errors"
"gorm.io/gorm" "gorm.io/gorm"
"math/rand"
"one-api/common" "one-api/common"
"strconv"
"time"
) )
type Channel struct { type Channel struct {
@ -37,7 +41,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error idKeyword, err := strconv.Atoi(keyword)
if err != nil {
idKeyword = 0
}
err = DB.Omit("key").Where("name LIKE ?", keyword+"%").Or(&Channel{Id: idKeyword}).Or(&Channel{Key: keyword}).Find(&channels).Error
return channels, err return channels, err
} }
@ -53,13 +61,17 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
} }
func GetRandomChannel() (*Channel, error) { func GetRandomChannel() (*Channel, error) {
channel := Channel{} var channels []Channel
var err error = nil var err error = nil
if common.UsingSQLite { err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Find(&channels).Error
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error if err != nil {
} else { return nil, err
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
} }
if len(channels) == 0 {
return nil, errors.New("no enabled channel")
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
channel := channels[r.Intn(len(channels))]
return &channel, err return &channel, err
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strconv"
) )
type Redemption struct { type Redemption struct {
@ -27,7 +28,11 @@ func GetAllRedemptions(startIdx int, num int) ([]*Redemption, error) {
} }
func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) { func SearchRedemptions(keyword string) (redemptions []*Redemption, err error) {
err = DB.Where("id = ? or name LIKE ?", keyword, keyword+"%").Find(&redemptions).Error idKeyword, err := strconv.Atoi(keyword)
if err != nil {
idKeyword = 0
}
err = DB.Where("name LIKE ?", keyword+"%").Or(&Redemption{Id: idKeyword}).Find(&redemptions).Error
return redemptions, err return redemptions, err
} }
@ -51,7 +56,7 @@ func Redeem(key string, userId int) (quota int, err error) {
redemption := &Redemption{} redemption := &Redemption{}
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strconv"
"strings" "strings"
) )
@ -42,7 +43,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
} }
func SearchUsers(keyword string) (users []*User, err error) { func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error idKeyword, err := strconv.Atoi(keyword)
if err != nil {
idKeyword = 0
}
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Or(&User{Id: idKeyword}).Find(&users).Error
return users, err return users, err
} }
@ -267,7 +272,7 @@ func GetUserEmail(id int) (email string, err error) {
} }
func GetUserGroup(id int) (group string, err error) { func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error err = DB.Model(&User{}).Where("id = ?", id).Select("group").Find(&group).Error
return group, err return group, err
} }