简单的渠道绑定模型

This commit is contained in:
quzard 2023-06-02 20:27:00 +08:00
parent 7e80e2da3a
commit 7228409c15
5 changed files with 103 additions and 6 deletions

View File

@ -1,10 +1,14 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/model"
"strconv"
)
@ -49,7 +53,44 @@ func Distribute() func(c *gin.Context) {
} else {
// Select a channel for the user
var err error
channel, err = model.GetRandomChannel()
var textRequest controller.GeneralOpenAIRequest
requestBody, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "read_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
err = c.Request.Body.Close()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "close_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
err = json.Unmarshal(requestBody, &textRequest)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "unmarshal_request_body_failed",
"type": "one_api_error",
},
})
c.Abort()
return
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
model_ := textRequest.Model
channel, err = model.GetRandomChannel(model_)
if err != nil {
c.JSON(200, gin.H{
"error": gin.H{

View File

@ -19,6 +19,7 @@ type Channel struct {
Other string `json:"other"`
Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models" `
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -48,14 +49,14 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err
}
func GetRandomChannel() (*Channel, error) {
func GetRandomChannel(model string) (*Channel, error) {
// TODO: consider weight
channel := Channel{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
err = DB.Where("status = ? and (models = ? or models like ? or models like ? or models like ?)", common.ChannelStatusEnabled, model, model+",%", "%,"+model, "%,"+model+",%").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}

View File

@ -61,6 +61,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute := apiRouter.Group("/channel")
channelRoute.Use(middleware.AdminAuth())
{
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/:id", controller.GetChannel)

View File

@ -231,6 +231,15 @@ const ChannelsTable = () => {
setLoading(false);
};
const renderModels = (modelString) => {
let models = modelString.split(",");
return models.map((model) => (
<Label>
{model}
</Label>
))
}
return (
<>
<Form onSubmit={searchChannels}>
@ -288,6 +297,9 @@ const ChannelsTable = () => {
>
响应时间
</Table.HeaderCell>
<Table.HeaderCell>
支持的模型
</Table.HeaderCell>
<Table.HeaderCell
style={{ cursor: 'pointer' }}
onClick={() => {
@ -322,6 +334,10 @@ const ChannelsTable = () => {
basic
/>
</Table.Cell>
<Table.Cell>
{channel.models.length > 0 ? renderModels(channel.models) :<Label></Label>
}
</Table.Cell>
<Table.Cell>
<Popup
content={channel.balance_updated_time ? renderTimestamp(channel.balance_updated_time) : '未更新'}

View File

@ -14,7 +14,8 @@ const EditChannel = () => {
type: 1,
key: '',
base_url: '',
other: ''
other: '',
models: [],
};
const [batch, setBatch] = useState(false);
const [inputs, setInputs] = useState(originInputs);
@ -27,7 +28,11 @@ const EditChannel = () => {
let res = await API.get(`/api/channel/${channelId}`);
const { success, message, data } = res.data;
if (success) {
data.password = '';
if (data.models === "") {
data.models = []
} else {
data.models = data.models.split(",")
}
setInputs(data);
} else {
showError(message);
@ -50,6 +55,7 @@ const EditChannel = () => {
localInputs.other = '2023-03-15-preview';
}
let res;
localInputs.models = localInputs.models.join(",")
if (isEdit) {
res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
} else {
@ -68,6 +74,25 @@ const EditChannel = () => {
}
};
const [modelOptions, setModelOptions] = useState([]);
useEffect(() => {
const getModels = async () => {
try {
let res = await API.get(`/api/channel/models`);
setModelOptions(res.data.data.map((model) => ({
key: model.id,
text: model.id,
value: model.id,
})));
} catch (error) {
console.error('Error fetching models:', error);
}
};
getModels();
}, []);
return (
<>
<Segment loading={loading}>
@ -137,6 +162,19 @@ const EditChannel = () => {
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='支持的模型'
name='models'
fluid
multiple
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
{
batch ? <Form.Field>
<Form.TextArea