From 2ad22e14256264f0787f5c82d7b5409d431a2d63 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 7 Jun 2023 23:26:00 +0800 Subject: [PATCH] feat: support group now (close #17, close #72, close #85, close #104, close #136) Co-authored-by: quzard <1191890118@qq.com> --- common/gin.go | 26 ++++++++++ controller/relay.go | 14 +----- middleware/distributor.go | 21 +++++++- model/ability.go | 72 ++++++++++++++++++++++++++++ model/channel.go | 37 +++++++++++--- model/main.go | 4 ++ model/redemption.go | 1 - model/token.go | 1 - model/user.go | 6 +++ router/api-router.go | 1 + router/relay-router.go | 8 +++- web/src/components/ChannelsTable.js | 12 ++++- web/src/components/UsersTable.js | 13 ++++- web/src/helpers/render.js | 9 ++++ web/src/pages/Channel/EditChannel.js | 39 ++++++++++++++- 15 files changed, 235 insertions(+), 29 deletions(-) create mode 100644 common/gin.go create mode 100644 model/ability.go diff --git a/common/gin.go b/common/gin.go new file mode 100644 index 00000000..ffa1e218 --- /dev/null +++ b/common/gin.go @@ -0,0 +1,26 @@ +package common + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "io" +) + +func UnmarshalBodyReusable(c *gin.Context, v any) error { + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + return err + } + err = c.Request.Body.Close() + if err != nil { + return err + } + err = json.Unmarshal(requestBody, &v) + if err != nil { + return err + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil +} diff --git a/controller/relay.go b/controller/relay.go index 81497d81..fb3b8bc4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -116,20 +116,10 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { consumeQuota := c.GetBool("consume_quota") var textRequest GeneralOpenAIRequest if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { - requestBody, err := io.ReadAll(c.Request.Body) + err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil { - return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest) + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest) - } - err = json.Unmarshal(requestBody, &textRequest) - if err != nil { - return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest) - } - // Reset request body - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() diff --git a/middleware/distributor.go b/middleware/distributor.go index 357849e7..624cf3b1 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -9,6 +9,10 @@ import ( "strconv" ) +type ModelRequest struct { + Model string `json:"model"` +} + func Distribute() func(c *gin.Context) { return func(c *gin.Context) { var channel *model.Channel @@ -48,8 +52,21 @@ func Distribute() func(c *gin.Context) { } } else { // Select a channel for the user - var err error - channel, err = model.GetRandomChannel() + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + c.JSON(200, gin.H{ + "error": gin.H{ + "message": "无效的请求", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + userId := c.GetInt("id") + userGroup, _ := model.GetUserGroup(userId) + channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { c.JSON(200, gin.H{ "error": gin.H{ diff --git a/model/ability.go b/model/ability.go new file mode 100644 index 00000000..1270ea8a --- /dev/null +++ b/model/ability.go @@ -0,0 +1,72 @@ +package model + +import ( + "one-api/common" + "strings" +) + +type Ability struct { + Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` + Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` + ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` + Enabled bool `json:"enabled" gorm:"default:1"` +} + +func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { + if group == "default" { + return GetRandomChannel() + } + ability := Ability{} + var err error = nil + if common.UsingSQLite { + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + } else { + err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error + } + if err != nil { + return nil, err + } + channel := Channel{} + err = DB.First(&channel, "id = ?", ability.ChannelId).Error + return &channel, err +} + +func (channel *Channel) AddAbilities() error { + models_ := strings.Split(channel.Models, ",") + abilities := make([]Ability, 0, len(models_)) + for _, model := range models_ { + ability := Ability{ + Group: channel.Group, + Model: model, + ChannelId: channel.Id, + Enabled: channel.Status == common.ChannelStatusEnabled, + } + abilities = append(abilities, ability) + } + return DB.Create(&abilities).Error +} + +func (channel *Channel) DeleteAbilities() error { + return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error +} + +// UpdateAbilities updates abilities of this channel. +// Make sure the channel is completed before calling this function. +func (channel *Channel) UpdateAbilities() error { + // A quick and dirty way to update abilities + // First delete all abilities of this channel + err := channel.DeleteAbilities() + if err != nil { + return err + } + // Then add new abilities + err = channel.AddAbilities() + if err != nil { + return err + } + return nil +} + +func UpdateAbilityStatus(channelId int, status bool) error { + return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Update("enabled", status).Error +} diff --git a/model/channel.go b/model/channel.go index 35d65827..006a67d9 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,6 @@ package model import ( - _ "gorm.io/driver/sqlite" "one-api/common" ) @@ -19,6 +18,8 @@ type Channel struct { Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -49,13 +50,12 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { } func GetRandomChannel() (*Channel, error) { - // TODO: consider weight channel := Channel{} var err error = nil if common.UsingSQLite { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error } else { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error } return &channel, err } @@ -63,18 +63,35 @@ func GetRandomChannel() (*Channel, error) { func BatchInsertChannels(channels []Channel) error { var err error err = DB.Create(&channels).Error - return err + if err != nil { + return err + } + for _, channel_ := range channels { + err = channel_.AddAbilities() + if err != nil { + return err + } + } + return nil } func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error + if err != nil { + return err + } + err = channel.AddAbilities() return err } func (channel *Channel) Update() error { var err error err = DB.Model(channel).Updates(channel).Error + if err != nil { + return err + } + err = channel.UpdateAbilities() return err } @@ -101,11 +118,19 @@ func (channel *Channel) UpdateBalance(balance float64) { func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error + if err != nil { + return err + } + err = channel.DeleteAbilities() return err } func UpdateChannelStatusById(id int, status int) { - err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error + err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + if err != nil { + common.SysError("failed to update ability status: " + err.Error()) + } + err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { common.SysError("failed to update channel status: " + err.Error()) } diff --git a/model/main.go b/model/main.go index 3f6fafbf..8d55cee6 100644 --- a/model/main.go +++ b/model/main.go @@ -75,6 +75,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Ability{}) + if err != nil { + return err + } err = createRootAccountIfNeed() return err } else { diff --git a/model/redemption.go b/model/redemption.go index b731acf7..c3444f33 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -2,7 +2,6 @@ package model import ( "errors" - _ "gorm.io/driver/sqlite" "one-api/common" ) diff --git a/model/token.go b/model/token.go index 4adf42e5..8ce252b2 100644 --- a/model/token.go +++ b/model/token.go @@ -3,7 +3,6 @@ package model import ( "errors" "fmt" - _ "gorm.io/driver/sqlite" "gorm.io/gorm" "one-api/common" ) diff --git a/model/user.go b/model/user.go index 2ca0d6a4..23a97896 100644 --- a/model/user.go +++ b/model/user.go @@ -22,6 +22,7 @@ type User struct { VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` + Group string `json:"group" gorm:"type:varchar(32);default:'default'"` } func GetMaxUserId() int { @@ -229,6 +230,11 @@ func GetUserEmail(id int) (email string, err error) { return email, err } +func GetUserGroup(id int) (group string, err error) { + err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error + return group, err +} + func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") diff --git a/router/api-router.go b/router/api-router.go index 9ca2226a..abd4d23b 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) { { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) + channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/router/relay-router.go b/router/relay-router.go index 46c37a89..6d5b74a9 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -8,11 +8,15 @@ import ( func SetRelayRouter(router *gin.Engine) { // https://platform.openai.com/docs/api-reference/introduction + modelsRouter := router.Group("/v1/models") + modelsRouter.Use(middleware.TokenAuth()) + { + modelsRouter.GET("/", controller.ListModels) + modelsRouter.GET("/:model", controller.RetrieveModel) + } relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { - relayV1Router.GET("/models", controller.ListModels) - relayV1Router.GET("/models/:model", controller.RetrieveModel) relayV1Router.POST("/completions", controller.RelayNotImplemented) relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.RelayNotImplemented) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index a0a0f5dd..be0bba16 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -4,6 +4,7 @@ import { Link } from 'react-router-dom'; import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; +import { renderGroup } from '../helpers/render'; function renderTimestamp(timestamp) { return ( @@ -264,6 +265,14 @@ const ChannelsTable = () => { > 名称 + { + sortChannel('group'); + }} + > + 分组 + { @@ -312,6 +321,7 @@ const ChannelsTable = () => { {channel.id} {channel.name ? channel.name : '无'} + {renderGroup(channel.group)} {renderType(channel.type)} {renderStatus(channel.status)} @@ -398,7 +408,7 @@ const ChannelsTable = () => { - + diff --git a/web/src/components/UsersTable.js b/web/src/components/UsersTable.js index 1fdd8923..9906bca7 100644 --- a/web/src/components/UsersTable.js +++ b/web/src/components/UsersTable.js @@ -4,7 +4,7 @@ import { Link } from 'react-router-dom'; import { API, showError, showSuccess } from '../helpers'; import { ITEMS_PER_PAGE } from '../constants'; -import { renderText } from '../helpers/render'; +import { renderGroup, renderText } from '../helpers/render'; function renderRole(role) { switch (role) { @@ -175,6 +175,14 @@ const UsersTable = () => { > 用户名 + { + sortUser('group'); + }} + > + 分组 + { @@ -231,6 +239,7 @@ const UsersTable = () => { hoverable /> + {renderGroup(user.group)} {user.email ? renderText(user.email, 30) : '无'} {user.quota} {renderRole(user.role)} @@ -306,7 +315,7 @@ const UsersTable = () => { - + diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 20bfedd5..1817feb1 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -1,6 +1,15 @@ +import { Label } from 'semantic-ui-react'; + export function renderText(text, limit) { if (text.length > limit) { return text.slice(0, limit - 3) + '...'; } return text; +} + +export function renderGroup(group) { + if (group === "") { + return + } + return } \ No newline at end of file diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 5cf6e6a1..05607f98 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -14,10 +14,12 @@ const EditChannel = () => { type: 1, key: '', base_url: '', - other: '' + other: '', + models: [], }; const [batch, setBatch] = useState(false); const [inputs, setInputs] = useState(originInputs); + const [modelOptions, setModelOptions] = useState([]); const handleInputChange = (e, { name, value }) => { console.log(name, value); setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -27,17 +29,36 @@ const EditChannel = () => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; if (success) { - data.password = ''; + if (data.models === "") { + data.models = [] + } else { + data.models = data.models.split(",") + } setInputs(data); } else { showError(message); } setLoading(false); }; + + const fetchModels = async () => { + try { + let res = await API.get(`/api/channel/models`); + setModelOptions(res.data.data.map((model) => ({ + key: model.id, + text: model.id, + value: model.id, + }))); + } catch (error) { + console.error('Error fetching models:', error); + } + }; + useEffect(() => { if (isEdit) { loadChannel().then(); } + fetchModels().then(); }, []); const submit = async () => { @@ -50,6 +71,7 @@ const EditChannel = () => { localInputs.other = '2023-03-15-preview'; } let res; + localInputs.models = localInputs.models.join(",") if (isEdit) { res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); } else { @@ -137,6 +159,19 @@ const EditChannel = () => { autoComplete='new-password' /> + + + { batch ?