简单的渠道绑定模型
This commit is contained in:
parent
7e80e2da3a
commit
7228409c15
@ -1,10 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/controller"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var err error
|
var err error
|
||||||
channel, err = model.GetRandomChannel()
|
var textRequest controller.GeneralOpenAIRequest
|
||||||
|
requestBody, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": "read_request_body_failed",
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = c.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": "close_request_body_failed",
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(requestBody, &textRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": "unmarshal_request_body_failed",
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Reset request body
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
model_ := textRequest.Model
|
||||||
|
channel, err = model.GetRandomChannel(model_)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
|
@ -19,6 +19,7 @@ type Channel struct {
|
|||||||
Other string `json:"other"`
|
Other string `json:"other"`
|
||||||
Balance float64 `json:"balance"` // in USD
|
Balance float64 `json:"balance"` // in USD
|
||||||
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
|
||||||
|
Models string `json:"models" `
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||||
@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
|||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomChannel() (*Channel, error) {
|
func GetRandomChannel(model string) (*Channel, error) {
|
||||||
// TODO: consider weight
|
// TODO: consider weight
|
||||||
channel := Channel{}
|
channel := Channel{}
|
||||||
var err error = nil
|
var err error = nil
|
||||||
if common.UsingSQLite {
|
if common.UsingSQLite {
|
||||||
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
|
err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||||
} else {
|
} else {
|
||||||
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
|
err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error
|
||||||
}
|
}
|
||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute := apiRouter.Group("/channel")
|
channelRoute := apiRouter.Group("/channel")
|
||||||
channelRoute.Use(middleware.AdminAuth())
|
channelRoute.Use(middleware.AdminAuth())
|
||||||
{
|
{
|
||||||
|
channelRoute.GET("/models", controller.ListModels)
|
||||||
channelRoute.GET("/", controller.GetAllChannels)
|
channelRoute.GET("/", controller.GetAllChannels)
|
||||||
channelRoute.GET("/search", controller.SearchChannels)
|
channelRoute.GET("/search", controller.SearchChannels)
|
||||||
channelRoute.GET("/:id", controller.GetChannel)
|
channelRoute.GET("/:id", controller.GetChannel)
|
||||||
|
@ -231,6 +231,15 @@ const ChannelsTable = () => {
|
|||||||
setLoading(false);
|
setLoading(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const renderModels = (modelString) => {
|
||||||
|
let models = modelString.split(",");
|
||||||
|
return models.map((model) => (
|
||||||
|
<Label>
|
||||||
|
{model}
|
||||||
|
</Label>
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form onSubmit={searchChannels}>
|
<Form onSubmit={searchChannels}>
|
||||||
@ -288,6 +297,9 @@ const ChannelsTable = () => {
|
|||||||
>
|
>
|
||||||
响应时间
|
响应时间
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
|
<Table.HeaderCell>
|
||||||
|
支持的模型
|
||||||
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
@ -322,6 +334,10 @@ const ChannelsTable = () => {
|
|||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
|
<Table.Cell>
|
||||||
|
{channel.models.length > 0 ? renderModels(channel.models) :<Label>无</Label>
|
||||||
|
}
|
||||||
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<Popup
|
<Popup
|
||||||
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
|
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}
|
||||||
|
@ -14,7 +14,8 @@ const EditChannel = () => {
|
|||||||
type: 1,
|
type: 1,
|
||||||
key: '',
|
key: '',
|
||||||
base_url: '',
|
base_url: '',
|
||||||
other: ''
|
other: '',
|
||||||
|
models: [],
|
||||||
};
|
};
|
||||||
const [batch, setBatch] = useState(false);
|
const [batch, setBatch] = useState(false);
|
||||||
const [inputs, setInputs] = useState(originInputs);
|
const [inputs, setInputs] = useState(originInputs);
|
||||||
@ -27,7 +28,11 @@ const EditChannel = () => {
|
|||||||
let res = await API.get(`/api/channel/${channelId}`);
|
let res = await API.get(`/api/channel/${channelId}`);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
data.password = '';
|
if (data.models === "") {
|
||||||
|
data.models = []
|
||||||
|
} else {
|
||||||
|
data.models = data.models.split(",")
|
||||||
|
}
|
||||||
setInputs(data);
|
setInputs(data);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@ -50,6 +55,7 @@ const EditChannel = () => {
|
|||||||
localInputs.other = '2023-03-15-preview';
|
localInputs.other = '2023-03-15-preview';
|
||||||
}
|
}
|
||||||
let res;
|
let res;
|
||||||
|
localInputs.models = localInputs.models.join(",")
|
||||||
if (isEdit) {
|
if (isEdit) {
|
||||||
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
|
||||||
} else {
|
} else {
|
||||||
@ -68,6 +74,25 @@ const EditChannel = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const [modelOptions, setModelOptions] = useState([]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const getModels = async () => {
|
||||||
|
try {
|
||||||
|
let res = await API.get(`/api/channel/models`);
|
||||||
|
setModelOptions(res.data.data.map((model) => ({
|
||||||
|
key: model.id,
|
||||||
|
text: model.id,
|
||||||
|
value: model.id,
|
||||||
|
})));
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching models:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
getModels();
|
||||||
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Segment loading={loading}>
|
<Segment loading={loading}>
|
||||||
@ -137,6 +162,19 @@ const EditChannel = () => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
</Form.Field>
|
</Form.Field>
|
||||||
|
<Form.Field>
|
||||||
|
<Form.Dropdown
|
||||||
|
label='支持的模型'
|
||||||
|
name='models'
|
||||||
|
fluid
|
||||||
|
multiple
|
||||||
|
selection
|
||||||
|
onChange={handleInputChange}
|
||||||
|
value={inputs.models}
|
||||||
|
autoComplete='new-password'
|
||||||
|
options={modelOptions}
|
||||||
|
/>
|
||||||
|
</Form.Field>
|
||||||
{
|
{
|
||||||
batch ? <Form.Field>
|
batch ? <Form.Field>
|
||||||
<Form.TextArea
|
<Form.TextArea
|
||||||
|
Loading…
Reference in New Issue
Block a user