简单的渠道绑定模型

This commit is contained in:
quzard 2023-06-02 20:27:00 +08:00
parent 7e80e2da3a
commit 7228409c15
5 changed files with 103 additions and 6 deletions

View File

@ -1,10 +1,14 @@
package middleware package middleware
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/controller"
"one-api/model" "one-api/model"
"strconv" "strconv"
) )
@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) {
} else { } else {
// Select a channel for the user // Select a channel for the user
var err error var err error
channel, err = model.GetRandomChannel() var textRequest controller.GeneralOpenAIRequest
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "read_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
err = c.Request.Body.Close()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "close_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "unmarshal_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
model_ := textRequest.Model
channel, err = model.GetRandomChannel(model_)
if err != nil { if err != nil {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": gin.H{ "error": gin.H{

View File

@ -19,6 +19,7 @@ type Channel struct {
Other string `json:"other"` Other string `json:"other"`
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models" `
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err return &channel, err
} }
func GetRandomChannel() (*Channel, error) { func GetRandomChannel(model string) (*Channel, error) {
// TODO: consider weight // TODO: consider weight
channel := Channel{} channel := Channel{}
var err error = nil var err error = nil
if common.UsingSQLite { if common.UsingSQLite {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error
} else { } else {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error
} }
return &channel, err return &channel, err
} }

View File

@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute := apiRouter.Group("/channel") channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth()) channelRoute.Use(middleware.AdminAuth())
{ {
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels) channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/:id", controller.GetChannel)

View File

@ -231,6 +231,15 @@ const ChannelsTable = () => {
setLoading(false); setLoading(false);
}; };
const renderModels = (modelString) => {
let models = modelString.split(",");
return models.map((model) => (
<Label>
{model}
</Label>
))
}
return ( return (
<> <>
<Form onSubmit={searchChannels}> <Form onSubmit={searchChannels}>
@ -288,6 +297,9 @@ const ChannelsTable = () => {
> >
响应时间 响应时间
</Table.HeaderCell> </Table.HeaderCell>
<Table.HeaderCell>
支持的模型
</Table.HeaderCell>
<Table.HeaderCell <Table.HeaderCell
style={{ cursor: 'pointer' }} style={{ cursor: 'pointer' }}
onClick={() => { onClick={() => {
@ -322,6 +334,10 @@ const ChannelsTable = () => {
basic basic
/> />
</Table.Cell> </Table.Cell>
<Table.Cell>
{channel.models.length > 0 ? renderModels(channel.models) :<Label></Label>
}
</Table.Cell>
<Table.Cell> <Table.Cell>
<Popup <Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'} content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}

View File

@ -14,7 +14,8 @@ const EditChannel = () => {
type: 1, type: 1,
key: '', key: '',
base_url: '', base_url: '',
other: '' other: '',
models: [],
}; };
const [batch, setBatch] = useState(false); const [batch, setBatch] = useState(false);
const [inputs, setInputs] = useState(originInputs); const [inputs, setInputs] = useState(originInputs);
@ -27,7 +28,11 @@ const EditChannel = () => {
let res = await API.get(`/api/channel/${channelId}`); let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
data.password = ''; if (data.models === "") {
data.models = []
} else {
data.models = data.models.split(",")
}
setInputs(data); setInputs(data);
} else { } else {
showError(message); showError(message);
@ -50,6 +55,7 @@ const EditChannel = () => {
localInputs.other = '2023-03-15-preview'; localInputs.other = '2023-03-15-preview';
} }
let res; let res;
localInputs.models = localInputs.models.join(",")
if (isEdit) { if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) }); res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else { } else {
@ -68,6 +74,25 @@ const EditChannel = () => {
} }
}; };
const [modelOptions, setModelOptions] = useState([]);
useEffect(() => {
const getModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
setModelOptions(res.data.data.map((model) => ({
key: model.id,
text: model.id,
value: model.id,
})));
} catch (error) {
console.error('Error fetching models:', error);
}
};
getModels();
}, []);
return ( return (
<> <>
<Segment loading={loading}> <Segment loading={loading}>
@ -137,6 +162,19 @@ const EditChannel = () => {
autoComplete='new-password' autoComplete='new-password'
/> />
</Form.Field> </Form.Field>
<Form.Field>
<Form.Dropdown
label='支持的模型'
name='models'
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
{ {
batch ? <Form.Field> batch ? <Form.Field>
<Form.TextArea <Form.TextArea