diff --git a/common/constants.go b/common/constants.go
index c7d3f222..40b90cbc 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -82,6 +82,7 @@ var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
+var DefaultWeight = 10
var RootUserEmail = ""
diff --git a/common/utils.go b/common/utils.go
index 21bec8f5..903d7c2d 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -2,7 +2,6 @@ package common
import (
"fmt"
- "github.com/google/uuid"
"html/template"
"log"
"math/rand"
@@ -13,6 +12,8 @@ import (
"strconv"
"strings"
"time"
+
+ "github.com/google/uuid"
)
func OpenBrowser(url string) {
@@ -207,3 +208,16 @@ func String2Int(str string) int {
}
return num
}
+
+func SplitDistinct(s, sep string) []string {
+ splited := strings.Split(s, sep)
+ set := make(map[string]struct{})
+ list := []string{}
+ for _, item := range splited {
+ if _, ok := set[item]; !ok {
+ set[item] = struct{}{}
+ list = append(list, item)
+ }
+ }
+ return list
+}
diff --git a/controller/channel.go b/controller/channel.go
index 904abc23..99edfa21 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -1,12 +1,13 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
+
+ "github.com/gin-gonic/gin"
)
func GetAllChannels(c *gin.Context) {
@@ -92,6 +93,7 @@ func AddChannel(c *gin.Context) {
}
localChannel := channel
localChannel.Key = key
+ localChannel.FixWeightMapping()
channels = append(channels, localChannel)
}
err = model.BatchInsertChannels(channels)
@@ -154,6 +156,9 @@ func UpdateChannel(c *gin.Context) {
})
return
}
+ if channel.Models != "" {
+ channel.FixWeightMapping()
+ }
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{
diff --git a/go.mod b/go.mod
index 10b78d68..4444860d 100644
--- a/go.mod
+++ b/go.mod
@@ -14,6 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
+ github.com/mroth/weightedrand/v2 v2.1.0
github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3
diff --git a/go.sum b/go.sum
index 4865bcaa..af3a046d 100644
--- a/go.sum
+++ b/go.sum
@@ -111,6 +111,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/mroth/weightedrand/v2 v2.1.0 h1:o1ascnB1CIVzsqlfArQQjeMy1U0NcIbBO5rfd5E/OeU=
+github.com/mroth/weightedrand/v2 v2.1.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7OJ/9U4yfiNAOU=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
diff --git a/model/ability.go b/model/ability.go
index 3da83be8..c42a2352 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -2,7 +2,6 @@ package model
import (
"one-api/common"
- "strings"
)
type Ability struct {
@@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
func (channel *Channel) AddAbilities() error {
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
+ models_ := channel.GetModels()
+ groups_ := channel.GetGroups()
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
diff --git a/model/cache.go b/model/cache.go
index c6d0c70a..09e1288d 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -4,13 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
- "math/rand"
"one-api/common"
- "sort"
"strconv"
- "strings"
"sync"
"time"
+
+ "github.com/mroth/weightedrand/v2"
)
var (
@@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
-var group2model2channels map[string]map[string][]*Channel
+var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int]
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@@ -148,35 +147,51 @@ func InitChannelCache() {
for _, ability := range abilities {
groups[ability.Group] = true
}
- newGroup2model2channels := make(map[string]map[string][]*Channel)
+ newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int])
for group := range groups {
- newGroup2model2channels[group] = make(map[string][]*Channel)
+ newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int])
}
for _, channel := range channels {
- groups := strings.Split(channel.Group, ",")
+ groups := channel.GetGroups()
for _, group := range groups {
- models := strings.Split(channel.Models, ",")
+ models := channel.GetModels()
+ weightMapping := channel.GetWeightMapping()
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
- newGroup2model2channels[group][model] = make([]*Channel, 0)
+ newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0)
}
- newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
+ weight, ok := weightMapping[model]
+ if weight < 0 || !ok {
+ // use default value if:
+ // weight < 0: invalid
+ // !ok: weight not set
+ weight = common.DefaultWeight
+ }
+ newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight))
}
}
}
// sort by priority
+ m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int])
for group, model2channels := range newGroup2model2channels {
+ m[group] = make(map[string]*weightedrand.Chooser[*Channel, int])
for model, channels := range model2channels {
- sort.Slice(channels, func(i, j int) bool {
- return channels[i].GetPriority() > channels[j].GetPriority()
- })
- newGroup2model2channels[group][model] = channels
+ if len(channels) == 0 {
+ common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model))
+ continue
+ }
+ c, err := weightedrand.NewChooser(channels...)
+ if err != nil {
+ common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error()))
+ continue
+ }
+ m[group][model] = c
}
}
channelSyncLock.Lock()
- group2model2channels = newGroup2model2channels
+ group2model2channels = m
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
- if len(channels) == 0 {
+ if channels == nil {
return nil, errors.New("channel not found")
}
- endIdx := len(channels)
- // choose by priority
- firstChannel := channels[0]
- if firstChannel.GetPriority() > 0 {
- for i := range channels {
- if channels[i].GetPriority() != firstChannel.GetPriority() {
- endIdx = i
- break
- }
- }
- }
- idx := rand.Intn(endIdx)
- return channels[idx], nil
+ return channels.Pick(), nil
}
diff --git a/model/channel.go b/model/channel.go
index 7e7b42e6..294367c1 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,8 +1,10 @@
package model
import (
- "gorm.io/gorm"
+ "encoding/json"
"one-api/common"
+
+ "gorm.io/gorm"
)
type Channel struct {
@@ -23,6 +25,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+ WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
@@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
+func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) {
+ if channel.WeightMapping == nil || *channel.WeightMapping == "" {
+ return
+ }
+ err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
+ if err != nil {
+ common.SysError("failed to unmarshal weight mapping: " + err.Error())
+ }
+ return
+}
+
+func (channel *Channel) FixWeightMapping() {
+ var weightMapping map[string]int
+ if channel.WeightMapping == nil || *channel.WeightMapping == "" {
+ weightMapping = make(map[string]int)
+ } else {
+ err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
+ if err != nil {
+ common.SysError("failed to marshal weight mapping: " + err.Error())
+ }
+ }
+
+ models := channel.GetModels()
+ for _, model := range models {
+ if _, ok := weightMapping[model]; !ok {
+ weightMapping[model] = common.DefaultWeight
+ }
+ }
+
+ jsonStr, err := json.Marshal(weightMapping)
+ if err != nil {
+ common.SysError("failed to marshal weight mapping: " + err.Error())
+ }
+ var result = string(jsonStr)
+ channel.WeightMapping = &result
+}
+
+func (channel *Channel) GetModels() []string {
+ return common.SplitDistinct(channel.Models, ",")
+}
+
+func (channel *Channel) GetGroups() []string {
+ return common.SplitDistinct(channel.Group, ",")
+}
+
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
diff --git a/model/option.go b/model/option.go
index 4ef4d260..371b1e2c 100644
--- a/model/option.go
+++ b/model/option.go
@@ -71,6 +71,7 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
+ common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
+ case "DefaultWeight":
+ common.DefaultWeight, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js
index bf8b5ffd..1b2e2835 100644
--- a/web/src/components/OperationSetting.js
+++ b/web/src/components/OperationSetting.js
@@ -21,7 +21,8 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
- RetryTimes: 0
+ RetryTimes: 0,
+ DefaultWeight: 10
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@@ -128,6 +129,9 @@ const OperationSetting = () => {
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
+ if (originInputs['DefaultWeight'] !== inputs.DefaultWeight) {
+ await updateOption('DefaultWeight', inputs.DefaultWeight);
+ }
break;
}
};
@@ -150,7 +154,7 @@ const OperationSetting = () => {