diff --git a/middleware/distributor.go b/middleware/distributor.go index 357849e7..04382a24 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,10 +1,14 @@ package middleware import ( + "bytes" + "encoding/json" "fmt" "github.com/gin-gonic/gin" + "io" "net/http" "one-api/common" + "one-api/controller" "one-api/model" "strconv" ) @@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var err error - channel, err = model.GetRandomChannel() + var textRequest controller.GeneralOpenAIRequest + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "read_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = c.Request.Body.Close() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "close_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = json.Unmarshal(requestBody, &textRequest) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "unmarshal_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + model_ := textRequest.Model + channel, err = model.GetRandomChannel(model_) if err != nil { c.JSON(200, gin.H{ "error": gin.H{ diff --git a/model/channel.go b/model/channel.go index 35d65827..6012d94b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -19,6 +19,7 @@ type Channel struct { Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models" ` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { return &channel, err } -func GetRandomChannel() (*Channel, error) { +func GetRandomChannel(model string) (*Channel, error) { // TODO: consider weight channel := Channel{} var err error = nil if common.UsingSQLite { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error } else { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error } return &channel, err } diff --git a/router/api-router.go b/router/api-router.go index 9ca2226a..889e9a62 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { + channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/:id", controller.GetChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index a0a0f5dd..82ee1ada 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -231,6 +231,15 @@ const ChannelsTable = () => { setLoading(false); }; + const renderModels = (modelString) => { + let models = modelString.split(","); + return models.map((model) => ( + + )) + } + return ( <>