fix: use int instead of bool

This commit is contained in:
ckt1031 2023-07-24 18:12:16 +08:00
parent a588241515
commit 30187cebe8
6 changed files with 29 additions and 44 deletions

View File

@ -138,6 +138,16 @@ const (
ChannelStatusDisabled = 2 // also don't use 0 ChannelStatusDisabled = 2 // also don't use 0
) )
const (
ChannelAllowNonStreamEnabled = 1
ChannelAllowNonStreamDisabled = 2
)
const (
ChannelAllowStreamEnabled = 1
ChannelAllowStreamDisabled = 2
)
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = 0
ChannelTypeOpenAI = 1 ChannelTypeOpenAI = 1

View File

@ -64,7 +64,7 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if channel.AllowStreaming && isStream { if channel.AllowStreaming == common.ChannelAllowStreamEnabled && isStream {
responseText := "" responseText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -145,7 +145,7 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest(channel.AllowStreaming) testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel, *testRequest) err, _ = testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
@ -210,7 +210,7 @@ func testAllChannels(notify bool) error {
continue continue
} }
tik := time.Now() tik := time.Now()
testRequest := buildTestRequest(channel.AllowStreaming) testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
err, openaiErr := testChannel(channel, *testRequest) err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()

View File

@ -1,6 +1,7 @@
package model package model
import ( import (
"fmt"
"one-api/common" "one-api/common"
"strings" "strings"
) )
@ -19,9 +20,9 @@ func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channe
cmd := "`group` = ? and model = ? and enabled = 1" cmd := "`group` = ? and model = ? and enabled = 1"
if stream { if stream {
cmd += " and allow_streaming = 1" cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled)
} else { } else {
cmd += " and allow_non_streaming = 1" cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
} }
if common.UsingSQLite { if common.UsingSQLite {

View File

@ -173,7 +173,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*C
var filteredChannels []*Channel var filteredChannels []*Channel
for _, channel := range channels { for _, channel := range channels {
if (stream && channel.AllowStreaming) || (!stream && channel.AllowNonStreaming) { if (stream && channel.AllowStreaming == common.ChannelAllowStreamEnabled) || (!stream && channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled) {
filteredChannels = append(filteredChannels, channel) filteredChannels = append(filteredChannels, channel)
} }
} }

View File

@ -1,7 +1,6 @@
package model package model
import ( import (
"encoding/json"
"one-api/common" "one-api/common"
"gorm.io/gorm" "gorm.io/gorm"
@ -25,8 +24,8 @@ type Channel struct {
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
AllowStreaming bool `json:"allow_streaming"` AllowStreaming int `json:"allow_streaming" gorm:"default:2"`
AllowNonStreaming bool `json:"allow_non_streaming"` AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:2"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@ -84,19 +83,7 @@ func BatchInsertChannels(channels []Channel) error {
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
var err error var err error
// turn channel into a map err = DB.Create(channel).Error
channelMap := make(map[string]interface{})
// Convert channel struct to a map
channelBytes, err := json.Marshal(channel)
if err != nil {
return err
}
err = json.Unmarshal(channelBytes, &channelMap)
if err != nil {
return err
}
err = DB.Create(channelMap).Error
if err != nil { if err != nil {
return err return err
} }
@ -106,24 +93,11 @@ func (channel *Channel) Insert() error {
func (channel *Channel) Update() error { func (channel *Channel) Update() error {
var err error var err error
// turn channel into a map err = DB.Model(channel).Updates(channel).Error
channelMap := make(map[string]interface{})
// Convert channel struct to a map
channelBytes, err := json.Marshal(channel)
if err != nil { if err != nil {
return err return err
} }
err = json.Unmarshal(channelBytes, &channelMap) DB.Model(channel).First(channel, "id = ?", channel.Id)
if err != nil {
return err
}
err = DB.Model(channel).Updates(channelMap).Error
if err != nil {
return err
}
DB.Model(channel).First(channelMap, "id = ?", channel.Id)
err = channel.UpdateAbilities() err = channel.UpdateAbilities()
return err return err
} }

View File

@ -22,8 +22,8 @@ const EditChannel = () => {
base_url: '', base_url: '',
other: '', other: '',
model_mapping: '', model_mapping: '',
allow_streaming: true, allow_streaming: 1,
allow_non_streaming: true, allow_non_streaming: 1,
models: [], models: [],
groups: ['default'] groups: ['default']
}; };
@ -133,7 +133,7 @@ const EditChannel = () => {
return; return;
} }
// allow streaming and allow non streaming cannot be both false // allow streaming and allow non streaming cannot be both false
if (!inputs.allow_streaming && !inputs.allow_non_streaming) { if (inputs.allow_streaming === 2 && inputs.allow_non_streaming === 2) {
showInfo('流式请求和非流式请求不能同时禁用!'); showInfo('流式请求和非流式请求不能同时禁用!');
return; return;
} }
@ -318,21 +318,21 @@ const EditChannel = () => {
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>
<Form.Checkbox <Form.Checkbox
checked={inputs.allow_streaming} checked={inputs.allow_streaming === 1}
label='允许流式请求' label='允许流式请求'
name='allow_streaming' name='allow_streaming'
onChange={() => { onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_streaming: !inputs.allow_streaming })); setInputs((inputs) => ({ ...inputs, allow_streaming: inputs.allow_streaming === 1 ? 2 : 1 }));
}} }}
/> />
</Form.Field> </Form.Field>
<Form.Field> <Form.Field>
<Form.Checkbox <Form.Checkbox
checked={inputs.allow_non_streaming} checked={inputs.allow_non_streaming === 1}
label='允许非流式请求' label='允许非流式请求'
name='allow_non_streaming' name='allow_non_streaming'
onChange={() => { onChange={() => {
setInputs((inputs) => ({ ...inputs, allow_non_streaming: !inputs.allow_non_streaming })); setInputs((inputs) => ({ ...inputs, allow_non_streaming: inputs.allow_non_streaming === 1 ? 2 : 1 }));
}} }}
/> />
</Form.Field> </Form.Field>