diff --git a/controller/channel-test.go b/controller/channel-test.go index f7a565a2..45cf604b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -183,7 +183,12 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr, -1) { + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + if shouldDisableChannel(openaiErr, -1) && ban { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 04d365aa..0b132522 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -128,7 +128,7 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { if err == nil { return false } - if statusCode == http.StatusUnauthorized { + if statusCode == http.StatusUnauthorized || statusCode == http.StatusTooManyRequests { return true } if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { diff --git a/controller/relay.go b/controller/relay.go index 79c99191..cd8d80b6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,9 +236,10 @@ func Relay(c *gin.Context) { }) } channelId := c.GetInt("channel_id") + autoBan := c.GetBool("auto_ban") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) diff --git a/middleware/distributor.go b/middleware/distributor.go index 71235b34..8c0ecddd 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,7 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) + c.Set("auto_ban", channel.AutoBan) c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) diff --git a/model/channel.go b/model/channel.go index 628f435e..a41a04e8 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,6 +25,7 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` + AutoBan *int `json:"auto_ban" gorm:"default:1"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 3ebb6b8f..3a8022a8 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -43,9 +43,12 @@ const EditChannel = () => { other: '', model_mapping: '', models: [], + auto_ban: 1, groups: ['default'] }; const [batch, setBatch] = useState(false); + const [autoBan, setAutoBan] = useState(true); + // const [autoBan, setAutoBan] = useState(true); const [inputs, setInputs] = useState(originInputs); const [originModelOptions, setOriginModelOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); @@ -82,6 +85,7 @@ const EditChannel = () => { } setInputs((inputs) => ({ ...inputs, models: localModels })); } + //setAutoBan }; const loadChannel = async () => { @@ -102,6 +106,12 @@ const EditChannel = () => { data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } setInputs(data); + if (data.auto_ban === 0) { + setAutoBan(false); + } else { + setAutoBan(true); + } + // console.log(data); } else { showError(message); } @@ -161,6 +171,11 @@ const EditChannel = () => { fetchGroups().then(); }, []); + useEffect(() => { + setInputs((inputs) => ({ ...inputs, auto_ban: autoBan ? 1 : 0 })); + console.log(autoBan); + }, [autoBan]); + const submit = async () => { if (!isEdit && (inputs.name === '' || inputs.key === '')) { showInfo('请填写渠道名称和渠道密钥!'); @@ -185,6 +200,11 @@ const EditChannel = () => { localInputs.other = 'v2.1'; } let res; + if (!Array.isArray(localInputs.models)) { + showError('提交失败,请勿重复提交!'); + handleCancel(); + return; + } localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(','); if (isEdit) { @@ -423,7 +443,20 @@ const EditChannel = () => { placeholder='请输入组织org-xxx' onChange={handleInputChange} value={inputs.openai_organization} - autoComplete='new-password' + /> + + + { + setAutoBan(!autoBan); + + } + } + // onChange={handleInputChange} /> {