From a588241515fa36aa3ee531a10a716c1930612526 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Mon, 24 Jul 2023 15:30:08 +0800 Subject: [PATCH] feat: allow toggling stream mode of channels --- controller/channel-test.go | 68 +++++++++++++++++++++++----- controller/channel.go | 3 +- controller/relay.go | 1 + middleware/distributor.go | 5 +- model/ability.go | 17 +++++-- model/cache.go | 16 +++++-- model/channel.go | 37 +++++++++++++-- web/src/pages/Channel/EditChannel.js | 36 +++++++++++++-- 8 files changed, 154 insertions(+), 29 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index be658fa8..d81d78ae 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -1,17 +1,20 @@ package controller import ( + "bufio" "bytes" "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" + "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { @@ -58,21 +61,64 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr return err, nil } defer resp.Body.Close() - var response TextResponse - err = json.NewDecoder(resp.Body).Decode(&response) - if err != nil { - return err, nil - } - if response.Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error + + isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + + if channel.AllowStreaming && isStream { + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + for scanner.Scan() { + data := scanner.Text() + if len(data) < 6 { // ignore blank line or wrong format + continue + } + data = data[6:] + if !strings.HasPrefix(data, "[DONE]") { + var streamResponse ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &streamResponse) + if err != nil { + return err, nil + } + for _, choice := range streamResponse.Choices { + responseText += choice.Delta.Content + } + } + } + + if responseText == "" { + return errors.New("Empty response"), nil + } + } else { + var response TextResponse + err = json.NewDecoder(resp.Body).Decode(&response) + if err != nil { + return err, nil + } + if response.Usage.CompletionTokens == 0 { + return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error + } } + return nil, nil } -func buildTestRequest() *ChatRequest { +func buildTestRequest(stream bool) *ChatRequest { testRequest := &ChatRequest{ Model: "", // this will be set later MaxTokens: 1, + Stream: stream, } testMessage := Message{ Role: "user", @@ -99,7 +145,7 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest() + testRequest := buildTestRequest(channel.AllowStreaming) tik := time.Now() err, _ = testChannel(channel, *testRequest) tok := time.Now() @@ -154,7 +200,6 @@ func testAllChannels(notify bool) error { if err != nil { return err } - testRequest := buildTestRequest() var disableThreshold = int64(common.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value @@ -165,6 +210,7 @@ func testAllChannels(notify bool) error { continue } tik := time.Now() + testRequest := buildTestRequest(channel.AllowStreaming) err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() diff --git a/controller/channel.go b/controller/channel.go index 8afc0eed..6dab76d7 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -1,12 +1,13 @@ package controller import ( - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "strings" + + "github.com/gin-gonic/gin" ) func GetAllChannels(c *gin.Context) { diff --git a/controller/relay.go b/controller/relay.go index 9cfa5c4f..493412dd 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -46,6 +46,7 @@ type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` MaxTokens int `json:"max_tokens"` + Stream bool `json:"stream"` } type TextRequest struct { diff --git a/middleware/distributor.go b/middleware/distributor.go index 91c00e1a..1940c69c 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -12,7 +12,8 @@ import ( ) type ModelRequest struct { - Model string `json:"model"` + Model string `json:"model"` + Stream bool `json:"stream" default:"true"` } func Distribute() func(c *gin.Context) { @@ -84,7 +85,7 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { diff --git a/model/ability.go b/model/ability.go index e87c3940..e167cf32 100644 --- a/model/ability.go +++ b/model/ability.go @@ -12,13 +12,22 @@ type Ability struct { Enabled bool `json:"enabled"` } -func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { ability := Ability{} var err error = nil - if common.UsingSQLite { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error + + cmd := "`group` = ? and model = ? and enabled = 1" + + if stream { + cmd += " and allow_streaming = 1" } else { - err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error + cmd += " and allow_non_streaming = 1" + } + + if common.UsingSQLite { + err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error + } else { + err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error } if err != nil { return nil, err diff --git a/model/cache.go b/model/cache.go index 64666c86..c2f29722 100644 --- a/model/cache.go +++ b/model/cache.go @@ -160,9 +160,9 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { if !common.RedisEnabled { - return GetRandomSatisfiedChannel(group, model) + return GetRandomSatisfiedChannel(group, model, stream) } channelSyncLock.RLock() defer channelSyncLock.RUnlock() @@ -170,6 +170,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error if len(channels) == 0 { return nil, errors.New("channel not found") } - idx := rand.Intn(len(channels)) - return channels[idx], nil + + var filteredChannels []*Channel + for _, channel := range channels { + if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) { + filteredChannels = append(filteredChannels, channel) + } + } + + idx := rand.Intn(len(filteredChannels)) + return filteredChannels[idx], nil } diff --git a/model/channel.go b/model/channel.go index 7cc9fa9b..8b019418 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,10 @@ package model import ( - "gorm.io/gorm" + "encoding/json" "one-api/common" + + "gorm.io/gorm" ) type Channel struct { @@ -23,6 +25,8 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` + AllowStreaming bool `json:"allow_streaming"` + AllowNonStreaming bool `json:"allow_non_streaming"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -80,7 +84,19 @@ func BatchInsertChannels(channels []Channel) error { func (channel *Channel) Insert() error { var err error - err = DB.Create(channel).Error + // turn channel into a map + channelMap := make(map[string]interface{}) + + // Convert channel struct to a map + channelBytes, err := json.Marshal(channel) + if err != nil { + return err + } + err = json.Unmarshal(channelBytes, &channelMap) + if err != nil { + return err + } + err = DB.Create(channelMap).Error if err != nil { return err } @@ -90,11 +106,24 @@ func (channel *Channel) Insert() error { func (channel *Channel) Update() error { var err error - err = DB.Model(channel).Updates(channel).Error + // turn channel into a map + channelMap := make(map[string]interface{}) + + // Convert channel struct to a map + channelBytes, err := json.Marshal(channel) if err != nil { return err } - DB.Model(channel).First(channel, "id = ?", channel.Id) + err = json.Unmarshal(channelBytes, &channelMap) + if err != nil { + return err + } + + err = DB.Model(channel).Updates(channelMap).Error + if err != nil { + return err + } + DB.Model(channel).First(channelMap, "id = ?", channel.Id) err = channel.UpdateAbilities() return err } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 7833c7f3..5da899d6 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -22,6 +22,8 @@ const EditChannel = () => { base_url: '', other: '', model_mapping: '', + allow_streaming: true, + allow_non_streaming: true, models: [], groups: ['default'] }; @@ -94,6 +96,9 @@ const EditChannel = () => { useEffect(() => { let localModelOptions = [...originModelOptions]; + if (!Array.isArray(inputs.models)) { + inputs.models = inputs.models.split(','); + } inputs.models.forEach((model) => { if (!localModelOptions.find((option) => option.key === model)) { localModelOptions.push({ @@ -127,6 +132,11 @@ const EditChannel = () => { showInfo('模型映射必须是合法的 JSON 格式!'); return; } + // allow streaming and allow non streaming cannot be both false + if (!inputs.allow_streaming && !inputs.allow_non_streaming) { + showInfo('流式请求和非流式请求不能同时禁用!'); + return; + } let localInputs = inputs; if (localInputs.base_url.endsWith('/')) { localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); @@ -176,7 +186,7 @@ const EditChannel = () => { 注意,模型部署名称必须和模型名称保持一致,因为 One API 会把请求体中的 model 参数替换为你的部署名称(模型名称中的点会被剔除),图片演示。 + href='https://github.com/songquanpeng/one-api/issues/133?notification_referrer_id=NT_kwDOAmJSYrM2NjIwMzI3NDgyOjM5OTk4MDUw#issuecomment-1571602271'>图片演示。 { }}>清除所有模型 { +