feat: able to set group ratio now (close #62, close #142)

This commit is contained in:
JustSong 2023-06-11 11:08:16 +08:00
parent 9d0bec83df
commit 596446dba4
10 changed files with 134 additions and 8 deletions

30
common/group-ratio.go Normal file
View File

@ -0,0 +1,30 @@
package common
import "encoding/json"
var GroupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
if err != nil {
SysError("Error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
if !ok {
SysError("Group ratio not found: " + name)
return 1
}
return ratio
}

19
controller/group.go Normal file
View File

@ -0,0 +1,19 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": groupNames,
})
}

View File

@ -140,6 +140,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
consumeQuota := c.GetBool("consume_quota") consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest) err := common.UnmarshalBodyReusable(c, &textRequest)
@ -194,7 +195,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.MaxTokens != 0 { if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens preConsumedTokens = promptTokens + textRequest.MaxTokens
} }
ratio := common.GetModelRatio(textRequest.Model) ratio := common.GetModelRatio(textRequest.Model) * common.GetGroupRatio(group)
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
if consumeQuota { if consumeQuota {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)

View File

@ -16,6 +16,9 @@ type ModelRequest struct {
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.GetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("channelId") channelId, ok := c.Get("channelId")
if ok { if ok {
@ -70,8 +73,6 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "text-moderation-stable" modelRequest.Model = "text-moderation-stable"
} }
} }
userId := c.GetInt("id")
userGroup, _ := model.GetUserGroup(userId)
channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{

View File

@ -58,6 +58,7 @@ func InitOptionMap() {
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
common.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
common.TopUpLink = value common.TopUpLink = value
case "ChannelDisableThreshold": case "ChannelDisableThreshold":

View File

@ -98,5 +98,10 @@ func SetApiRouter(router *gin.Engine) {
logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs) logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs) logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs) logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
groupRoute := apiRouter.Group("/group")
groupRoute.Use(middleware.AdminAuth())
{
groupRoute.GET("/", controller.GetGroups)
}
} }
} }

View File

@ -30,6 +30,7 @@ const SystemSetting = () => {
QuotaRemindThreshold: 0, QuotaRemindThreshold: 0,
PreConsumedQuota: 0, PreConsumedQuota: 0,
ModelRatio: '', ModelRatio: '',
GroupRatio: '',
TopUpLink: '', TopUpLink: '',
AutomaticDisableChannelEnabled: '', AutomaticDisableChannelEnabled: '',
ChannelDisableThreshold: 0, ChannelDisableThreshold: 0,
@ -101,6 +102,7 @@ const SystemSetting = () => {
name === 'QuotaRemindThreshold' || name === 'QuotaRemindThreshold' ||
name === 'PreConsumedQuota' || name === 'PreConsumedQuota' ||
name === 'ModelRatio' || name === 'ModelRatio' ||
name === 'GroupRatio' ||
name === 'TopUpLink' name === 'TopUpLink'
) { ) {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
@ -131,6 +133,13 @@ const SystemSetting = () => {
} }
await updateOption('ModelRatio', inputs.ModelRatio); await updateOption('ModelRatio', inputs.ModelRatio);
} }
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
return;
}
await updateOption('GroupRatio', inputs.GroupRatio);
}
if (originInputs['TopUpLink'] !== inputs.TopUpLink) { if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
await updateOption('TopUpLink', inputs.TopUpLink); await updateOption('TopUpLink', inputs.TopUpLink);
} }
@ -329,6 +338,17 @@ const SystemSetting = () => {
placeholder='为一个 JSON 文本,键为模型名称,值为倍率' placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
/> />
</Form.Group> </Form.Group>
<Form.Group widths='equal'>
<Form.TextArea
label='分组倍率'
name='GroupRatio'
onChange={handleInputChange}
style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
value={inputs.GroupRatio}
placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
/>
</Form.Group>
<Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button> <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
<Divider /> <Divider />
<Header as='h3'> <Header as='h3'>

View File

@ -10,6 +10,10 @@ export function renderText(text, limit) {
export function renderGroup(group) { export function renderGroup(group) {
if (group === "") { if (group === "") {
return <Label>default</Label> return <Label>default</Label>
} else if (group === "vip" || group === "pro") {
return <Label color='yellow'>{group}</Label>
} else if (group === "svip" || group === "premium") {
return <Label color='red'>{group}</Label>
} }
return <Label>{group}</Label> return <Label>{group}</Label>
} }

View File

@ -21,6 +21,7 @@ const EditChannel = () => {
const [batch, setBatch] = useState(false); const [batch, setBatch] = useState(false);
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const [modelOptions, setModelOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]);
const [groupOptions, setGroupOptions] = useState([]);
const [basicModels, setBasicModels] = useState([]); const [basicModels, setBasicModels] = useState([]);
const [fullModels, setFullModels] = useState([]); const [fullModels, setFullModels] = useState([]);
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
@ -58,11 +59,25 @@ const EditChannel = () => {
} }
}; };
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
useEffect(() => { useEffect(() => {
if (isEdit) { if (isEdit) {
loadChannel().then(); loadChannel().then();
} }
fetchModels().then(); fetchModels().then();
fetchGroups().then();
}, []); }, []);
const submit = async () => { const submit = async () => {
@ -167,13 +182,19 @@ const EditChannel = () => {
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>
<Form.Input <Form.Dropdown
label='分组' label='分组'
placeholder={'请选择分组'}
name='group' name='group'
placeholder={'请输入分组'} fluid
search
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.group} value={inputs.group}
autoComplete='new-password' autoComplete='new-password'
options={groupOptions}
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>

View File

@ -17,11 +17,24 @@ const EditUser = () => {
quota: 0, quota: 0,
group: 'default' group: 'default'
}); });
const [groupOptions, setGroupOptions] = useState([]);
const { username, display_name, password, github_id, wechat_id, email, quota, group } = const { username, display_name, password, github_id, wechat_id, email, quota, group } =
inputs; inputs;
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
}; };
const fetchGroups = async () => {
try {
let res = await API.get(`/api/group`);
setGroupOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group,
})));
} catch (error) {
showError(error.message);
}
};
const loadUser = async () => { const loadUser = async () => {
let res = undefined; let res = undefined;
@ -41,6 +54,9 @@ const EditUser = () => {
}; };
useEffect(() => { useEffect(() => {
loadUser().then(); loadUser().then();
if (userId) {
fetchGroups().then();
}
}, []); }, []);
const submit = async () => { const submit = async () => {
@ -101,13 +117,19 @@ const EditUser = () => {
{ {
userId && <> userId && <>
<Form.Field> <Form.Field>
<Form.Input <Form.Dropdown
label='分组' label='分组'
placeholder={'请选择分组'}
name='group' name='group'
placeholder={'请输入用户分组'} fluid
search
selection
allowAdditions
additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
onChange={handleInputChange} onChange={handleInputChange}
value={group} value={inputs.group}
autoComplete='new-password' autoComplete='new-password'
options={groupOptions}
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>