Merge branch 'channel-stream-mode' into refactor-main
This commit is contained in:
commit
00d3a78bef
@ -1,17 +1,20 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
|
||||
@ -58,21 +61,64 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
|
||||
return err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
|
||||
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if channel.AllowStreaming && isStream {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if responseText == "" {
|
||||
return errors.New("Empty response"), nil
|
||||
}
|
||||
} else {
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *ChatRequest {
|
||||
func buildTestRequest(stream bool) *ChatRequest {
|
||||
testRequest := &ChatRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
Stream: stream,
|
||||
}
|
||||
testMessage := Message{
|
||||
Role: "user",
|
||||
@ -99,7 +145,7 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
testRequest := buildTestRequest(channel.AllowStreaming)
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, *testRequest)
|
||||
tok := time.Now()
|
||||
@ -154,7 +200,6 @@ func testAllChannels(notify bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
@ -165,6 +210,7 @@ func testAllChannels(notify bool) error {
|
||||
continue
|
||||
}
|
||||
tik := time.Now()
|
||||
testRequest := buildTestRequest(channel.AllowStreaming)
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
@ -1,12 +1,13 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
|
@ -46,6 +46,7 @@ type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
|
@ -12,7 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream" default:"true"`
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
@ -84,7 +85,7 @@ func Distribute() func(c *gin.Context) {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
if channel != nil {
|
||||
|
@ -12,13 +12,22 @@ type Ability struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
|
||||
cmd := "`group` = ? and model = ? and enabled = 1"
|
||||
|
||||
if stream {
|
||||
cmd += " and allow_streaming = 1"
|
||||
} else {
|
||||
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
cmd += " and allow_non_streaming = 1"
|
||||
}
|
||||
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
|
||||
} else {
|
||||
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
if !common.RedisEnabled {
|
||||
return GetRandomSatisfiedChannel(group, model)
|
||||
return GetRandomSatisfiedChannel(group, model, stream)
|
||||
}
|
||||
channelSyncLock.RLock()
|
||||
defer channelSyncLock.RUnlock()
|
||||
@ -170,6 +170,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
|
||||
if len(channels) == 0 {
|
||||
return nil, errors.New("channel not found")
|
||||
}
|
||||
idx := rand.Intn(len(channels))
|
||||
return channels[idx], nil
|
||||
|
||||
var filteredChannels []*Channel
|
||||
for _, channel := range channels {
|
||||
if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) {
|
||||
filteredChannels = append(filteredChannels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
idx := rand.Intn(len(filteredChannels))
|
||||
return filteredChannels[idx], nil
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
@ -23,6 +25,8 @@ type Channel struct {
|
||||
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
|
||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||
AllowStreaming bool `json:"allow_streaming"`
|
||||
AllowNonStreaming bool `json:"allow_non_streaming"`
|
||||
}
|
||||
|
||||
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
@ -80,7 +84,19 @@ func BatchInsertChannels(channels []Channel) error {
|
||||
|
||||
func (channel *Channel) Insert() error {
|
||||
var err error
|
||||
err = DB.Create(channel).Error
|
||||
// turn channel into a map
|
||||
channelMap := make(map[string]interface{})
|
||||
|
||||
// Convert channel struct to a map
|
||||
channelBytes, err := json.Marshal(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(channelBytes, &channelMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = DB.Create(channelMap).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -90,11 +106,24 @@ func (channel *Channel) Insert() error {
|
||||
|
||||
func (channel *Channel) Update() error {
|
||||
var err error
|
||||
err = DB.Model(channel).Updates(channel).Error
|
||||
// turn channel into a map
|
||||
channelMap := make(map[string]interface{})
|
||||
|
||||
// Convert channel struct to a map
|
||||
channelBytes, err := json.Marshal(channel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channel, "id = ?", channel.Id)
|
||||
err = json.Unmarshal(channelBytes, &channelMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = DB.Model(channel).Updates(channelMap).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
DB.Model(channel).First(channelMap, "id = ?", channel.Id)
|
||||
err = channel.UpdateAbilities()
|
||||
return err
|
||||
}
|
||||
|
@ -22,6 +22,8 @@ const EditChannel = () => {
|
||||
base_url: '',
|
||||
other: '',
|
||||
model_mapping: '',
|
||||
allow_streaming: true,
|
||||
allow_non_streaming: true,
|
||||
models: [],
|
||||
groups: ['default']
|
||||
};
|
||||
@ -94,6 +96,9 @@ const EditChannel = () => {
|
||||
|
||||
useEffect(() => {
|
||||
let localModelOptions = [...originModelOptions];
|
||||
if (!Array.isArray(inputs.models)) {
|
||||
inputs.models = inputs.models.split(',');
|
||||
}
|
||||
inputs.models.forEach((model) => {
|
||||
if (!localModelOptions.find((option) => option.key === model)) {
|
||||
localModelOptions.push({
|
||||
@ -127,6 +132,11 @@ const EditChannel = () => {
|
||||
showInfo('模型映射必须是合法的 JSON 格式!');
|
||||
return;
|
||||
}
|
||||
// allow streaming and allow non streaming cannot be both false
|
||||
if (!inputs.allow_streaming && !inputs.allow_non_streaming) {
|
||||
showInfo('流式请求和非流式请求不能同时禁用!');
|
||||
return;
|
||||
}
|
||||
let localInputs = inputs;
|
||||
if (localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
|
||||
@ -176,7 +186,7 @@ const EditChannel = () => {
|
||||
<Message>
|
||||
注意,<strong>模型部署名称必须和模型名称保持一致</strong>,因为 One API 会把请求体中的 model
|
||||
参数替换为你的部署名称(模型名称中的点会被剔除),<a target='_blank'
|
||||
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
||||
href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示</a>。
|
||||
</Message>
|
||||
<Form.Field>
|
||||
<Form.Input
|
||||
@ -270,7 +280,7 @@ const EditChannel = () => {
|
||||
}}>清除所有模型</Button>
|
||||
<Input
|
||||
action={
|
||||
<Button type={'button'} onClick={()=>{
|
||||
<Button type={'button'} onClick={() => {
|
||||
if (customModel.trim() === "") return;
|
||||
if (inputs.models.includes(customModel)) return;
|
||||
let localModels = [...inputs.models];
|
||||
@ -281,7 +291,7 @@ const EditChannel = () => {
|
||||
text: customModel,
|
||||
value: customModel,
|
||||
});
|
||||
setModelOptions(modelOptions=>{
|
||||
setModelOptions(modelOptions => {
|
||||
return [...modelOptions, ...localModelOptions];
|
||||
});
|
||||
setCustomModel('');
|
||||
@ -306,6 +316,26 @@ const EditChannel = () => {
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</Form.Field>
|
||||
<Form.Field>
|
||||
<Form.Checkbox
|
||||
checked={inputs.allow_streaming}
|
||||
label='允许流式请求'
|
||||
name='allow_streaming'
|
||||
onChange={() => {
|
||||
setInputs((inputs) => ({ ...inputs, allow_streaming: !inputs.allow_streaming }));
|
||||
}}
|
||||
/>
|
||||
</Form.Field>
|
||||
<Form.Field>
|
||||
<Form.Checkbox
|
||||
checked={inputs.allow_non_streaming}
|
||||
label='允许非流式请求'
|
||||
name='allow_non_streaming'
|
||||
onChange={() => {
|
||||
setInputs((inputs) => ({ ...inputs, allow_non_streaming: !inputs.allow_non_streaming }));
|
||||
}}
|
||||
/>
|
||||
</Form.Field>
|
||||
{
|
||||
batch ? <Form.Field>
|
||||
<Form.TextArea
|
||||
|
Loading…
Reference in New Issue
Block a user