From 4a27d75d57a1c73063442c9243b3d41c1c7babfb Mon Sep 17 00:00:00 2001 From: WqyJh <781345688@qq.com> Date: Wed, 1 Nov 2023 14:58:04 +0800 Subject: [PATCH] feat: add weight_mapping for channel --- common/constants.go | 1 + common/utils.go | 16 ++++++- controller/channel.go | 7 ++- go.mod | 1 + go.sum | 2 + model/ability.go | 5 +-- model/cache.go | 61 ++++++++++++++------------ model/channel.go | 50 ++++++++++++++++++++- model/option.go | 3 ++ web/src/components/OperationSetting.js | 19 +++++++- web/src/pages/Channel/EditChannel.js | 25 +++++++++++ 11 files changed, 153 insertions(+), 37 deletions(-) diff --git a/common/constants.go b/common/constants.go index c7d3f222..40b90cbc 100644 --- a/common/constants.go +++ b/common/constants.go @@ -82,6 +82,7 @@ var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 +var DefaultWeight = 10 var RootUserEmail = "" diff --git a/common/utils.go b/common/utils.go index 21bec8f5..903d7c2d 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,7 +2,6 @@ package common import ( "fmt" - "github.com/google/uuid" "html/template" "log" "math/rand" @@ -13,6 +12,8 @@ import ( "strconv" "strings" "time" + + "github.com/google/uuid" ) func OpenBrowser(url string) { @@ -207,3 +208,16 @@ func String2Int(str string) int { } return num } + +func SplitDistinct(s, sep string) []string { + splited := strings.Split(s, sep) + set := make(map[string]struct{}) + list := []string{} + for _, item := range splited { + if _, ok := set[item]; !ok { + set[item] = struct{}{} + list = append(list, item) + } + } + return list +} diff --git a/controller/channel.go b/controller/channel.go index 904abc23..99edfa21 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,12 +1,13 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" ) func GetAllChannels(c *gin.Context) { @@ -92,6 +93,7 @@ func AddChannel(c *gin.Context) { } localChannel := channel localChannel.Key = key + localChannel.FixWeightMapping() channels = append(channels, localChannel) } err = model.BatchInsertChannels(channels) @@ -154,6 +156,9 @@ func UpdateChannel(c *gin.Context) { }) return } + if channel.Models != "" { + channel.FixWeightMapping() + } err = channel.Update() if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/go.mod b/go.mod index 10b78d68..4444860d 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 + github.com/mroth/weightedrand/v2 v2.1.0 github.com/pkoukk/tiktoken-go v0.1.5 golang.org/x/crypto v0.14.0 gorm.io/driver/mysql v1.4.3 diff --git a/go.sum b/go.sum index 4865bcaa..af3a046d 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/mroth/weightedrand/v2 v2.1.0 h1:o1ascnB1CIVzsqlfArQQjeMy1U0NcIbBO5rfd5E/OeU= +github.com/mroth/weightedrand/v2 v2.1.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7OJ/9U4yfiNAOU= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= diff --git a/model/ability.go b/model/ability.go index 3da83be8..c42a2352 100644 --- a/model/ability.go +++ b/model/ability.go @@ -2,7 +2,6 @@ package model import ( "one-api/common" - "strings" ) type Ability struct { @@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { } func (channel *Channel) AddAbilities() error { - models_ := strings.Split(channel.Models, ",") - groups_ := strings.Split(channel.Group, ",") + models_ := channel.GetModels() + groups_ := channel.GetGroups() abilities := make([]Ability, 0, len(models_)) for _, model := range models_ { for _, group := range groups_ { diff --git a/model/cache.go b/model/cache.go index c6d0c70a..09e1288d 100644 --- a/model/cache.go +++ b/model/cache.go @@ -4,13 +4,12 @@ import ( "encoding/json" "errors" "fmt" - "math/rand" "one-api/common" - "sort" "strconv" - "strings" "sync" "time" + + "github.com/mroth/weightedrand/v2" ) var ( @@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } -var group2model2channels map[string]map[string][]*Channel +var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int] var channelSyncLock sync.RWMutex func InitChannelCache() { @@ -148,35 +147,51 @@ func InitChannelCache() { for _, ability := range abilities { groups[ability.Group] = true } - newGroup2model2channels := make(map[string]map[string][]*Channel) + newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int]) for group := range groups { - newGroup2model2channels[group] = make(map[string][]*Channel) + newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int]) } for _, channel := range channels { - groups := strings.Split(channel.Group, ",") + groups := channel.GetGroups() for _, group := range groups { - models := strings.Split(channel.Models, ",") + models := channel.GetModels() + weightMapping := channel.GetWeightMapping() for _, model := range models { if _, ok := newGroup2model2channels[group][model]; !ok { - newGroup2model2channels[group][model] = make([]*Channel, 0) + newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0) } - newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) + weight, ok := weightMapping[model] + if weight < 0 || !ok { + // use default value if: + // weight < 0: invalid + // !ok: weight not set + weight = common.DefaultWeight + } + newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight)) } } } // sort by priority + m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int]) for group, model2channels := range newGroup2model2channels { + m[group] = make(map[string]*weightedrand.Chooser[*Channel, int]) for model, channels := range model2channels { - sort.Slice(channels, func(i, j int) bool { - return channels[i].GetPriority() > channels[j].GetPriority() - }) - newGroup2model2channels[group][model] = channels + if len(channels) == 0 { + common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model)) + continue + } + c, err := weightedrand.NewChooser(channels...) + if err != nil { + common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error())) + continue + } + m[group][model] = c } } channelSyncLock.Lock() - group2model2channels = newGroup2model2channels + group2model2channels = m channelSyncLock.Unlock() common.SysLog("channels synced from database") } @@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error channelSyncLock.RLock() defer channelSyncLock.RUnlock() channels := group2model2channels[group][model] - if len(channels) == 0 { + if channels == nil { return nil, errors.New("channel not found") } - endIdx := len(channels) - // choose by priority - firstChannel := channels[0] - if firstChannel.GetPriority() > 0 { - for i := range channels { - if channels[i].GetPriority() != firstChannel.GetPriority() { - endIdx = i - break - } - } - } - idx := rand.Intn(endIdx) - return channels[idx], nil + return channels.Pick(), nil } diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..294367c1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,10 @@ package model import ( - "gorm.io/gorm" + "encoding/json" "one-api/common" + + "gorm.io/gorm" ) type Channel struct { @@ -23,6 +25,7 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` } @@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error { return nil } +func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) { + if channel.WeightMapping == nil || *channel.WeightMapping == "" { + return + } + err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping) + if err != nil { + common.SysError("failed to unmarshal weight mapping: " + err.Error()) + } + return +} + +func (channel *Channel) FixWeightMapping() { + var weightMapping map[string]int + if channel.WeightMapping == nil || *channel.WeightMapping == "" { + weightMapping = make(map[string]int) + } else { + err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping) + if err != nil { + common.SysError("failed to marshal weight mapping: " + err.Error()) + } + } + + models := channel.GetModels() + for _, model := range models { + if _, ok := weightMapping[model]; !ok { + weightMapping[model] = common.DefaultWeight + } + } + + jsonStr, err := json.Marshal(weightMapping) + if err != nil { + common.SysError("failed to marshal weight mapping: " + err.Error()) + } + var result = string(jsonStr) + channel.WeightMapping = &result +} + +func (channel *Channel) GetModels() []string { + return common.SplitDistinct(channel.Models, ",") +} + +func (channel *Channel) GetGroups() []string { + return common.SplitDistinct(channel.Group, ",") +} + func (channel *Channel) GetPriority() int64 { if channel.Priority == nil { return 0 diff --git a/model/option.go b/model/option.go index 4ef4d260..371b1e2c 100644 --- a/model/option.go +++ b/model/option.go @@ -71,6 +71,7 @@ func InitOptionMap() { common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) + common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight) common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) { common.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) + case "DefaultWeight": + common.DefaultWeight, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index bf8b5ffd..1b2e2835 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -21,7 +21,8 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', - RetryTimes: 0 + RetryTimes: 0, + DefaultWeight: 10 }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -128,6 +129,9 @@ const OperationSetting = () => { if (originInputs['RetryTimes'] !== inputs.RetryTimes) { await updateOption('RetryTimes', inputs.RetryTimes); } + if (originInputs['DefaultWeight'] !== inputs.DefaultWeight) { + await updateOption('DefaultWeight', inputs.DefaultWeight); + } break; } }; @@ -150,7 +154,7 @@ const OperationSetting = () => {
通用设置
- + { value={inputs.RetryTimes} placeholder='失败重试次数' /> + { base_url: '', other: '', model_mapping: '', + weight_mapping: '', models: [], groups: ['default'] }; @@ -105,6 +112,9 @@ const EditChannel = () => { if (data.model_mapping !== '') { data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } + if (data.weight_mapping !== '') { + data.weight_mapping = JSON.stringify(JSON.parse(data.weight_mapping), null, 2); + } setInputs(data); } else { showError(message); @@ -178,6 +188,10 @@ const EditChannel = () => { showInfo('模型映射必须是合法的 JSON 格式!'); return; } + if (inputs.weight_mapping !== '' && !verifyJSON(inputs.weight_mapping)) { + showInfo('模型权重必须是合法的 JSON 格式!'); + return; + } let localInputs = inputs; if (localInputs.base_url && localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); @@ -396,6 +410,17 @@ const EditChannel = () => { autoComplete='new-password' /> + + + { batch ?