diff --git a/common/group-ratio.go b/common/group-ratio.go new file mode 100644 index 00000000..0a9cf4ba --- /dev/null +++ b/common/group-ratio.go @@ -0,0 +1,30 @@ +package common + +import "encoding/json" + +var GroupRatio = map[string]float64{ + "default": 1, + "vip": 1, + "svip": 1, +} + +func GroupRatio2JSONString() string { + jsonBytes, err := json.Marshal(GroupRatio) + if err != nil { + SysError("Error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateGroupRatioByJSONString(jsonStr string) error { + return json.Unmarshal([]byte(jsonStr), &GroupRatio) +} + +func GetGroupRatio(name string) float64 { + ratio, ok := GroupRatio[name] + if !ok { + SysError("Group ratio not found: " + name) + return 1 + } + return ratio +} diff --git a/controller/group.go b/controller/group.go new file mode 100644 index 00000000..2b2f6006 --- /dev/null +++ b/controller/group.go @@ -0,0 +1,19 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/common" +) + +func GetGroups(c *gin.Context) { + groupNames := make([]string, 0) + for groupName, _ := range common.GroupRatio { + groupNames = append(groupNames, groupName) + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": groupNames, + }) +} diff --git a/controller/relay.go b/controller/relay.go index ed47ceb3..fb4c358f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -140,6 +140,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") + group := c.GetString("group") var textRequest GeneralOpenAIRequest if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { err := common.UnmarshalBodyReusable(c, &textRequest) @@ -194,7 +195,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens } - ratio := common.GetModelRatio(textRequest.Model) + ratio := common.GetModelRatio(textRequest.Model) * common.GetGroupRatio(group) preConsumedQuota := int(float64(preConsumedTokens) * ratio) if consumeQuota { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) diff --git a/middleware/distributor.go b/middleware/distributor.go index 0f4221bf..08568ea1 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -16,6 +16,9 @@ type ModelRequest struct { func Distribute() func(c *gin.Context) { return func(c *gin.Context) { + userId := c.GetInt("id") + userGroup, _ := model.GetUserGroup(userId) + c.Set("group", userGroup) var channel *model.Channel channelId, ok := c.Get("channelId") if ok { @@ -70,8 +73,6 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "text-moderation-stable" } } - userId := c.GetInt("id") - userGroup, _ := model.GetUserGroup(userId) channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { c.JSON(200, gin.H{ diff --git a/model/option.go b/model/option.go index 5f7ad36c..5e74984d 100644 --- a/model/option.go +++ b/model/option.go @@ -58,6 +58,7 @@ func InitOptionMap() { common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) { common.PreConsumedQuota, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) + case "GroupRatio": + err = common.UpdateGroupRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value case "ChannelDisableThreshold": diff --git a/router/api-router.go b/router/api-router.go index de8249b1..062ccac1 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -98,5 +98,10 @@ func SetApiRouter(router *gin.Engine) { logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) + groupRoute := apiRouter.Group("/group") + groupRoute.Use(middleware.AdminAuth()) + { + groupRoute.GET("/", controller.GetGroups) + } } } diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index 0977a804..3b40822b 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -30,6 +30,7 @@ const SystemSetting = () => { QuotaRemindThreshold: 0, PreConsumedQuota: 0, ModelRatio: '', + GroupRatio: '', TopUpLink: '', AutomaticDisableChannelEnabled: '', ChannelDisableThreshold: 0, @@ -101,6 +102,7 @@ const SystemSetting = () => { name === 'QuotaRemindThreshold' || name === 'PreConsumedQuota' || name === 'ModelRatio' || + name === 'GroupRatio' || name === 'TopUpLink' ) { setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -131,6 +133,13 @@ const SystemSetting = () => { } await updateOption('ModelRatio', inputs.ModelRatio); } + if (originInputs['GroupRatio'] !== inputs.GroupRatio) { + if (!verifyJSON(inputs.GroupRatio)) { + showError('分组倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('GroupRatio', inputs.GroupRatio); + } if (originInputs['TopUpLink'] !== inputs.TopUpLink) { await updateOption('TopUpLink', inputs.TopUpLink); } @@ -329,6 +338,17 @@ const SystemSetting = () => { placeholder='为一个 JSON 文本,键为模型名称,值为倍率' /> +