feat: able to set model limitation for token (close #178)
This commit is contained in:
parent
065da8ef8c
commit
dc7aaf2de5
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
@ -142,3 +143,30 @@ func RetrieveModel(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func GetUserAvailableModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
id := c.GetInt("id")
|
||||
userGroup, err := model.CacheGetUserGroup(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": models,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -216,6 +216,7 @@ func UpdateToken(c *gin.Context) {
|
||||
cleanToken.ExpiredTime = token.ExpiredTime
|
||||
cleanToken.RemainQuota = token.RemainQuota
|
||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||
cleanToken.Models = token.Models
|
||||
}
|
||||
err = cleanToken.Update()
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
@ -107,6 +108,18 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set("request_model", requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
if !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
|
@ -2,14 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
requestModel := c.GetString("request_model")
|
||||
var err error
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
requestModel = modelRequest.Model
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||
if channel != nil {
|
||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
|
@ -1,9 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.Abort()
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
return modelRequest.Model, nil
|
||||
}
|
||||
|
||||
func isModelInList(modelName string, models string) bool {
|
||||
modelList := strings.Split(models, ",")
|
||||
for _, model := range modelList {
|
||||
if modelName == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ var (
|
||||
UserId2GroupCacheSeconds = config.SyncFrequency
|
||||
UserId2QuotaCacheSeconds = config.SyncFrequency
|
||||
UserId2StatusCacheSeconds = config.SyncFrequency
|
||||
GroupModelsCacheSeconds = config.SyncFrequency
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
||||
return userEnabled, err
|
||||
}
|
||||
|
||||
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetGroupModels(ctx, group)
|
||||
}
|
||||
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
|
||||
if err == nil {
|
||||
return strings.Split(modelsStr, ","), nil
|
||||
}
|
||||
models, err := GetGroupModels(ctx, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
|
||||
if err != nil {
|
||||
logger.SysError("Redis set group models error: " + err.Error())
|
||||
}
|
||||
return models, nil
|
||||
}
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
var channelSyncLock sync.RWMutex
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
@ -8,6 +9,8 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"gorm.io/gorm"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@ -25,7 +28,7 @@ type Channel struct {
|
||||
Balance float64 `json:"balance"` // in USD
|
||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||
Models string `json:"models"`
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
Group string `json:"group" gorm:"index;type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
@ -202,3 +205,28 @@ func DeleteDisabledChannel() (int64, error) {
|
||||
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
var modelsList []string
|
||||
err := DB.Model(&Channel{}).Distinct("models").Where(groupCol+" = ?", group).Pluck("models", &modelsList).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
set := make(map[string]bool)
|
||||
for i := 0; i < len(modelsList); i++ {
|
||||
modelList := strings.Split(modelsList[i], ",")
|
||||
for _, model := range modelList {
|
||||
set[model] = true
|
||||
}
|
||||
}
|
||||
modelList := make([]string, 0, len(set))
|
||||
for model := range set {
|
||||
modelList = append(modelList, model)
|
||||
}
|
||||
sort.Strings(modelList)
|
||||
return modelList, err
|
||||
}
|
||||
|
@ -12,24 +12,25 @@ import (
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id"`
|
||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
Name string `json:"name" gorm:"index" `
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
|
||||
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
|
||||
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
|
||||
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
|
||||
Models *string `json:"models" gorm:"default:''"`
|
||||
}
|
||||
|
||||
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
|
||||
var tokens []*Token
|
||||
var err error
|
||||
query := DB.Where("user_id = ?", userId)
|
||||
|
||||
|
||||
switch order {
|
||||
case "remain_quota":
|
||||
query = query.Order("unlimited_quota desc, remain_quota desc")
|
||||
@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
|
||||
default:
|
||||
query = query.Order("id desc")
|
||||
}
|
||||
|
||||
|
||||
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
|
||||
return tokens, err
|
||||
}
|
||||
@ -121,7 +122,7 @@ func (token *Token) Insert() error {
|
||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||
func (token *Token) Update() error {
|
||||
var err error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -43,6 +43,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.POST("/topup", controller.TopUp)
|
||||
selfRoute.GET("/available_models", controller.GetUserAvailableModels)
|
||||
}
|
||||
|
||||
adminRoute := userRoute.Group("/")
|
||||
|
@ -1,19 +1,21 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
|
||||
import { useParams, useNavigate } from 'react-router-dom';
|
||||
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
|
||||
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
|
||||
import { useNavigate, useParams } from 'react-router-dom';
|
||||
import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers';
|
||||
import { renderQuotaWithPrompt } from '../../helpers/render';
|
||||
|
||||
const EditToken = () => {
|
||||
const params = useParams();
|
||||
const tokenId = params.id;
|
||||
const isEdit = tokenId !== undefined;
|
||||
const [loading, setLoading] = useState(isEdit);
|
||||
const [modelOptions, setModelOptions] = useState([]);
|
||||
const originInputs = {
|
||||
name: '',
|
||||
remain_quota: isEdit ? 0 : 500000,
|
||||
expired_time: -1,
|
||||
unlimited_quota: false
|
||||
unlimited_quota: false,
|
||||
models: []
|
||||
};
|
||||
const [inputs, setInputs] = useState(originInputs);
|
||||
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
|
||||
@ -22,8 +24,8 @@ const EditToken = () => {
|
||||
setInputs((inputs) => ({ ...inputs, [name]: value }));
|
||||
};
|
||||
const handleCancel = () => {
|
||||
navigate("/token");
|
||||
}
|
||||
navigate('/token');
|
||||
};
|
||||
const setExpiredTime = (month, day, hour, minute) => {
|
||||
let now = new Date();
|
||||
let timestamp = now.getTime() / 1000;
|
||||
@ -50,6 +52,11 @@ const EditToken = () => {
|
||||
if (data.expired_time !== -1) {
|
||||
data.expired_time = timestamp2string(data.expired_time);
|
||||
}
|
||||
if (data.models === '') {
|
||||
data.models = [];
|
||||
} else {
|
||||
data.models = data.models.split(',');
|
||||
}
|
||||
setInputs(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@ -60,8 +67,26 @@ const EditToken = () => {
|
||||
if (isEdit) {
|
||||
loadToken().then();
|
||||
}
|
||||
loadAvailableModels().then();
|
||||
}, []);
|
||||
|
||||
const loadAvailableModels = async () => {
|
||||
let res = await API.get(`/api/user/available_models`);
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let options = data.map((model) => {
|
||||
return {
|
||||
key: model,
|
||||
text: model,
|
||||
value: model
|
||||
};
|
||||
});
|
||||
setModelOptions(options);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const submit = async () => {
|
||||
if (!isEdit && inputs.name === '') return;
|
||||
let localInputs = inputs;
|
||||
@ -74,6 +99,7 @@ const EditToken = () => {
|
||||
}
|
||||
localInputs.expired_time = Math.ceil(time / 1000);
|
||||
}
|
||||
localInputs.models = localInputs.models.join(',');
|
||||
let res;
|
||||
if (isEdit) {
|
||||
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
|
||||
@ -109,6 +135,24 @@ const EditToken = () => {
|
||||
required={!isEdit}
|
||||
/>
|
||||
</Form.Field>
|
||||
<Form.Field>
|
||||
<Form.Dropdown
|
||||
label='模型范围'
|
||||
placeholder={'请选择允许使用的模型,留空则不进行限制'}
|
||||
name='models'
|
||||
fluid
|
||||
multiple
|
||||
search
|
||||
onLabelClick={(e, { value }) => {
|
||||
copy(value).then();
|
||||
}}
|
||||
selection
|
||||
onChange={handleInputChange}
|
||||
value={inputs.models}
|
||||
autoComplete='new-password'
|
||||
options={modelOptions}
|
||||
/>
|
||||
</Form.Field>
|
||||
<Form.Field>
|
||||
<Form.Input
|
||||
label='过期时间'
|
||||
|
Loading…
Reference in New Issue
Block a user