fix: using whitelist when disabling channels (close #292)

This commit is contained in:
JustSong 2023-07-22 18:15:30 +08:00
parent 12a0e7105e
commit 0495b9a0d7
3 changed files with 28 additions and 14 deletions

View File

@ -14,7 +14,7 @@ import (
"time" "time"
) )
func testChannel(channel *model.Channel, request ChatRequest) error { func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo" request.Model = "gpt-35-turbo"
@ -33,11 +33,11 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
jsonData, err := json.Marshal(request) jsonData, err := json.Marshal(request)
if err != nil { if err != nil {
return err return err, nil
} }
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return err return err, nil
} }
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key) req.Header.Set("api-key", channel.Key)
@ -48,18 +48,18 @@ func testChannel(channel *model.Channel, request ChatRequest) error {
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return err return err, nil
} }
defer resp.Body.Close() defer resp.Body.Close()
var response TextResponse var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response) err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil { if err != nil {
return err return err, nil
} }
if response.Usage.CompletionTokens == 0 { if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)) return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
} }
return nil return nil, nil
} }
func buildTestRequest() *ChatRequest { func buildTestRequest() *ChatRequest {
@ -94,7 +94,7 @@ func TestChannel(c *gin.Context) {
} }
testRequest := buildTestRequest() testRequest := buildTestRequest()
tik := time.Now() tik := time.Now()
err = testChannel(channel, *testRequest) err, _ = testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
@ -158,13 +158,14 @@ func testAllChannels(notify bool) error {
continue continue
} }
tik := time.Now() tik := time.Now()
err := testChannel(channel, *testRequest) err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil || milliseconds > disableThreshold { if milliseconds > disableThreshold {
if milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error())
} }
if shouldDisableChannel(openaiErr) {
disableChannel(channel.Id, channel.Name, err.Error()) disableChannel(channel.Id, channel.Name, err.Error())
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)

View File

@ -91,3 +91,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
StatusCode: statusCode, StatusCode: statusCode,
} }
} }
func shouldDisableChannel(err *OpenAIError) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
return false
}

View File

@ -171,7 +171,7 @@ func Relay(c *gin.Context) {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if common.AutomaticDisableChannelEnabled && (err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated") { if shouldDisableChannel(&err.OpenAIError) {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name") channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message) disableChannel(channelId, channelName, err.Message)