feat: add weight_mapping for channel

This commit is contained in:
WqyJh 2023-11-01 14:58:04 +08:00
parent aec343dc38
commit 4a27d75d57
11 changed files with 153 additions and 37 deletions

View File

@ -82,6 +82,7 @@ var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500 var PreConsumedQuota = 500
var ApproximateTokenEnabled = false var ApproximateTokenEnabled = false
var RetryTimes = 0 var RetryTimes = 0
var DefaultWeight = 10
var RootUserEmail = "" var RootUserEmail = ""

View File

@ -2,7 +2,6 @@ package common
import ( import (
"fmt" "fmt"
"github.com/google/uuid"
"html/template" "html/template"
"log" "log"
"math/rand" "math/rand"
@ -13,6 +12,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/google/uuid"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@ -207,3 +208,16 @@ func String2Int(str string) int {
} }
return num return num
} }
func SplitDistinct(s, sep string) []string {
splited := strings.Split(s, sep)
set := make(map[string]struct{})
list := []string{}
for _, item := range splited {
if _, ok := set[item]; !ok {
set[item] = struct{}{}
list = append(list, item)
}
}
return list
}

View File

@ -1,12 +1,13 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func GetAllChannels(c *gin.Context) { func GetAllChannels(c *gin.Context) {
@ -92,6 +93,7 @@ func AddChannel(c *gin.Context) {
} }
localChannel := channel localChannel := channel
localChannel.Key = key localChannel.Key = key
localChannel.FixWeightMapping()
channels = append(channels, localChannel) channels = append(channels, localChannel)
} }
err = model.BatchInsertChannels(channels) err = model.BatchInsertChannels(channels)
@ -154,6 +156,9 @@ func UpdateChannel(c *gin.Context) {
}) })
return return
} }
if channel.Models != "" {
channel.FixWeightMapping()
}
err = channel.Update() err = channel.Update()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/mroth/weightedrand/v2 v2.1.0
github.com/pkoukk/tiktoken-go v0.1.5 github.com/pkoukk/tiktoken-go v0.1.5
golang.org/x/crypto v0.14.0 golang.org/x/crypto v0.14.0
gorm.io/driver/mysql v1.4.3 gorm.io/driver/mysql v1.4.3

2
go.sum
View File

@ -111,6 +111,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mroth/weightedrand/v2 v2.1.0 h1:o1ascnB1CIVzsqlfArQQjeMy1U0NcIbBO5rfd5E/OeU=
github.com/mroth/weightedrand/v2 v2.1.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7OJ/9U4yfiNAOU=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=

View File

@ -2,7 +2,6 @@ package model
import ( import (
"one-api/common" "one-api/common"
"strings"
) )
type Ability struct { type Ability struct {
@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
} }
func (channel *Channel) AddAbilities() error { func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",") models_ := channel.GetModels()
groups_ := strings.Split(channel.Group, ",") groups_ := channel.GetGroups()
abilities := make([]Ability, 0, len(models_)) abilities := make([]Ability, 0, len(models_))
for _, model := range models_ { for _, model := range models_ {
for _, group := range groups_ { for _, group := range groups_ {

View File

@ -4,13 +4,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"one-api/common" "one-api/common"
"sort"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/mroth/weightedrand/v2"
) )
var ( var (
@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err return userEnabled, err
} }
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int]
var channelSyncLock sync.RWMutex var channelSyncLock sync.RWMutex
func InitChannelCache() { func InitChannelCache() {
@ -148,35 +147,51 @@ func InitChannelCache() {
for _, ability := range abilities { for _, ability := range abilities {
groups[ability.Group] = true groups[ability.Group] = true
} }
newGroup2model2channels := make(map[string]map[string][]*Channel) newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int])
for group := range groups { for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel) newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int])
} }
for _, channel := range channels { for _, channel := range channels {
groups := strings.Split(channel.Group, ",") groups := channel.GetGroups()
for _, group := range groups { for _, group := range groups {
models := strings.Split(channel.Models, ",") models := channel.GetModels()
weightMapping := channel.GetWeightMapping()
for _, model := range models { for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok { if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0) newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0)
} }
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) weight, ok := weightMapping[model]
if weight < 0 || !ok {
// use default value if:
// weight < 0: invalid
// !ok: weight not set
weight = common.DefaultWeight
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight))
} }
} }
} }
// sort by priority // sort by priority
m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int])
for group, model2channels := range newGroup2model2channels { for group, model2channels := range newGroup2model2channels {
m[group] = make(map[string]*weightedrand.Chooser[*Channel, int])
for model, channels := range model2channels { for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool { if len(channels) == 0 {
return channels[i].GetPriority() > channels[j].GetPriority() common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model))
}) continue
newGroup2model2channels[group][model] = channels }
c, err := weightedrand.NewChooser(channels...)
if err != nil {
common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error()))
continue
}
m[group][model] = c
} }
} }
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = m
channelSyncLock.Unlock() channelSyncLock.Unlock()
common.SysLog("channels synced from database") common.SysLog("channels synced from database")
} }
@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
channelSyncLock.RLock() channelSyncLock.RLock()
defer channelSyncLock.RUnlock() defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model] channels := group2model2channels[group][model]
if len(channels) == 0 { if channels == nil {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
endIdx := len(channels) return channels.Pick(), nil
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
} }

