fix: fix postgresql support (#606)
* fix postgresql support fixes #517 * fix: fix pg support * chore: delete useless code --------- Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
parent
57aa637c77
commit
a398f35968
@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
|||||||
var DisplayInCurrencyEnabled = true
|
var DisplayInCurrencyEnabled = true
|
||||||
var DisplayTokenStatEnabled = true
|
var DisplayTokenStatEnabled = true
|
||||||
|
|
||||||
var UsingSQLite = false
|
|
||||||
|
|
||||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||||
|
|
||||||
var SessionSecret = uuid.New().String()
|
var SessionSecret = uuid.New().String()
|
||||||
var SQLitePath = "one-api.db"
|
|
||||||
|
|
||||||
var OptionMap map[string]string
|
var OptionMap map[string]string
|
||||||
var OptionMapRWMutex sync.RWMutex
|
var OptionMapRWMutex sync.RWMutex
|
||||||
|
6
common/database.go
Normal file
6
common/database.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
var UsingSQLite = false
|
||||||
|
var UsingPostgreSQL = false
|
||||||
|
|
||||||
|
var SQLitePath = "one-api.db"
|
@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
|
|||||||
func MessageWithRequestId(message string, id string) string {
|
func MessageWithRequestId(message string, id string) string {
|
||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func String2Int(str string) int {
|
||||||
|
num, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
@ -15,10 +15,17 @@ type Ability struct {
|
|||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
ability := Ability{}
|
ability := Ability{}
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
var err error = nil
|
var err error = nil
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||||
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
|
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||||
if common.UsingSQLite {
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||||
|
@ -21,14 +21,18 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
var token Token
|
var token Token
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
|
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
|||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomChannel() (*Channel, error) {
|
|
||||||
channel := Channel{}
|
|
||||||
var err error = nil
|
|
||||||
if common.UsingSQLite {
|
|
||||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
|
||||||
} else {
|
|
||||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
|
|
||||||
}
|
|
||||||
return &channel, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func BatchInsertChannels(channels []Channel) error {
|
func BatchInsertChannels(channels []Channel) error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(&channels).Error
|
err = DB.Create(&channels).Error
|
||||||
|
@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
if strings.HasPrefix(dsn, "postgres://") {
|
if strings.HasPrefix(dsn, "postgres://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
common.SysLog("using PostgreSQL as database")
|
common.SysLog("using PostgreSQL as database")
|
||||||
|
common.UsingPostgreSQL = true
|
||||||
return gorm.Open(postgres.New(postgres.Config{
|
return gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dsn,
|
DSN: dsn,
|
||||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||||
|
@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
}
|
}
|
||||||
redemption := &Redemption{}
|
redemption := &Redemption{}
|
||||||
|
|
||||||
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
|
|
||||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("无效的兑换码")
|
return errors.New("无效的兑换码")
|
||||||
}
|
}
|
||||||
|
@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUserGroup(id int) (group string, err error) {
|
func GetUserGroup(id int) (group string, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
|
groupCol := "`group`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user