diff --git a/controller/tag.go b/controller/tag.go new file mode 100644 index 00000000..4654238e --- /dev/null +++ b/controller/tag.go @@ -0,0 +1,28 @@ +package controller + +import ( + "github.com/gin-gonic/gin" + "net/http" + "one-api/model" +) + +func GetTags(c *gin.Context) { + tags := make([]string, 0) + var checkMap = make(map[string]int) + channels, err := model.GetAllChannels(0, 0, true) + if err == nil{ + for i := range channels { + if _, ok := checkMap[channels[i].Tag];ok{ + continue + } + tags = append(tags, channels[i].Tag) + checkMap[channels[i].Tag] = 1 + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": tags, + }) +} diff --git a/controller/token.go b/controller/token.go index 8642122c..a0f5cb28 100644 --- a/controller/token.go +++ b/controller/token.go @@ -119,6 +119,7 @@ func AddToken(c *gin.Context) { cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, + Tag: token.Tag, Key: common.GenerateKey(), CreatedTime: common.GetTimestamp(), AccessedTime: common.GetTimestamp(), @@ -210,6 +211,7 @@ func UpdateToken(c *gin.Context) { cleanToken.ExpiredTime = token.ExpiredTime cleanToken.RemainQuota = token.RemainQuota cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Tag = token.Tag } err = cleanToken.Update() if err != nil { diff --git a/middleware/auth.go b/middleware/auth.go index ad7e64b7..cef98c9a 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,6 +106,8 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) + c.Set("token_tag", token.Tag) + print(token.Tag) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1]) diff --git a/middleware/distributor.go b/middleware/distributor.go index 81338130..36150773 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -65,7 +65,9 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "whisper-1" } } - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + tag := c.GetString("token_tag") + //channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetChannelByTag(userGroup, modelRequest.Model, tag) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { @@ -80,6 +82,9 @@ func Distribute() func(c *gin.Context) { c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("model_mapping", channel.GetModelMapping()) + if channel.Organization != ""{ + c.Request.Header.Set("Organization", channel.Organization) + } c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) switch channel.Type { diff --git a/model/ability.go b/model/ability.go index 3da83be8..509eb4ee 100644 --- a/model/ability.go +++ b/model/ability.go @@ -8,6 +8,7 @@ import ( type Ability struct { Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` + Tag string `json:"tag" gorm:"index"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` Enabled bool `json:"enabled"` Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` @@ -23,7 +24,7 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { } var err error = nil - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error @@ -39,6 +40,32 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { return &channel, err } +func GetRandomSatisfiedChannelByTag(group string, model string, tag string) (*Channel, error) { + ability := Ability{} + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + + var err error = nil + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and tag= ? and enabled = "+trueVal, group, model, tag) + channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?) and tag= ?", group, model, maxPrioritySubQuery, tag) + if common.UsingSQLite || common.UsingPostgreSQL { + err = channelQuery.Order("RANDOM()").First(&ability).Error + } else { + err = channelQuery.Order("RAND()").First(&ability).Error + } + if err != nil { + return nil, err + } + channel := Channel{} + channel.Id = ability.ChannelId + err = DB.First(&channel, "id = ?", ability.ChannelId).Error + return &channel, err +} + func (channel *Channel) AddAbilities() error { models_ := strings.Split(channel.Models, ",") groups_ := strings.Split(channel.Group, ",") @@ -51,6 +78,7 @@ func (channel *Channel) AddAbilities() error { ChannelId: channel.Id, Enabled: channel.Status == common.ChannelStatusEnabled, Priority: channel.Priority, + Tag: channel.Tag, } abilities = append(abilities, ability) } diff --git a/model/cache.go b/model/cache.go index c6d0c70a..10b69f66 100644 --- a/model/cache.go +++ b/model/cache.go @@ -213,3 +213,23 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error idx := rand.Intn(endIdx) return channels[idx], nil } + +func CacheGetChannelByTag(group string, model string, tag string) (*Channel, error){ + if !common.MemoryCacheEnabled { + return GetRandomSatisfiedChannelByTag(group, model, tag) + } + channelSyncLock.RLock() + defer channelSyncLock.RUnlock() + channels := group2model2channels[group][model] + if len(channels) == 0 { + return nil, errors.New("channel not found") + } + + for i := range channels { + if channels[i].Tag == tag{ + return channels[i], nil + } + } + + return nil, errors.New("channel not found") +} diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..14161eea 100644 --- a/model/channel.go +++ b/model/channel.go @@ -9,6 +9,8 @@ type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` Key string `json:"key" gorm:"not null;index"` + Organization string `json:"organization" gorm:"index"` + Tag string `json:"tag" gorm:"index"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` diff --git a/model/token.go b/model/token.go index 0fa984d3..ed03a731 100644 --- a/model/token.go +++ b/model/token.go @@ -10,6 +10,7 @@ import ( type Token struct { Id int `json:"id"` UserId int `json:"user_id"` + Tag string `json:"tag" gorm:"index"` Key string `json:"key" gorm:"type:char(48);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index" ` @@ -102,7 +103,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "tag").Updates(token).Error return err } diff --git a/router/api-router.go b/router/api-router.go index da3f9e61..84dd953a 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -110,5 +110,10 @@ func SetApiRouter(router *gin.Engine) { { groupRoute.GET("/", controller.GetGroups) } + tagRoute := apiRouter.Group("/tags") + tagRoute.Use(middleware.AdminAuth()) + { + tagRoute.GET("/", controller.GetTags) + } } } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 0d4e114d..7616dcb3 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -308,6 +308,17 @@ const EditChannel = () => { autoComplete='new-password' /> +