feat: allow toggling stream mode of channels

This commit is contained in:
ckt1031 2023-07-24 15:30:08 +08:00
parent bc2f48b1f2
commit a588241515
8 changed files with 154 additions and 29 deletions

View File

@ -1,17 +1,20 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
@ -58,6 +61,46 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
return err, nil
}
defer resp.Body.Close()
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if channel.AllowStreaming && isStream {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
return err, nil
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
}
}
if responseText == "" {
return errors.New("Empty response"), nil
}
} else {
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
@ -66,13 +109,16 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
}
return nil, nil
}
func buildTestRequest() *ChatRequest {
func buildTestRequest(stream bool) *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
Stream: stream,
}
testMessage := Message{
Role: "user",
@ -99,7 +145,7 @@ func TestChannel(c *gin.Context) {
})
return
}
testRequest := buildTestRequest()
testRequest := buildTestRequest(channel.AllowStreaming)
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
tok := time.Now()
@ -154,7 +200,6 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
@ -165,6 +210,7 @@ func testAllChannels(notify bool) error {
continue
}
tik := time.Now()
testRequest := buildTestRequest(channel.AllowStreaming)
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()

View File

@ -1,12 +1,13 @@
package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
func GetAllChannels(c *gin.Context) {

View File

@ -46,6 +46,7 @@ type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type TextRequest struct {

View File

@ -13,6 +13,7 @@ import (
type ModelRequest struct {
Model string `json:"model"`
Stream bool `json:"stream" default:"true"`
}
func Distribute() func(c *gin.Context) {
@ -84,7 +85,7 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {

View File

@ -12,13 +12,22 @@ type Ability struct {
Enabled bool `json:"enabled"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
ability := Ability{}
var err error = nil
if common.UsingSQLite {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
cmd := "`group` = ? and model = ? and enabled = 1"
if stream {
cmd += " and allow_streaming = 1"
} else {
err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
cmd += " and allow_non_streaming = 1"
}
if common.UsingSQLite {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err

View File

@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) {
}
}
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
if !common.RedisEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, stream)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@ -170,6 +170,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
idx := rand.Intn(len(channels))
return channels[idx], nil
var filteredChannels []*Channel
for _, channel := range channels {
if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) {
filteredChannels = append(filteredChannels, channel)
}
}
idx := rand.Intn(len(filteredChannels))
return filteredChannels[idx], nil
}

View File

@ -1,8 +1,10 @@
package model
import (
"gorm.io/gorm"
"encoding/json"
"one-api/common"
"gorm.io/gorm"
)
type Channel struct {
@ -23,6 +25,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
AllowStreaming bool `json:"allow_streaming"`
AllowNonStreaming bool `json:"allow_non_streaming"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -80,7 +84,19 @@ func BatchInsertChannels(channels []Channel) error {
func (channel *Channel) Insert() error {
var err error
err = DB.Create(channel).Error
// turn channel into a map
channelMap := make(map[string]interface{})
// Convert channel struct to a map
channelBytes, err := json.Marshal(channel)
if err != nil {
return err
}
err = json.Unmarshal(channelBytes, &channelMap)
if err != nil {
return err
}
err = DB.Create(channelMap).Error
if err != nil {
return err
}
@ -90,11 +106,24 @@ func (channel *Channel) Insert() error {
func (channel *Channel) Update() error {
var err error
err = DB.Model(channel).Updates(channel).Error
// turn channel into a map
channelMap := make(map[string]interface{})
// Convert channel struct to a map
channelBytes, err := json.Marshal(channel)
if err != nil {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = json.Unmarshal(channelBytes, &channelMap)
if err != nil {
return err
}
err = DB.Model(channel).Updates(channelMap).Error
if err != nil {
return err
}
DB.Model(channel).First(channelMap, "id = ?", channel.Id)
err = channel.UpdateAbilities()
return err
}

View File

@ -22,6 +22,8 @@ const EditChannel = () => {
base_url: '',
other: '',
model_mapping: '',
allow_streaming: true,
allow_non_streaming: true,
models: [],
groups: ['default']
};
@ -94,6 +96,9 @@ const EditChannel = () => {
useEffect(() => {
let localModelOptions = [...originModelOptions];
if (!Array.isArray(inputs.models)) {
inputs.models = inputs.models.split(',');
}
inputs.models.forEach((model) => {
if (!localModelOptions.find((option) => option.key === model)) {
localModelOptions.push({
@ -127,6 +132,11 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
// allow streaming and allow non streaming cannot be both false
if (!inputs.allow_streaming && !inputs.allow_non_streaming) {
showInfo('流式请求和非流式请求不能同时禁用!');
return;
}
let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@ -270,7 +280,7 @@ const EditChannel = () => {
}}>清除所有模型</Button>
<Input
action={
<Button type={'button'} onClick={()=>{
<Button type={'button'} onClick={() => {
if (customModel.trim() === "") return;
if (inputs.models.includes(customModel)) return;
let localModels = [...inputs.models];
@ -281,7 +291,7 @@ const EditChannel = () => {
text: customModel,
value: customModel,
});
setModelOptions(modelOptions=>{
setModelOptions(modelOptions => {
return [...modelOptions, ...localModelOptions];
});
setCustomModel('');
@ -306,6 +316,26 @@ const EditChannel = () => {
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Checkbox
checked={inputs.allow_streaming}
label='允许流式请求'
name='allow_streaming'
onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_streaming: !inputs.allow_streaming }));
}}
/>
</Form.Field>
<Form.Field>
<Form.Checkbox
checked={inputs.allow_non_streaming}
label='允许非流式请求'
name='allow_non_streaming'
onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_non_streaming: !inputs.allow_non_streaming }));
}}
/>
</Form.Field>
{
batch ? <Form.Field>
<Form.TextArea