Merge 354ad3867a
into 505817ca17
This commit is contained in:
commit
ed0a3dcf5e
28
controller/tag.go
Normal file
28
controller/tag.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"net/http"
|
||||||
|
"one-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetTags(c *gin.Context) {
|
||||||
|
tags := make([]string, 0)
|
||||||
|
var checkMap = make(map[string]int)
|
||||||
|
channels, err := model.GetAllChannels(0, 0, true)
|
||||||
|
if err == nil{
|
||||||
|
for i := range channels {
|
||||||
|
if _, ok := checkMap[channels[i].Tag];ok{
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tags = append(tags, channels[i].Tag)
|
||||||
|
checkMap[channels[i].Tag] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": tags,
|
||||||
|
})
|
||||||
|
}
|
@ -119,6 +119,7 @@ func AddToken(c *gin.Context) {
|
|||||||
cleanToken := model.Token{
|
cleanToken := model.Token{
|
||||||
UserId: c.GetInt("id"),
|
UserId: c.GetInt("id"),
|
||||||
Name: token.Name,
|
Name: token.Name,
|
||||||
|
Tag: token.Tag,
|
||||||
Key: common.GenerateKey(),
|
Key: common.GenerateKey(),
|
||||||
CreatedTime: common.GetTimestamp(),
|
CreatedTime: common.GetTimestamp(),
|
||||||
AccessedTime: common.GetTimestamp(),
|
AccessedTime: common.GetTimestamp(),
|
||||||
@ -210,6 +211,7 @@ func UpdateToken(c *gin.Context) {
|
|||||||
cleanToken.ExpiredTime = token.ExpiredTime
|
cleanToken.ExpiredTime = token.ExpiredTime
|
||||||
cleanToken.RemainQuota = token.RemainQuota
|
cleanToken.RemainQuota = token.RemainQuota
|
||||||
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
cleanToken.UnlimitedQuota = token.UnlimitedQuota
|
||||||
|
cleanToken.Tag = token.Tag
|
||||||
}
|
}
|
||||||
err = cleanToken.Update()
|
err = cleanToken.Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -106,6 +106,8 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_name", token.Name)
|
c.Set("token_name", token.Name)
|
||||||
|
c.Set("token_tag", token.Tag)
|
||||||
|
print(token.Tag)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
|
@ -65,7 +65,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelRequest.Model = "whisper-1"
|
modelRequest.Model = "whisper-1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
tag := c.GetString("token_tag")
|
||||||
|
//channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
|
channel, err = model.CacheGetChannelByTag(userGroup, modelRequest.Model, tag)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
@ -80,6 +82,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
if channel.Organization != ""{
|
||||||
|
c.Request.Header.Set("Organization", channel.Organization)
|
||||||
|
}
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
type Ability struct {
|
type Ability struct {
|
||||||
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
|
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
|
||||||
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
|
||||||
|
Tag string `json:"tag" gorm:"index"`
|
||||||
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
||||||
@ -39,6 +40,32 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
|||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRandomSatisfiedChannelByTag(group string, model string, tag string) (*Channel, error) {
|
||||||
|
ability := Ability{}
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error = nil
|
||||||
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and tag= ? and enabled = "+trueVal, group, model, tag)
|
||||||
|
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?) and tag= ?", group, model, maxPrioritySubQuery, tag)
|
||||||
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
|
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||||
|
} else {
|
||||||
|
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
channel := Channel{}
|
||||||
|
channel.Id = ability.ChannelId
|
||||||
|
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
|
||||||
|
return &channel, err
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) AddAbilities() error {
|
func (channel *Channel) AddAbilities() error {
|
||||||
models_ := strings.Split(channel.Models, ",")
|
models_ := strings.Split(channel.Models, ",")
|
||||||
groups_ := strings.Split(channel.Group, ",")
|
groups_ := strings.Split(channel.Group, ",")
|
||||||
@ -51,6 +78,7 @@ func (channel *Channel) AddAbilities() error {
|
|||||||
ChannelId: channel.Id,
|
ChannelId: channel.Id,
|
||||||
Enabled: channel.Status == common.ChannelStatusEnabled,
|
Enabled: channel.Status == common.ChannelStatusEnabled,
|
||||||
Priority: channel.Priority,
|
Priority: channel.Priority,
|
||||||
|
Tag: channel.Tag,
|
||||||
}
|
}
|
||||||
abilities = append(abilities, ability)
|
abilities = append(abilities, ability)
|
||||||
}
|
}
|
||||||
|
@ -213,3 +213,23 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
|||||||
idx := rand.Intn(endIdx)
|
idx := rand.Intn(endIdx)
|
||||||
return channels[idx], nil
|
return channels[idx], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CacheGetChannelByTag(group string, model string, tag string) (*Channel, error){
|
||||||
|
if !common.MemoryCacheEnabled {
|
||||||
|
return GetRandomSatisfiedChannelByTag(group, model, tag)
|
||||||
|
}
|
||||||
|
channelSyncLock.RLock()
|
||||||
|
defer channelSyncLock.RUnlock()
|
||||||
|
channels := group2model2channels[group][model]
|
||||||
|
if len(channels) == 0 {
|
||||||
|
return nil, errors.New("channel not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range channels {
|
||||||
|
if channels[i].Tag == tag{
|
||||||
|
return channels[i], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("channel not found")
|
||||||
|
}
|
||||||
|
@ -9,6 +9,8 @@ type Channel struct {
|
|||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
Type int `json:"type" gorm:"default:0"`
|
Type int `json:"type" gorm:"default:0"`
|
||||||
Key string `json:"key" gorm:"not null;index"`
|
Key string `json:"key" gorm:"not null;index"`
|
||||||
|
Organization string `json:"organization" gorm:"index"`
|
||||||
|
Tag string `json:"tag" gorm:"index"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight *uint `json:"weight" gorm:"default:0"`
|
Weight *uint `json:"weight" gorm:"default:0"`
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
type Token struct {
|
type Token struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id"`
|
||||||
|
Tag string `json:"tag" gorm:"index"`
|
||||||
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index" `
|
Name string `json:"name" gorm:"index" `
|
||||||
@ -102,7 +103,7 @@ func (token *Token) Insert() error {
|
|||||||
// Update Make sure your token's fields is completed, because this will update non-zero values
|
// Update Make sure your token's fields is completed, because this will update non-zero values
|
||||||
func (token *Token) Update() error {
|
func (token *Token) Update() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
|
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "tag").Updates(token).Error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,5 +110,10 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
{
|
{
|
||||||
groupRoute.GET("/", controller.GetGroups)
|
groupRoute.GET("/", controller.GetGroups)
|
||||||
}
|
}
|
||||||
|
tagRoute := apiRouter.Group("/tags")
|
||||||
|
tagRoute.Use(middleware.AdminAuth())
|
||||||
|
{
|
||||||
|
tagRoute.GET("/", controller.GetTags)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -308,6 +308,17 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Input
|
||||||
|
label='Tag'
|
||||||
|
required
|
||||||
|
name='tag'
|
||||||
|
placeholder={'请添加tag'}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.tag}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Dropdown
|
<Form.Dropdown
|
||||||
label='分组'
|
label='分组'
|
||||||
@ -420,19 +431,8 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
{
|
|
||||||
batch ? <Form.Field>
|
<Form.Field>
|
||||||
<Form.TextArea
|
|
||||||
label='密钥'
|
|
||||||
name='key'
|
|
||||||
required
|
|
||||||
placeholder={'请输入密钥,一行一个'}
|
|
||||||
onChange={handleInputChange}
|
|
||||||
value={inputs.key}
|
|
||||||
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
|
|
||||||
autoComplete='new-password'
|
|
||||||
/>
|
|
||||||
</Form.Field> : <Form.Field>
|
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='密钥'
|
label='密钥'
|
||||||
name='key'
|
name='key'
|
||||||
@ -443,17 +443,18 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
}
|
<Form.Field>
|
||||||
{
|
<Form.Input
|
||||||
!isEdit && (
|
label='organization'
|
||||||
<Form.Checkbox
|
name='organization'
|
||||||
checked={batch}
|
placeholder={type2secretPrompt(inputs.type)}
|
||||||
label='批量创建'
|
onChange={handleInputChange}
|
||||||
name='batch'
|
value={inputs.organization}
|
||||||
onChange={() => setBatch(!batch)}
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
)
|
</Form.Field>
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
|
@ -16,6 +16,7 @@ const EditToken = () => {
|
|||||||
unlimited_quota: false
|
unlimited_quota: false
|
||||||
};
|
};
|
||||||
const [inputs, setInputs] = useState(originInputs);
|
const [inputs, setInputs] = useState(originInputs);
|
||||||
|
const [tagOptions, setTagOptions] = useState([]);
|
||||||
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
|
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const handleInputChange = (e, { name, value }) => {
|
const handleInputChange = (e, { name, value }) => {
|
||||||
@ -59,9 +60,24 @@ const EditToken = () => {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
loadToken().then();
|
loadToken().then();
|
||||||
|
fetchTags().then();
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
|
||||||
|
const fetchTags = async () => {
|
||||||
|
try {
|
||||||
|
let res = await API.get(`/api/tags/`);
|
||||||
|
setTagOptions(res.data.data.map((group) => ({
|
||||||
|
key: group,
|
||||||
|
text: group,
|
||||||
|
value: group
|
||||||
|
})));
|
||||||
|
} catch (error) {
|
||||||
|
showError(error.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const submit = async () => {
|
const submit = async () => {
|
||||||
if (!isEdit && inputs.name === '') return;
|
if (!isEdit && inputs.name === '') return;
|
||||||
let localInputs = inputs;
|
let localInputs = inputs;
|
||||||
@ -109,6 +125,21 @@ const EditToken = () => {
|
|||||||
required={!isEdit}
|
required={!isEdit}
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Dropdown
|
||||||
|
label='tag'
|
||||||
|
placeholder={'请选择可以使用该渠道的分组'}
|
||||||
|
name='tag'
|
||||||
|
required
|
||||||
|
fluid
|
||||||
|
selection
|
||||||
|
allowAdditions
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.tag}
|
||||||
|
autoComplete='new-password'
|
||||||
|
options={tagOptions}
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
<Form.Field>
|
<Form.Field>
|
||||||
<Form.Input
|
<Form.Input
|
||||||
label='过期时间'
|
label='过期时间'
|
||||||
|
Loading…
Reference in New Issue
Block a user