diff --git a/controller/channel-test.go b/controller/channel-test.go
index be658fa8..d81d78ae 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -1,17 +1,20 @@
package controller
import (
+ "bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
+ "strings"
"sync"
"time"
+
+ "github.com/gin-gonic/gin"
)
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
@@ -58,21 +61,64 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
return err, nil
}
defer resp.Body.Close()
- var response TextResponse
- err = json.NewDecoder(resp.Body).Decode(&response)
- if err != nil {
- return err, nil
- }
- if response.Usage.CompletionTokens == 0 {
- return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
+
+ isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
+
+ if channel.AllowStreaming && isStream {
+ responseText := ""
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
+ }
+ if i := strings.Index(string(data), "\n"); i >= 0 {
+ return i + 1, data[0:i], nil
+ }
+ if atEOF {
+ return len(data), data, nil
+ }
+ return 0, nil, nil
+ })
+ for scanner.Scan() {
+ data := scanner.Text()
+ if len(data) < 6 { // ignore blank line or wrong format
+ continue
+ }
+ data = data[6:]
+ if !strings.HasPrefix(data, "[DONE]") {
+ var streamResponse ChatCompletionsStreamResponse
+ err := json.Unmarshal([]byte(data), &streamResponse)
+ if err != nil {
+ return err, nil
+ }
+ for _, choice := range streamResponse.Choices {
+ responseText += choice.Delta.Content
+ }
+ }
+ }
+
+ if responseText == "" {
+ return errors.New("Empty response"), nil
+ }
+ } else {
+ var response TextResponse
+ err = json.NewDecoder(resp.Body).Decode(&response)
+ if err != nil {
+ return err, nil
+ }
+ if response.Usage.CompletionTokens == 0 {
+ return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
+ }
}
+
return nil, nil
}
-func buildTestRequest() *ChatRequest {
+func buildTestRequest(stream bool) *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
+ Stream: stream,
}
testMessage := Message{
Role: "user",
@@ -99,7 +145,7 @@ func TestChannel(c *gin.Context) {
})
return
}
- testRequest := buildTestRequest()
+ testRequest := buildTestRequest(channel.AllowStreaming)
tik := time.Now()
err, _ = testChannel(channel, *testRequest)
tok := time.Now()
@@ -154,7 +200,6 @@ func testAllChannels(notify bool) error {
if err != nil {
return err
}
- testRequest := buildTestRequest()
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
@@ -165,6 +210,7 @@ func testAllChannels(notify bool) error {
continue
}
tik := time.Now()
+ testRequest := buildTestRequest(channel.AllowStreaming)
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
diff --git a/controller/channel.go b/controller/channel.go
index 8afc0eed..6dab76d7 100644
--- a/controller/channel.go
+++ b/controller/channel.go
@@ -1,12 +1,13 @@
package controller
import (
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
+
+ "github.com/gin-gonic/gin"
)
func GetAllChannels(c *gin.Context) {
diff --git a/controller/relay.go b/controller/relay.go
index 9cfa5c4f..493412dd 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -46,6 +46,7 @@ type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
+ Stream bool `json:"stream"`
}
type TextRequest struct {
diff --git a/middleware/distributor.go b/middleware/distributor.go
index 91c00e1a..1940c69c 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -12,7 +12,8 @@ import (
)
type ModelRequest struct {
- Model string `json:"model"`
+ Model string `json:"model"`
+ Stream bool `json:"stream" default:"true"`
}
func Distribute() func(c *gin.Context) {
@@ -84,7 +85,7 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e"
}
}
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil {
diff --git a/model/ability.go b/model/ability.go
index e87c3940..e167cf32 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -12,13 +12,22 @@ type Ability struct {
Enabled bool `json:"enabled"`
}
-func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
ability := Ability{}
var err error = nil
- if common.UsingSQLite {
- err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
+
+ cmd := "`group` = ? and model = ? and enabled = 1"
+
+ if stream {
+ cmd += " and allow_streaming = 1"
} else {
- err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
+ cmd += " and allow_non_streaming = 1"
+ }
+
+ if common.UsingSQLite {
+ err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
+ } else {
+ err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
}
if err != nil {
return nil, err
diff --git a/model/cache.go b/model/cache.go
index 64666c86..c2f29722 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) {
}
}
-func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
if !common.RedisEnabled {
- return GetRandomSatisfiedChannel(group, model)
+ return GetRandomSatisfiedChannel(group, model, stream)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
@@ -170,6 +170,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 {
return nil, errors.New("channel not found")
}
- idx := rand.Intn(len(channels))
- return channels[idx], nil
+
+ var filteredChannels []*Channel
+ for _, channel := range channels {
+ if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) {
+ filteredChannels = append(filteredChannels, channel)
+ }
+ }
+
+ idx := rand.Intn(len(filteredChannels))
+ return filteredChannels[idx], nil
}
diff --git a/model/channel.go b/model/channel.go
index 7cc9fa9b..8b019418 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -1,8 +1,10 @@
package model
import (
- "gorm.io/gorm"
+ "encoding/json"
"one-api/common"
+
+ "gorm.io/gorm"
)
type Channel struct {
@@ -23,6 +25,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+ AllowStreaming bool `json:"allow_streaming"`
+ AllowNonStreaming bool `json:"allow_non_streaming"`
}
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -80,7 +84,19 @@ func BatchInsertChannels(channels []Channel) error {
func (channel *Channel) Insert() error {
var err error
- err = DB.Create(channel).Error
+ // turn channel into a map
+ channelMap := make(map[string]interface{})
+
+ // Convert channel struct to a map
+ channelBytes, err := json.Marshal(channel)
+ if err != nil {
+ return err
+ }
+ err = json.Unmarshal(channelBytes, &channelMap)
+ if err != nil {
+ return err
+ }
+ err = DB.Create(channelMap).Error
if err != nil {
return err
}
@@ -90,11 +106,24 @@ func (channel *Channel) Insert() error {
func (channel *Channel) Update() error {
var err error
- err = DB.Model(channel).Updates(channel).Error
+ // turn channel into a map
+ channelMap := make(map[string]interface{})
+
+ // Convert channel struct to a map
+ channelBytes, err := json.Marshal(channel)
if err != nil {
return err
}
- DB.Model(channel).First(channel, "id = ?", channel.Id)
+ err = json.Unmarshal(channelBytes, &channelMap)
+ if err != nil {
+ return err
+ }
+
+ err = DB.Model(channel).Updates(channelMap).Error
+ if err != nil {
+ return err
+ }
+ DB.Model(channel).First(channelMap, "id = ?", channel.Id)
err = channel.UpdateAbilities()
return err
}
diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js
index 7833c7f3..5da899d6 100644
--- a/web/src/pages/Channel/EditChannel.js
+++ b/web/src/pages/Channel/EditChannel.js
@@ -22,6 +22,8 @@ const EditChannel = () => {
base_url: '',
other: '',
model_mapping: '',
+ allow_streaming: true,
+ allow_non_streaming: true,
models: [],
groups: ['default']
};
@@ -94,6 +96,9 @@ const EditChannel = () => {
useEffect(() => {
let localModelOptions = [...originModelOptions];
+ if (!Array.isArray(inputs.models)) {
+ inputs.models = inputs.models.split(',');
+ }
inputs.models.forEach((model) => {
if (!localModelOptions.find((option) => option.key === model)) {
localModelOptions.push({
@@ -127,6 +132,11 @@ const EditChannel = () => {
showInfo('模型映射必须是合法的 JSON 格式!');
return;
}
+ // allow streaming and allow non streaming cannot be both false
+ if (!inputs.allow_streaming && !inputs.allow_non_streaming) {
+ showInfo('流式请求和非流式请求不能同时禁用!');
+ return;
+ }
let localInputs = inputs;
if (localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1);
@@ -176,7 +186,7 @@ const EditChannel = () => {