From bba49c959eb7623a6f28c09540c86441cb29c3e3 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Thu, 27 Jul 2023 14:47:45 +0800 Subject: [PATCH] feat: support postgres --- common/constants.go | 1 + go.mod | 9 ++++++++- go.sum | 10 ++++++++++ model/ability.go | 12 ++++++++++-- model/cache.go | 10 ++++++++-- model/channel.go | 15 ++++++++++++--- model/main.go | 17 +++++++++++++---- model/redemption.go | 12 ++++++++++-- model/user.go | 8 +++++++- 9 files changed, 79 insertions(+), 15 deletions(-) diff --git a/common/constants.go b/common/constants.go index 81f98163..4fb55995 100644 --- a/common/constants.go +++ b/common/constants.go @@ -22,6 +22,7 @@ var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true var UsingSQLite = false +var UsingPostgreSQL = false // Any options with "Secret", "Token" in its key won't be return by GetOptions diff --git a/go.mod b/go.mod index 2e0cf017..a95a100a 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,13 @@ require ( golang.org/x/crypto v0.9.0 gorm.io/driver/mysql v1.4.3 gorm.io/driver/sqlite v1.4.3 - gorm.io/gorm v1.24.0 + gorm.io/gorm v1.25.0 +) + +require ( + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.3.1 // indirect ) require ( @@ -53,4 +59,5 @@ require ( golang.org/x/text v0.9.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/postgres v1.5.2 ) diff --git a/go.sum b/go.sum index 7287206a..15d7d42f 100644 --- a/go.sum +++ b/go.sum @@ -67,6 +67,12 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -185,9 +191,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= +gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= +gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74= gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= +gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/model/ability.go b/model/ability.go index e87c3940..4fd50fcb 100644 --- a/model/ability.go +++ b/model/ability.go @@ -15,8 +15,16 @@ type Ability struct { func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { ability := Ability{} var err error = nil - if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + + cmd := "`group` = ? and model = ? and enabled = 1" + + if common.UsingPostgreSQL { + // Make cmd compatible with PostgreSQL + cmd = "\"group\" = ? and model = ? and enabled = true" + } + + if common.UsingSQLite || common.UsingPostgreSQL { + err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error } else { err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error } diff --git a/model/cache.go b/model/cache.go index 64666c86..d8683800 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,13 +21,19 @@ var ( func CacheGetTokenByKey(key string) (*Token, error) { var token Token + whereItem := "`key` = ?" + if common.UsingPostgreSQL { + // Make cmd compatible with PostgreSQL + whereItem = "\"key\" = ?" + } + if !common.RedisEnabled { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(whereItem, key).First(&token).Error return &token, err } tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { - err := DB.Where("`key` = ?", key).First(&token).Error + err := DB.Where(whereItem, key).First(&token).Error if err != nil { return nil, err } diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..56c96f24 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,9 @@ package model import ( - "gorm.io/gorm" "one-api/common" + + "gorm.io/gorm" ) type Channel struct { @@ -37,7 +38,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { } func SearchChannels(keyword string) (channels []*Channel, err error) { - err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error + whereItem := "id = ? or name LIKE ? or `key` = ?" + + if common.UsingPostgreSQL { + whereItem = "id = ? or name LIKE ? or \"key\" = ?" + } + + err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -55,7 +62,9 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { func GetRandomChannel() (*Channel, error) { channel := Channel{} var err error = nil - if common.UsingSQLite { + if common.UsingPostgreSQL { + err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error + } else if common.UsingSQLite { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error } else { err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error diff --git a/model/main.go b/model/main.go index 5bc5ce19..741c8c4b 100644 --- a/model/main.go +++ b/model/main.go @@ -1,11 +1,13 @@ package model import ( - "gorm.io/driver/mysql" - "gorm.io/driver/sqlite" - "gorm.io/gorm" "one-api/common" "os" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" ) var DB *gorm.DB @@ -40,7 +42,14 @@ func CountTable(tableName string) (num int64) { func InitDB() (err error) { var db *gorm.DB - if os.Getenv("SQL_DSN") != "" { + if os.Getenv("POSTGRES_DSN") != "" { + // Use PostgreSQL + common.SysLog("using PostgreSQL as database") + common.UsingPostgreSQL = true + db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{ + PrepareStmt: true, // precompile SQL + }) + } else if os.Getenv("SQL_DSN") != "" { // Use MySQL common.SysLog("using MySQL as database") db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{ diff --git a/model/redemption.go b/model/redemption.go index fafb2145..67dfe290 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -3,8 +3,9 @@ package model import ( "errors" "fmt" - "gorm.io/gorm" "one-api/common" + + "gorm.io/gorm" ) type Redemption struct { @@ -51,7 +52,14 @@ func Redeem(key string, userId int) (quota int, err error) { redemption := &Redemption{} err = DB.Transaction(func(tx *gorm.DB) error { - err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error + whereItem := "`key` = ?" + + if common.UsingPostgreSQL { + // Make cmd compatible with PostgreSQL + whereItem = "\"key\" = ?" + } + + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error if err != nil { return errors.New("无效的兑换码") } diff --git a/model/user.go b/model/user.go index 7c771840..b699673c 100644 --- a/model/user.go +++ b/model/user.go @@ -267,7 +267,13 @@ func GetUserEmail(id int) (email string, err error) { } func GetUserGroup(id int) (group string, err error) { - err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + selectItem := "`group`" + + if common.UsingPostgreSQL { + selectItem = "\"group\"" + } + + err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error return group, err }