From ba54c71948b765866da0de918db6cd5af7c39866 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 21 Jun 2023 17:04:18 +0800 Subject: [PATCH] feat: select channel without database (#158) --- model/cache.go | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/model/cache.go b/model/cache.go index cf7cf4f7..6cf4808a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -2,8 +2,11 @@ package model import ( "encoding/json" + "errors" "fmt" + "math/rand" "one-api/common" + "strings" "sync" "time" ) @@ -57,18 +60,15 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -var channelId2channel map[int]*Channel -var channelSyncLock sync.RWMutex var group2model2channels map[string]map[string][]*Channel +var channelSyncLock sync.RWMutex func InitChannelCache() { - channelSyncLock.Lock() - defer channelSyncLock.Unlock() - channelId2channel = make(map[int]*Channel) + newChannelId2channel := make(map[int]*Channel) var channels []*Channel DB.Find(&channels) for _, channel := range channels { - channelId2channel[channel.Id] = channel + newChannelId2channel[channel.Id] = channel } var abilities []*Ability DB.Find(&abilities) @@ -76,11 +76,26 @@ func InitChannelCache() { for _, ability := range abilities { groups[ability.Group] = true } - group2model2channels = make(map[string]map[string][]*Channel) + newGroup2model2channels := make(map[string]map[string][]*Channel) for group := range groups { - group2model2channels[group] = make(map[string][]*Channel) - // TODO: implement this + newGroup2model2channels[group] = make(map[string][]*Channel) } + for _, channel := range channels { + groups := strings.Split(channel.Group, ",") + for _, group := range groups { + models := strings.Split(channel.Models, ",") + for _, model := range models { + if _, ok := newGroup2model2channels[group][model]; !ok { + newGroup2model2channels[group][model] = make([]*Channel, 0) + } + newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) + } + } + } + channelSyncLock.Lock() + group2model2channels = newGroup2model2channels + channelSyncLock.Unlock() + common.SysLog("Channels synced from database") } func SyncChannelCache(frequency int) { @@ -95,7 +110,12 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if !common.RedisEnabled { return GetRandomSatisfiedChannel(group, model) } - return GetRandomSatisfiedChannel(group, model) - // TODO: implement this - return nil, nil + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + channels := group2model2channels[group][model] + if len(channels) == 0 { + return nil, errors.New("channel not found") + } + idx := rand.Intn(len(channels)) + return channels[idx], nil }