From 80271b33ba0c2f8fd7f4c3ad6be8e943d1de9197 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 25 Sep 2023 18:44:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B8=A0=E9=81=93=E6=96=B0=E5=8F=AF?= =?UTF-8?q?=E9=80=89=E6=98=AF=E5=90=A6=E8=87=AA=E5=8A=A8=E7=A6=81=E7=94=A8?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 7 +++++- controller/relay-utils.go | 2 +- controller/relay.go | 3 ++- middleware/distributor.go | 1 + model/channel.go | 1 + web/src/pages/Channel/EditChannel.js | 35 +++++++++++++++++++++++++++- 6 files changed, 45 insertions(+), 4 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index f7a565a2..45cf604b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -183,7 +183,12 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr, -1) { + ban := true + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + if shouldDisableChannel(openaiErr, -1) && ban { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 04d365aa..0b132522 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -128,7 +128,7 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool { if err == nil { return false } - if statusCode == http.StatusUnauthorized { + if statusCode == http.StatusUnauthorized || statusCode == http.StatusTooManyRequests { return true } if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { diff --git a/controller/relay.go b/controller/relay.go index 79c99191..cd8d80b6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -236,9 +236,10 @@ func Relay(c *gin.Context) { }) } channelId := c.GetInt("channel_id") + autoBan := c.GetBool("auto_ban") common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) diff --git a/middleware/distributor.go b/middleware/distributor.go index 71235b34..8c0ecddd 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,7 @@ func Distribute() func(c *gin.Context) { c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) + c.Set("auto_ban", channel.AutoBan) c.Set("model_mapping", channel.GetModelMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) diff --git a/model/channel.go b/model/channel.go index 628f435e..a41a04e8 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,6 +25,7 @@ type Channel struct { UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` Priority *int64 `json:"priority" gorm:"bigint;default:0"` + AutoBan *int `json:"auto_ban" gorm:"default:1"` } func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 3ebb6b8f..3a8022a8 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -43,9 +43,12 @@ const EditChannel = () => { other: '', model_mapping: '', models: [], + auto_ban: 1, groups: ['default'] }; const [batch, setBatch] = useState(false); + const [autoBan, setAutoBan] = useState(true); + // const [autoBan, setAutoBan] = useState(true); const [inputs, setInputs] = useState(originInputs); const [originModelOptions, setOriginModelOptions] = useState([]); const [modelOptions, setModelOptions] = useState([]); @@ -82,6 +85,7 @@ const EditChannel = () => { } setInputs((inputs) => ({ ...inputs, models: localModels })); } + //setAutoBan }; const loadChannel = async () => { @@ -102,6 +106,12 @@ const EditChannel = () => { data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2); } setInputs(data); + if (data.auto_ban === 0) { + setAutoBan(false); + } else { + setAutoBan(true); + } + // console.log(data); } else { showError(message); } @@ -161,6 +171,11 @@ const EditChannel = () => { fetchGroups().then(); }, []); + useEffect(() => { + setInputs((inputs) => ({ ...inputs, auto_ban: autoBan ? 1 : 0 })); + console.log(autoBan); + }, [autoBan]); + const submit = async () => { if (!isEdit && (inputs.name === '' || inputs.key === '')) { showInfo('请填写渠道名称和渠道密钥!'); @@ -185,6 +200,11 @@ const EditChannel = () => { localInputs.other = 'v2.1'; } let res; + if (!Array.isArray(localInputs.models)) { + showError('提交失败,请勿重复提交!'); + handleCancel(); + return; + } localInputs.models = localInputs.models.join(','); localInputs.group = localInputs.groups.join(','); if (isEdit) { @@ -423,7 +443,20 @@ const EditChannel = () => { placeholder='请输入组织org-xxx' onChange={handleInputChange} value={inputs.openai_organization} - autoComplete='new-password' + /> + + + { + setAutoBan(!autoBan); + + } + } + // onChange={handleInputChange} /> {