diff --git a/middleware/distributor.go b/middleware/distributor.go index 357849e7..04382a24 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,10 +1,14 @@ package middleware import ( + "bytes" + "encoding/json" "fmt" "github.com/gin-gonic/gin" + "io" "net/http" "one-api/common" + "one-api/controller" "one-api/model" "strconv" ) @@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var err error - channel, err = model.GetRandomChannel() + var textRequest controller.GeneralOpenAIRequest + requestBody, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "read_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = c.Request.Body.Close() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "close_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + err = json.Unmarshal(requestBody, &textRequest) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "message": "unmarshal_request_body_failed", + "type": "one_api_error", + }, + }) + c.Abort() + return + } + // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + model_ := textRequest.Model + channel, err = model.GetRandomChannel(model_) if err != nil { c.JSON(200, gin.H{ "error": gin.H{ diff --git a/model/channel.go b/model/channel.go index 35d65827..6012d94b 100644 --- a/model/channel.go +++ b/model/channel.go @@ -19,6 +19,7 @@ type Channel struct { Other string `json:"other"` Balance float64 `json:"balance"` // in USD BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` + Models string `json:"models" ` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { return &channel, err } -func GetRandomChannel() (*Channel, error) { +func GetRandomChannel(model string) (*Channel, error) { // TODO: consider weight channel := Channel{} var err error = nil if common.UsingSQLite { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error } else { - err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error + err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error } return &channel, err } diff --git a/router/api-router.go b/router/api-router.go index 9ca2226a..889e9a62 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute := apiRouter.Group("/channel") channelRoute.Use(middleware.AdminAuth()) { + channelRoute.GET("/models", controller.ListModels) channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/:id", controller.GetChannel) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index a0a0f5dd..82ee1ada 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -231,6 +231,15 @@ const ChannelsTable = () => { setLoading(false); }; + const renderModels = (modelString) => { + let models = modelString.split(","); + return models.map((model) => ( + + )) + } + return ( <>
@@ -288,6 +297,9 @@ const ChannelsTable = () => { > 响应时间 + + 支持的模型 + { @@ -322,6 +334,10 @@ const ChannelsTable = () => { basic /> + + {channel.models.length > 0 ? renderModels(channel.models) : + } + { type: 1, key: '', base_url: '', - other: '' + other: '', + models: [], }; const [batch, setBatch] = useState(false); const [inputs, setInputs] = useState(originInputs); @@ -27,7 +28,11 @@ const EditChannel = () => { let res = await API.get(`/api/channel/${channelId}`); const { success, message, data } = res.data; if (success) { - data.password = ''; + if (data.models === "") { + data.models = [] + } else { + data.models = data.models.split(",") + } setInputs(data); } else { showError(message); @@ -50,6 +55,7 @@ const EditChannel = () => { localInputs.other = '2023-03-15-preview'; } let res; + localInputs.models = localInputs.models.join(",") if (isEdit) { res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); } else { @@ -68,6 +74,25 @@ const EditChannel = () => { } }; + const [modelOptions, setModelOptions] = useState([]); + + useEffect(() => { + const getModels = async () => { + try { + let res = await API.get(`/api/channel/models`); + setModelOptions(res.data.data.map((model) => ({ + key: model.id, + text: model.id, + value: model.id, + }))); + } catch (error) { + console.error('Error fetching models:', error); + } + }; + + getModels(); + }, []); + return ( <> @@ -137,6 +162,19 @@ const EditChannel = () => { autoComplete='new-password' /> + + + { batch ?