diff --git a/common/constants.go b/common/constants.go index 81f98163..7165a544 100644 --- a/common/constants.go +++ b/common/constants.go @@ -138,6 +138,16 @@ const ( ChannelStatusDisabled = 2 // also don't use 0 ) +const ( + ChannelAllowNonStreamEnabled = 1 + ChannelAllowNonStreamDisabled = 2 +) + +const ( + ChannelAllowStreamEnabled = 1 + ChannelAllowStreamDisabled = 2 +) + const ( ChannelTypeUnknown = 0 ChannelTypeOpenAI = 1 diff --git a/controller/channel-test.go b/controller/channel-test.go index d81d78ae..ae3f60ab 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -64,7 +64,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if channel.AllowStreaming && isStream { + if channel.AllowStreaming == common.ChannelAllowStreamEnabled && isStream { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -145,7 +145,7 @@ func TestChannel(c *gin.Context) { }) return } - testRequest := buildTestRequest(channel.AllowStreaming) + testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled) tik := time.Now() err, _ = testChannel(channel, *testRequest) tok := time.Now() @@ -210,7 +210,7 @@ func testAllChannels(notify bool) error { continue } tik := time.Now() - testRequest := buildTestRequest(channel.AllowStreaming) + testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled) err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() diff --git a/model/ability.go b/model/ability.go index e167cf32..465acb88 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,6 +1,7 @@ package model import ( + "fmt" "one-api/common" "strings" ) @@ -19,9 +20,9 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe cmd := "`group` = ? and model = ? and enabled = 1" if stream { - cmd += " and allow_streaming = 1" + cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled) } else { - cmd += " and allow_non_streaming = 1" + cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled) } if common.UsingSQLite { diff --git a/model/cache.go b/model/cache.go index c2f29722..d734351d 100644 --- a/model/cache.go +++ b/model/cache.go @@ -173,7 +173,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C var filteredChannels []*Channel for _, channel := range channels { - if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) { + if (stream && channel.AllowStreaming == common.ChannelAllowStreamEnabled) || (!stream && channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled) { filteredChannels = append(filteredChannels, channel) } } diff --git a/model/channel.go b/model/channel.go index 8b019418..738e6f7c 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,7 +1,6 @@ package model import ( - "encoding/json" "one-api/common" "gorm.io/gorm" @@ -25,8 +24,8 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(32);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - AllowStreaming bool `json:"allow_streaming"` - AllowNonStreaming bool `json:"allow_non_streaming"` + AllowStreaming int `json:"allow_streaming" gorm:"default:2"` + AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:2"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { @@ -84,19 +83,7 @@ func BatchInsertChannels(channels []Channel) error { func (channel *Channel) Insert() error { var err error - // turn channel into a map - channelMap := make(map[string]interface{}) - - // Convert channel struct to a map - channelBytes, err := json.Marshal(channel) - if err != nil { - return err - } - err = json.Unmarshal(channelBytes, &channelMap) - if err != nil { - return err - } - err = DB.Create(channelMap).Error + err = DB.Create(channel).Error if err != nil { return err } @@ -106,24 +93,11 @@ func (channel *Channel) Insert() error { func (channel *Channel) Update() error { var err error - // turn channel into a map - channelMap := make(map[string]interface{}) - - // Convert channel struct to a map - channelBytes, err := json.Marshal(channel) + err = DB.Model(channel).Updates(channel).Error if err != nil { return err } - err = json.Unmarshal(channelBytes, &channelMap) - if err != nil { - return err - } - - err = DB.Model(channel).Updates(channelMap).Error - if err != nil { - return err - } - DB.Model(channel).First(channelMap, "id = ?", channel.Id) + DB.Model(channel).First(channel, "id = ?", channel.Id) err = channel.UpdateAbilities() return err } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 5da899d6..b992e9a9 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -22,8 +22,8 @@ const EditChannel = () => { base_url: '', other: '', model_mapping: '', - allow_streaming: true, - allow_non_streaming: true, + allow_streaming: 1, + allow_non_streaming: 1, models: [], groups: ['default'] }; @@ -133,7 +133,7 @@ const EditChannel = () => { return; } // allow streaming and allow non streaming cannot be both false - if (!inputs.allow_streaming && !inputs.allow_non_streaming) { + if (inputs.allow_streaming === 2 && inputs.allow_non_streaming === 2) { showInfo('流式请求和非流式请求不能同时禁用!'); return; } @@ -318,21 +318,21 @@ const EditChannel = () => {