This commit is contained in:
duyazhe 2024-01-01 17:52:21 +08:00 committed by GitHub
commit ed0a3dcf5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 151 additions and 26 deletions

28
controller/tag.go Normal file
View File

@ -0,0 +1,28 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/model"
)
func GetTags(c *gin.Context) {
tags := make([]string, 0)
var checkMap = make(map[string]int)
channels, err := model.GetAllChannels(0, 0, true)
if err == nil{
for i := range channels {
if _, ok := checkMap[channels[i].Tag];ok{
continue
}
tags = append(tags, channels[i].Tag)
checkMap[channels[i].Tag] = 1
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": tags,
})
}

View File

@ -119,6 +119,7 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Tag: token.Tag,
Key: common.GenerateKey(), Key: common.GenerateKey(),
CreatedTime: common.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(),
@ -210,6 +211,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.Tag = token.Tag
} }
err = cleanToken.Update() err = cleanToken.Update()
if err != nil { if err != nil {

View File

@ -106,6 +106,8 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
c.Set("token_tag", token.Tag)
print(token.Tag)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])

View File

@ -65,7 +65,9 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }
} }
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) tag := c.GetString("token_tag")
//channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
channel, err = model.CacheGetChannelByTag(userGroup, modelRequest.Model, tag)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {
@ -80,6 +82,9 @@ func Distribute() func(c *gin.Context) {
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
if channel.Organization != ""{
c.Request.Header.Set("Organization", channel.Organization)
}
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
switch channel.Type { switch channel.Type {

View File

@ -8,6 +8,7 @@ import (
type Ability struct { type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
Tag string `json:"tag" gorm:"index"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
@ -39,6 +40,32 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
return &channel, err return &channel, err
} }
func GetRandomSatisfiedChannelByTag(group string, model string, tag string) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and tag= ? and enabled = "+trueVal, group, model, tag)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?) and tag= ?", group, model, maxPrioritySubQuery, tag)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
err = channelQuery.Order("RAND()").First(&ability).Error
}
if err != nil {
return nil, err
}
channel := Channel{}
channel.Id = ability.ChannelId
err = DB.First(&channel, "id = ?", ability.ChannelId).Error
return &channel, err
}
func (channel *Channel) AddAbilities() error { func (channel *Channel) AddAbilities() error {
models_ := strings.Split(channel.Models, ",") models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",") groups_ := strings.Split(channel.Group, ",")
@ -51,6 +78,7 @@ func (channel *Channel) AddAbilities() error {
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority, Priority: channel.Priority,
Tag: channel.Tag,
} }
abilities = append(abilities, ability) abilities = append(abilities, ability)
} }

View File

@ -213,3 +213,23 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
return channels[idx], nil return channels[idx], nil
} }
func CacheGetChannelByTag(group string, model string, tag string) (*Channel, error){
if !common.MemoryCacheEnabled {
return GetRandomSatisfiedChannelByTag(group, model, tag)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
channels := group2model2channels[group][model]
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
for i := range channels {
if channels[i].Tag == tag{
return channels[i], nil
}
}
return nil, errors.New("channel not found")
}

View File

@ -9,6 +9,8 @@ type Channel struct {
Id int `json:"id"` Id int `json:"id"`
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
Key string `json:"key" gorm:"not null;index"` Key string `json:"key" gorm:"not null;index"`
Organization string `json:"organization" gorm:"index"`
Tag string `json:"tag" gorm:"index"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"` Weight *uint `json:"weight" gorm:"default:0"`

View File

@ -10,6 +10,7 @@ import (
type Token struct { type Token struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
Tag string `json:"tag" gorm:"index"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"` Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" ` Name string `json:"name" gorm:"index" `
@ -102,7 +103,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "tag").Updates(token).Error
return err return err
} }

View File

@ -110,5 +110,10 @@ func SetApiRouter(router *gin.Engine) {
{ {
groupRoute.GET("/", controller.GetGroups) groupRoute.GET("/", controller.GetGroups)
} }
tagRoute := apiRouter.Group("/tags")
tagRoute.Use(middleware.AdminAuth())
{
tagRoute.GET("/", controller.GetTags)
}
} }
} }

View File

@ -308,6 +308,17 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.Input
label='Tag'
required
name='tag'
placeholder={'请添加tag'}
onChange={handleInputChange}
value={inputs.tag}
autoComplete='new-password'
/>
</Form.Field>
<Form.Field> <Form.Field>
<Form.Dropdown <Form.Dropdown
label='分组' label='分组'
@ -420,19 +431,8 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
{
batch ? <Form.Field> <Form.Field>
<Form.TextArea
label='密钥'
name='key'
required
placeholder={'请输入密钥,一行一个'}
onChange={handleInputChange}
value={inputs.key}
style={{ minHeight: 150, fontFamily: 'JetBrains Mono, Consolas' }}
autoComplete='new-password'
/>
</Form.Field> : <Form.Field>
<Form.Input <Form.Input
label='密钥' label='密钥'
name='key' name='key'
@ -443,17 +443,18 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
} <Form.Field>
{ <Form.Input
!isEdit && ( label='organization'
<Form.Checkbox name='organization'
checked={batch} placeholder={type2secretPrompt(inputs.type)}
label='批量创建' onChange={handleInputChange}
name='batch' value={inputs.organization}
onChange={() => setBatch(!batch)} autoComplete='new-password'
/> />
) </Form.Field>
}
{ {
inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && ( inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
<Form.Field> <Form.Field>

View File

@ -16,6 +16,7 @@ const EditToken = () => {
unlimited_quota: false unlimited_quota: false
}; };
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
const [tagOptions, setTagOptions] = useState([]);
const { name, remain_quota, expired_time, unlimited_quota } = inputs; const { name, remain_quota, expired_time, unlimited_quota } = inputs;
const navigate = useNavigate(); const navigate = useNavigate();
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
@ -59,9 +60,24 @@ const EditToken = () => {
useEffect(() => { useEffect(() => {
if (isEdit) { if (isEdit) {
loadToken().then(); loadToken().then();
fetchTags().then();
} }
}, []); }, []);
const fetchTags = async () => {
try {
let res = await API.get(`/api/tags/`);
setTagOptions(res.data.data.map((group) => ({
key: group,
text: group,
value: group
})));
} catch (error) {
showError(error.message);
}
};
const submit = async () => { const submit = async () => {
if (!isEdit && inputs.name === '') return; if (!isEdit && inputs.name === '') return;
let localInputs = inputs; let localInputs = inputs;
@ -109,6 +125,21 @@ const EditToken = () => {
required={!isEdit} required={!isEdit}
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.Dropdown
label='tag'
placeholder={'请选择可以使用该渠道的分组'}
name='tag'
required
fluid
selection
allowAdditions
onChange={handleInputChange}
value={inputs.tag}
autoComplete='new-password'
options={tagOptions}
/>
</Form.Field>
<Form.Field> <Form.Field>
<Form.Input <Form.Input
label='过期时间' label='过期时间'