From c3dc315e7515723a172b811b6d5b7b166e188984 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 3 Sep 2023 14:58:20 +0800 Subject: [PATCH] feat: add batch update support (close #414) --- README.md | 4 +++ common/constants.go | 3 ++ main.go | 5 +++ model/channel.go | 8 +++++ model/token.go | 16 ++++++++++ model/user.go | 26 +++++++++++++++- model/utils.go | 75 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 model/utils.go diff --git a/README.md b/README.md index a0f3bcb9..a2105df2 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,10 @@ graph LR + 例子:`CHANNEL_TEST_FREQUENCY=1440` 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` +10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`BATCH_UPDATE_ENABLED=true` +11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + + 例子:`BATCH_UPDATE_INTERVAL=5` ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index 66ca06f4..b272fbe6 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY +var BatchUpdateEnabled = false +var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/main.go b/main.go index 9fb0a73e..8c5f2f31 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,11 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { + common.BatchUpdateEnabled = true + common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + model.InitBatchUpdater() + } controller.InitTokenEncoders() // Initialize HTTP server diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..5c495bab 100644 --- a/model/channel.go +++ b/model/channel.go @@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) { } func UpdateChannelUsedQuota(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) + return + } + updateChannelUsedQuota(id, quota) +} + +func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { common.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/token.go b/model/token.go index 7cd226c6..dfda27e3 100644 --- a/model/token.go +++ b/model/token.go @@ -131,6 +131,14 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, quota) + return nil + } + return increaseTokenQuota(id, quota) +} + +func increaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -144,6 +152,14 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) + return nil + } + return decreaseTokenQuota(id, quota) +} + +func decreaseTokenQuota(id int, quota int) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), diff --git a/model/user.go b/model/user.go index 7c771840..67511267 100644 --- a/model/user.go +++ b/model/user.go @@ -275,6 +275,14 @@ func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, quota) + return nil + } + return increaseUserQuota(id, quota) +} + +func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } @@ -283,6 +291,14 @@ func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUserQuota, id, -quota) + return nil + } + return decreaseUserQuota(id, quota) +} + +func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -293,10 +309,18 @@ func GetRootUserEmail() (email string) { } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota) + return + } + updateUserUsedQuotaAndRequestCount(id, quota, 1) +} + +func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), - "request_count": gorm.Expr("request_count + ?", 1), + "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { diff --git a/model/utils.go b/model/utils.go new file mode 100644 index 00000000..61734332 --- /dev/null +++ b/model/utils.go @@ -0,0 +1,75 @@ +package model + +import ( + "one-api/common" + "sync" + "time" +) + +const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock + +const ( + BatchUpdateTypeUserQuota = iota + BatchUpdateTypeTokenQuota + BatchUpdateTypeUsedQuotaAndRequestCount + BatchUpdateTypeChannelUsedQuota +) + +var batchUpdateStores []map[int]int +var batchUpdateLocks []sync.Mutex + +func init() { + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) + } +} + +func InitBatchUpdater() { + go func() { + for { + time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) + batchUpdate() + } + }() +} + +func addNewRecord(type_ int, id int, value int) { + batchUpdateLocks[type_].Lock() + defer batchUpdateLocks[type_].Unlock() + if _, ok := batchUpdateStores[type_][id]; !ok { + batchUpdateStores[type_][id] = value + } else { + batchUpdateStores[type_][id] += value + } +} + +func batchUpdate() { + common.SysLog("batch update started") + for i := 0; i < BatchUpdateTypeCount; i++ { + batchUpdateLocks[i].Lock() + store := batchUpdateStores[i] + batchUpdateStores[i] = make(map[int]int) + batchUpdateLocks[i].Unlock() + + for key, value := range store { + switch i { + case BatchUpdateTypeUserQuota: + err := increaseUserQuota(key, value) + if err != nil { + common.SysError("failed to batch update user quota: " + err.Error()) + } + case BatchUpdateTypeTokenQuota: + err := increaseTokenQuota(key, value) + if err != nil { + common.SysError("failed to batch update token quota: " + err.Error()) + } + case BatchUpdateTypeUsedQuotaAndRequestCount: + updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect + case BatchUpdateTypeChannelUsedQuota: + updateChannelUsedQuota(key, value) + } + } + } + common.SysLog("batch update finished") +}