This commit is contained in:
Qiying Wang 2024-01-01 17:52:20 +08:00 committed by GitHub
commit 75980d504a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 153 additions and 37 deletions

View File

@ -83,6 +83,7 @@ var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var DefaultWeight = 10
var RootUserEmail = ""

View File

@ -2,7 +2,6 @@ package common
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
@ -13,6 +12,8 @@ import (
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
func OpenBrowser(url string) {
@ -214,3 +215,16 @@ func String2Int(str string) int {
}
return num
}
func SplitDistinct(s, sep string) []string {
splited := strings.Split(s, sep)
set := make(map[string]struct{})
list := []string{}
for _, item := range splited {
if _, ok := set[item]; !ok {
set[item] = struct{}{}
list = append(list, item)
}
}
return list
}

View File

@ -1,12 +1,13 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func GetAllChannels(c *gin.Context) {
@ -92,6 +93,7 @@ func AddChannel(c *gin.Context) {
}
localChannel := channel
localChannel.Key = key
localChannel.FixWeightMapping()
channels = append(channels, localChannel)
}
err = model.BatchInsertChannels(channels)
@ -154,6 +156,9 @@ func UpdateChannel(c *gin.Context) {
})
return
}
if channel.Models != "" {
channel.FixWeightMapping()
}
err = channel.Update()
if err != nil {
c.JSON(http.StatusOK, gin.H{

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/mroth/weightedrand/v2 v2.1.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0

2
go.sum
View File

@ -111,6 +111,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mroth/weightedrand/v2 v2.1.0 h1:o1ascnB1CIVzsqlfArQQjeMy1U0NcIbBO5rfd5E/OeU=
github.com/mroth/weightedrand/v2 v2.1.0/go.mod h1:f2faGsfOGOwc1p94wzHKKZyTpcJUW7OJ/9U4yfiNAOU=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=

View File

@ -2,7 +2,6 @@ package model
import (
"one-api/common"
"strings"
)
type Ability struct {
@ -40,8 +39,8 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
models_ := channel.GetModels()
groups_ := channel.GetGroups()
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {

View File

@ -4,13 +4,12 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"one-api/common"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/mroth/weightedrand/v2"
)
var (
@ -132,7 +131,7 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
var group2model2channels map[string]map[string][]*Channel
var group2model2channels map[string]map[string]*weightedrand.Chooser[*Channel, int]
var channelSyncLock sync.RWMutex
func InitChannelCache() {
@ -148,35 +147,51 @@ func InitChannelCache() {
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]*Channel)
newGroup2model2channels := make(map[string]map[string][]weightedrand.Choice[*Channel, int])
for group := range groups {
newGroup2model2channels[group] = make(map[string][]*Channel)
newGroup2model2channels[group] = make(map[string][]weightedrand.Choice[*Channel, int])
}
for _, channel := range channels {
groups := strings.Split(channel.Group, ",")
groups := channel.GetGroups()
for _, group := range groups {
models := strings.Split(channel.Models, ",")
models := channel.GetModels()
weightMapping := channel.GetWeightMapping()
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]*Channel, 0)
newGroup2model2channels[group][model] = make([]weightedrand.Choice[*Channel, int], 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
weight, ok := weightMapping[model]
if weight < 0 || !ok {
// use default value if:
// weight < 0: invalid
// !ok: weight not set
weight = common.DefaultWeight
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], weightedrand.NewChoice(channel, weight))
}
}
}
// sort by priority
m := make(map[string]map[string]*weightedrand.Chooser[*Channel, int])
for group, model2channels := range newGroup2model2channels {
m[group] = make(map[string]*weightedrand.Chooser[*Channel, int])
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
if len(channels) == 0 {
common.SysError(fmt.Sprintf("no channel found for group %s model %s", group, model))
continue
}
c, err := weightedrand.NewChooser(channels...)
if err != nil {
common.SysError(fmt.Sprintf("failed to create chooser: %s", err.Error()))
continue
}
m[group][model] = c
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
group2model2channels = m
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
@ -196,20 +211,8 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
if len(channels) == 0 {
if channels == nil {
return nil, errors.New("channel not found")
}
endIdx := len(channels)
// choose by priority
firstChannel := channels[0]
if firstChannel.GetPriority() > 0 {
for i := range channels {
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
}
}
idx := rand.Intn(endIdx)
return channels[idx], nil
return channels.Pick(), nil
}

View File

@ -1,8 +1,10 @@
package model
import (
"gorm.io/gorm"
"encoding/json"
"one-api/common"
"gorm.io/gorm"
)
type Channel struct {
@ -23,6 +25,7 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
WeightMapping *string `json:"weight_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
}
@ -72,6 +75,51 @@ func BatchInsertChannels(channels []Channel) error {
return nil
}
func (channel *Channel) GetWeightMapping() (weightMapping map[string]int) {
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
return
}
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to unmarshal weight mapping: " + err.Error())
}
return
}
func (channel *Channel) FixWeightMapping() {
var weightMapping map[string]int
if channel.WeightMapping == nil || *channel.WeightMapping == "" {
weightMapping = make(map[string]int)
} else {
err := json.Unmarshal([]byte(*channel.WeightMapping), &weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
}
models := channel.GetModels()
for _, model := range models {
if _, ok := weightMapping[model]; !ok {
weightMapping[model] = common.DefaultWeight
}
}
jsonStr, err := json.Marshal(weightMapping)
if err != nil {
common.SysError("failed to marshal weight mapping: " + err.Error())
}
var result = string(jsonStr)
channel.WeightMapping = &result
}
func (channel *Channel) GetModels() []string {
return common.SplitDistinct(channel.Models, ",")
}
func (channel *Channel) GetGroups() []string {
return common.SplitDistinct(channel.Group, ",")
}
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0

View File

@ -72,6 +72,7 @@ func InitOptionMap() {
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMap["DefaultWeight"] = strconv.Itoa(common.DefaultWeight)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
}
@ -208,6 +209,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes":
common.RetryTimes, _ = strconv.Atoi(value)
case "DefaultWeight":
common.DefaultWeight, _ = strconv.Atoi(value)
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":

View File

@ -22,7 +22,8 @@ const OperationSetting = () => {
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0
RetryTimes: 0,
DefaultWeight: 10
});
const [originInputs, setOriginInputs] = useState({});
let [loading, setLoading] = useState(false);
@ -129,6 +130,9 @@ const OperationSetting = () => {
if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
await updateOption('RetryTimes', inputs.RetryTimes);
}
if (originInputs['DefaultWeight'] !== inputs.DefaultWeight) {
await updateOption('DefaultWeight', inputs.DefaultWeight);
}
break;
}
};
@ -151,7 +155,7 @@ const OperationSetting = () => {
<Header as='h3'>
通用设置
</Header>
<Form.Group widths={4}>
<Form.Group widths={5}>
<Form.Input
label='充值链接'
name='TopUpLink'
@ -191,6 +195,17 @@ const OperationSetting = () => {
value={inputs.RetryTimes}
placeholder='失败重试次数'
/>
<Form.Input
label='默认权重'
name='DefaultWeight'
type={'number'}
step='1'
min='0'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.DefaultWeight}
placeholder='默认权重'
/>
</Form.Group>
<Form.Group inline>
<Form.Checkbox

View File

@ -10,6 +10,12 @@ const MODEL_MAPPING_EXAMPLE = {
'gpt-4-32k-0314': 'gpt-4-32k'
};
const WEIGHT_MAPPING_EXAMPLE = {
'gpt-3.5-turbo-0301': 120,
'gpt-4-0314': 10,
'gpt-4-32k-0314': 10
};
function type2secretPrompt(type) {
// inputs.type === 15 ? '按照如下格式输入APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
switch (type) {
@ -43,6 +49,7 @@ const EditChannel = () => {
base_url: '',
other: '',
model_mapping: '',
weight_mapping: '',
models: [],
groups: ['default']
};
@ -115,6 +122,9 @@ const EditChannel = () => {
if (data.model_mapping !== '') {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
if (data.weight_mapping !== '') {
data.weight_mapping = JSON.stringify(JSON.parse(data.weight_mapping), null, 2);
}
setInputs(data);
} else {
showError(message);
@ -188,6 +198,10 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
if (inputs.weight_mapping !== '' && !verifyJSON(inputs.weight_mapping)) {
showInfo('模型权重必须是合法的 JSON 格式!');
return;
}
let localInputs = inputs;
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@ -420,6 +434,17 @@ const EditChannel = () => {
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.TextArea
label='模型权重'
placeholder={`此项可选,用于修改请求体中的模型权重,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型权重,推荐填写模型的 TPM 值,例如:\n${JSON.stringify(WEIGHT_MAPPING_EXAMPLE, null, 2)}`}
name='weight_mapping'
onChange={handleInputChange}
value={inputs.weight_mapping}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field>
{
batch ? <Form.Field>
<Form.TextArea