From 3d76a974d186b99eb70d58bc00b5380bf12e27bd Mon Sep 17 00:00:00 2001 From: JustSong Date: Tue, 20 Jun 2023 19:09:49 +0800 Subject: [PATCH] feat: use cache to avoid database access (#158) --- common/redis.go | 15 ++++++ main.go | 6 +++ middleware/distributor.go | 4 +- model/cache.go | 99 +++++++++++++++++++++++++++++++++++++++ model/token.go | 3 +- 5 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 model/cache.go diff --git a/common/redis.go b/common/redis.go index 56db2b40..8b34a083 100644 --- a/common/redis.go +++ b/common/redis.go @@ -37,3 +37,18 @@ func ParseRedisOption() *redis.Options { } return opt } + +func RedisSet(key string, value string, expiration time.Duration) error { + ctx := context.Background() + return RDB.Set(ctx, key, value, expiration).Err() +} + +func RedisGet(key string) (string, error) { + ctx := context.Background() + return RDB.Get(ctx, key).Result() +} + +func RedisDel(key string) error { + ctx := context.Background() + return RDB.Del(ctx, key).Err() +} diff --git a/main.go b/main.go index c8656c7a..fb6dc1b5 100644 --- a/main.go +++ b/main.go @@ -47,12 +47,18 @@ func main() { // Initialize options model.InitOptionMap() + if common.RedisEnabled { + model.InitChannelCache() + } if os.Getenv("SYNC_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY")) if err != nil { common.FatalLog(err) } go model.SyncOptions(frequency) + if common.RedisEnabled { + go model.SyncChannelCache(frequency) + } } // Initialize HTTP server diff --git a/middleware/distributor.go b/middleware/distributor.go index 08568ea1..091c7f3a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -17,7 +17,7 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") - userGroup, _ := model.GetUserGroup(userId) + userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) var channel *model.Channel channelId, ok := c.Get("channelId") @@ -73,7 +73,7 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "text-moderation-stable" } } - channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { c.JSON(200, gin.H{ "error": gin.H{ diff --git a/model/cache.go b/model/cache.go new file mode 100644 index 00000000..a405c77b --- /dev/null +++ b/model/cache.go @@ -0,0 +1,99 @@ +package model + +import ( + "encoding/json" + "fmt" + "one-api/common" + "sync" + "time" +) + +const ( + TokenCacheSeconds = 60 * 60 + UserId2GroupCacheSeconds = 60 * 60 +) + +func CacheGetTokenByKey(key string) (*Token, error) { + var token Token + if !common.RedisEnabled { + err := DB.Where("`key` = ?", key).First(token).Error + return &token, err + } + tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) + if err != nil { + err := DB.Where("`key` = ?", key).First(token).Error + if err != nil { + return nil, err + } + jsonBytes, err := json.Marshal(token) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second) + if err != nil { + common.SysError("Redis set token error: " + err.Error()) + } + } + err = json.Unmarshal([]byte(tokenObjectString), &token) + return &token, err +} + +func CacheGetUserGroup(id int) (group string, err error) { + if !common.RedisEnabled { + return GetUserGroup(id) + } + group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id)) + if err != nil { + group, err = GetUserGroup(id) + if err != nil { + return "", err + } + err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second) + if err != nil { + common.SysError("Redis set user group error: " + err.Error()) + } + } + return group, err +} + +var channelId2channel map[int]*Channel +var channelSyncLock sync.RWMutex +var group2model2channels map[string]map[string][]*Channel + +func InitChannelCache() { + channelSyncLock.Lock() + defer channelSyncLock.Unlock() + channelId2channel = make(map[int]*Channel) + var channels []*Channel + DB.Find(&channels) + for _, channel := range channels { + channelId2channel[channel.Id] = channel + } + var abilities []*Ability + DB.Find(&abilities) + groups := make(map[string]bool) + for _, ability := range abilities { + groups[ability.Group] = true + } + group2model2channels = make(map[string]map[string][]*Channel) + for group := range groups { + group2model2channels[group] = make(map[string][]*Channel) + // TODO: implement this + } +} + +func SyncChannelCache(frequency int) { + for { + time.Sleep(time.Duration(frequency) * time.Second) + common.SysLog("Syncing channels from database") + InitChannelCache() + } +} + +func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { + if !common.RedisEnabled { + return GetRandomSatisfiedChannel(group, model) + } + // TODO: implement this + return nil, nil +} diff --git a/model/token.go b/model/token.go index 64e52dcd..8744f582 100644 --- a/model/token.go +++ b/model/token.go @@ -36,8 +36,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if key == "" { return nil, errors.New("未提供 token") } - token = &Token{} - err = DB.Where("`key` = ?", key).First(token).Error + token, err = CacheGetTokenByKey(key) if err == nil { if token.Status != common.TokenStatusEnabled { return nil, errors.New("该 token 状态不可用")