From 43be1982d7f4077404de1a60deef286d21b3abab Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sat, 9 Sep 2023 01:50:41 +0800 Subject: [PATCH] merge --- controller/channel-test.go | 2 +- controller/midjourney.go | 10 +++++----- controller/relay-text.go | 14 +------------- controller/relay-utils.go | 36 +++++++++++++++++++++++++++++++++++- controller/relay.go | 2 +- 5 files changed, 43 insertions(+), 21 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 8465d51d..4acb2e3b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -174,7 +174,7 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if shouldDisableChannel(openaiErr) { + if shouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/midjourney.go b/controller/midjourney.go index ff7f5115..e9674efc 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -15,13 +15,13 @@ import ( func UpdateMidjourneyTask() { //revocer - defer func() { - if err := recover(); err != nil { - log.Printf("UpdateMidjourneyTask: %v", err) - } - }() imageModel := "midjourney" for { + defer func() { + if err := recover(); err != nil { + log.Printf("UpdateMidjourneyTask: %v", err) + } + }() time.Sleep(time.Duration(15) * time.Second) tasks := model.GetAllUnFinishTasks() if len(tasks) != 0 { diff --git a/controller/relay-text.go b/controller/relay-text.go index 841bb4c3..20b1696a 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" "one-api/model" @@ -331,18 +330,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - //print resp body - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Println("read resp err body failed", err) - } - log.Println("resp body:", string(body)) - errStr := fmt.Sprintf("bad status code: %d", resp.StatusCode) - if resp.StatusCode == 503 { - errStr = string(body) - } - return errorWrapper( - fmt.Errorf(errStr), "bad_status_code", resp.StatusCode) + return relayErrorHandler(resp) } } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 5b3e0274..1a9ee0d1 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,10 +1,14 @@ package controller import ( + "encoding/json" "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "io" + "net/http" "one-api/common" + "strconv" ) var stopFinishReason = "stop" @@ -95,13 +99,16 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus } } -func shouldDisableChannel(err *OpenAIError) bool { +func shouldDisableChannel(err *OpenAIError, statusCode int) bool { if !common.AutomaticDisableChannelEnabled { return false } if err == nil { return false } + if statusCode == http.StatusUnauthorized { + return true + } if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } @@ -115,3 +122,30 @@ func setEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("Transfer-Encoding", "chunked") c.Writer.Header().Set("X-Accel-Buffering", "no") } + +func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) { + openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{ + StatusCode: resp.StatusCode, + OpenAIError: OpenAIError{ + Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), + Type: "one_api_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + err = resp.Body.Close() + if err != nil { + return + } + var textResponse TextResponse + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return + } + openAIErrorWithStatusCode.OpenAIError = textResponse.Error + return +} diff --git a/controller/relay.go b/controller/relay.go index d5498f9c..ea676d79 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -207,7 +207,7 @@ func Relay(c *gin.Context) { channelId := c.GetInt("channel_id") common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError) { + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message)