diff --git a/controller/channel-test.go b/controller/channel-test.go index f2ffad01..fad5de8f 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -14,7 +14,7 @@ import ( "time" ) -func testChannel(channel *model.Channel, request ChatRequest) error { +func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { switch channel.Type { case common.ChannelTypeAzure: request.Model = "gpt-35-turbo" @@ -33,11 +33,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error { jsonData, err := json.Marshal(request) if err != nil { - return err + return err, nil } req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) if err != nil { - return err + return err, nil } if channel.Type == common.ChannelTypeAzure { req.Header.Set("api-key", channel.Key) @@ -48,18 +48,18 @@ func testChannel(channel *model.Channel, request ChatRequest) error { client := &http.Client{} resp, err := client.Do(req) if err != nil { - return err + return err, nil } defer resp.Body.Close() var response TextResponse err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { - return err + return err, nil } if response.Usage.CompletionTokens == 0 { - return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) + return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error } - return nil + return nil, nil } func buildTestRequest() *ChatRequest { @@ -94,7 +94,7 @@ func TestChannel(c *gin.Context) { } testRequest := buildTestRequest() tik := time.Now() - err = testChannel(channel, *testRequest) + err, _ = testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() go channel.UpdateResponseTime(milliseconds) @@ -158,13 +158,14 @@ func testAllChannels(notify bool) error { continue } tik := time.Now() - err := testChannel(channel, *testRequest) + err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() - if err != nil || milliseconds > disableThreshold { - if milliseconds > disableThreshold { - err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - } + if milliseconds > disableThreshold { + err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) + disableChannel(channel.Id, channel.Name, err.Error()) + } + if shouldDisableChannel(openaiErr) { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 35c7fa82..2133d8be 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -91,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus StatusCode: statusCode, } } + +func shouldDisableChannel(err *OpenAIError) bool { + if !common.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + return false +} diff --git a/controller/relay.go b/controller/relay.go index 88ae8acc..9cfa5c4f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -171,7 +171,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") { + if shouldDisableChannel(&err.OpenAIError) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message)