View File

@ -1,8 +1,10 @@
package model package model
import ( import (
"gorm.io/gorm" "encoding/json"
"one-api/common" "one-api/common"
"gorm.io/gorm"
) )
type Channel struct { type Channel struct {
@ -23,6 +25,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
} }
@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error {
return nil return nil
} }
func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) {
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
return
}
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to unmarshal weight mapping: " + err.Error())
}
return
}
func (channel *Channel) FixWeightMapping() {
var weightMapping map[string]int
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
weightMapping = make(map[string]int)
} else {
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
}
models := channel.GetModels()
for _, model := range models {
if _, ok := weightMapping[model]; !ok {
weightMapping[model] = common.DefaultWeight
}
}
jsonStr, err := json.Marshal(weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
var result = string(jsonStr)
channel.WeightMapping = &result
}
func (channel *Channel) GetModels() []string {
return common.SplitDistinct(channel.Models, ",")
}
func (channel *Channel) GetGroups() []string {
return common.SplitDistinct(channel.Group, ",")
}
func (channel *Channel) GetPriority() int64 { func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil { if channel.Priority == nil {
return 0 return 0

View File

@ -71,6 +71,7 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight)
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
@ -205,6 +206,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)
case "DefaultWeight":
common.DefaultWeight, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":

View File

@ -21,7 +21,8 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '', DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '', DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '', ApproximateTokenEnabled: '',
RetryTimes: 0 RetryTimes: 0,
DefaultWeight: 10
}); });
const [originInputs, setOriginInputs] = useState({}); const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false); let [loading, setLoading] = useState(false);
@ -128,6 +129,9 @@ const OperationSetting = () => {
if (originInputs['RetryTimes'] !== inputs.RetryTimes) { if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes); await updateOption('RetryTimes', inputs.RetryTimes);
} }
if (originInputs['DefaultWeight'] !== inputs.DefaultWeight) {
await updateOption('DefaultWeight', inputs.DefaultWeight);
}
break; break;
} }
}; };
@ -150,7 +154,7 @@ const OperationSetting = () => {
<Header as='h3'> <Header as='h3'>
通用设置 通用设置
</Header> </Header>
<Form.Group widths={4}> <Form.Group widths={5}>
<Form.Input <Form.Input
label='充值链接' label='充值链接'
name='TopUpLink' name='TopUpLink'
@ -190,6 +194,17 @@ const OperationSetting = () => {
value={inputs.RetryTimes} value={inputs.RetryTimes}
placeholder='失败重试次数' placeholder='失败重试次数'
/> />
<Form.Input
label='默认权重'
name='DefaultWeight'
type={'number'}
step='1'
min='0'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DefaultWeight}
placeholder='默认权重'
/>
</Form.Group> </Form.Group>
<Form.Group inline> <Form.Group inline>
<Form.Checkbox <Form.Checkbox

View File

@ -10,6 +10,12 @@ const MODEL_MAPPING_EXAMPLE = {
'gpt-4-32k-0314': 'gpt-4-32k' 'gpt-4-32k-0314': 'gpt-4-32k'
}; };
const WEIGHT_MAPPING_EXAMPLE = {
'gpt-3.5-turbo-0301': 120,
'gpt-4-0314': 10,
'gpt-4-32k-0314': 10
};
function type2secretPrompt(type) { function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥') // inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) { switch (type) {
@ -43,6 +49,7 @@ const EditChannel = () => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
weight_mapping: '',
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@ -105,6 +112,9 @@ const EditChannel = () => {
if (data.model_mapping !== '') { if (data.model_mapping !== '') {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
} }
if (data.weight_mapping !== '') {
data.weight_mapping = JSON.stringify(JSON.parse(data.weight_mapping), null, 2);
}
setInputs(data); setInputs(data);
} else { } else {
showError(message); showError(message);
@ -178,6 +188,10 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!'); showInfo('模型映射必须是合法的 JSON 格式!');
return; return;
} }
if (inputs.weight_mapping !== '' && !verifyJSON(inputs.weight_mapping)) {
showInfo('模型权重必须是合法的 JSON 格式!');
return;
}
let localInputs = inputs; let localInputs = inputs;
if (localInputs.base_url && localInputs.base_url.endsWith('/')) { if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@ -396,6 +410,17 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.TextArea
label='模型权重'
placeholder={`此项可选,用于修改请求体中的模型权重,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型权重,推荐填写模型的 TPM 值,例如:\n${JSON.stringify(WEIGHT_MAPPING_EXAMPLE, null, 2)}`}
name='weight_mapping'
onChange={handleInputChange}
value={inputs.weight_mapping}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
{ {
batch ? <Form.Field> batch ? <Form.Field>
<Form.TextArea <Form.TextArea