From 7228409c15f8533506dafa5035cfb832949f9074 Mon Sep 17 00:00:00 2001 From: quzard <1191890118@qq.com> Date: Fri, 2 Jun 2023 20:27:00 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8D=95=E7=9A=84=E6=B8=A0=E9=81=93?= =?UTF-8?q?=E7=BB=91=E5=AE=9A=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/distributor.go | 43 +++++++++++++++++++++++++++- model/channel.go | 7 +++-- router/api-router.go | 1 + web/src/components/ChannelsTable.js | 16 +++++++++++ web/src/pages/Channel/EditChannel.js | 42 +++++++++++++++++++++++++-- 5 files changed, 103 insertions(+), 6 deletions(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 357849e7..04382a24 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,10 +1,14 @@ package middleware import ( + "bytes" + "encoding/json" "fmt" "github.com/gin-gonic/gin" + "io" "net/http" "one-api/common" + "one-api/controller" "one-api/model" "strconv" ) @@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var err error - channel, err = model.GetRandomChannel() + var textRequest controller.GeneralOpenAIRequest + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "read_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = c.Request.Body.Close() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "close_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = json.Unmarshal(requestBody, &textRequest) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "unmarshal_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + model_ := textRequest.Model + channel, err = model.GetRandomChannel(model_) if err != nil { c.JSON(200, gin.H{ "error": gin.H{ diff --git a/model/channel.go b/model/channel.go index 35d65827..6012d94b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -19,6 +19,7 @@ type Channel struct { Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models" ` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { return &channel, err } -func GetRandomChannel() (*Channel, error) { +func GetRandomChannel(model string) (*Channel, error) { // TODO: consider weight channel := Channel{} var err error = nil if common.UsingSQLite { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error } else { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error } return &channel, err } diff --git a/router/api-router.go b/router/api-router.go index 9ca2226a..889e9a62 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { + channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/:id", controller.GetChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index a0a0f5dd..82ee1ada 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -231,6 +231,15 @@ const ChannelsTable = () => { setLoading(false); }; + const renderModels = (modelString) => { + let models = modelString.split(","); + return models.map((model) => ( + + )) + } + return ( <>