From ceb289cb4d7cd0fda4a2a44e9002ffb6051da939 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 18 May 2023 11:11:15 +0800 Subject: [PATCH] fix: handel error response from server correctly (close #90) --- controller/channel.go | 8 ++-- controller/relay.go | 90 +++++++++++++++++++++++++++---------------- 2 files changed, 61 insertions(+), 37 deletions(-) diff --git a/controller/channel.go b/controller/channel.go index af8b0b28..1b863a0d 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -265,14 +265,14 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false // disable & notify -func disableChannel(channelId int, channelName string, err error) { +func disableChannel(channelId int, channelName string, reason string) { if common.RootUserEmail == "" { common.RootUserEmail = model.GetRootUserEmail() } model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error()) - err = common.SendEmail(subject, common.RootUserEmail, content) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) + err := common.SendEmail(subject, common.RootUserEmail, content) if err != nil { common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) } @@ -312,7 +312,7 @@ func testAllChannels(c *gin.Context) error { if milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) } - disableChannel(channel.Id, channel.Name, err) + disableChannel(channel.Id, channel.Name, err.Error()) } channel.UpdateResponseTime(milliseconds) } diff --git a/controller/relay.go b/controller/relay.go index a24b7c5b..4803794b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "encoding/json" - "errors" "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" @@ -47,6 +46,11 @@ type OpenAIError struct { Code string `json:"code"` } +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} + type TextResponse struct { Usage `json:"usage"` Error OpenAIError `json:"error"` @@ -71,23 +75,33 @@ func countToken(text string) int { func Relay(c *gin.Context) { err := relayHelper(c) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, + c.JSON(err.StatusCode, gin.H{ + "error": err.OpenAIError, }) channelId := c.GetInt("channel_id") - common.SysError(fmt.Sprintf("Relay error: %s, channel id: %d", err.Error(), channelId)) - if common.AutomaticDisableChannelEnabled { + common.SysError(fmt.Sprintf("Relay error (channel #%d): %s", channelId, err.Message)) + if err.Type != "invalid_request_error" && err.StatusCode != http.StatusTooManyRequests && + common.AutomaticDisableChannelEnabled { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") - disableChannel(channelId, channelName, err) + disableChannel(channelId, channelName, err.Message) } } } -func relayHelper(c *gin.Context) error { +func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { + openAIError := OpenAIError{ + Message: err.Error(), + Type: "one_api_error", + Code: code, + } + return &OpenAIErrorWithStatusCode{ + OpenAIError: openAIError, + StatusCode: statusCode, + } +} + +func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") @@ -95,15 +109,15 @@ func relayHelper(c *gin.Context) error { if consumeQuota || channelType == common.ChannelTypeAzure { requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - return err + return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest) } err = c.Request.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest) } err = json.Unmarshal(requestBody, &textRequest) if err != nil { - return err + return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest) } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) @@ -146,12 +160,12 @@ func relayHelper(c *gin.Context) error { if consumeQuota { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return err + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusOK) } } req, err := http.NewRequest(c.Request.Method, fullRequestURL, c.Request.Body) if err != nil { - return err + return errorWrapper(err, "new_request_failed", http.StatusOK) } if channelType == common.ChannelTypeAzure { key := c.Request.Header.Get("Authorization") @@ -166,15 +180,15 @@ func relayHelper(c *gin.Context) error { client := &http.Client{} resp, err := client.Do(req) if err != nil { - return err + return errorWrapper(err, "do_request_failed", http.StatusOK) } err = req.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_request_body_failed", http.StatusOK) } err = c.Request.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_request_body_failed", http.StatusOK) } var textResponse TextResponse isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") @@ -259,50 +273,60 @@ func relayHelper(c *gin.Context) error { }) err = resp.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_response_body_failed", http.StatusOK) } return nil } else { - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } if consumeQuota { responseBody, err := io.ReadAll(resp.Body) if err != nil { - return err + return errorWrapper(err, "read_response_body_failed", http.StatusOK) } err = resp.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_response_body_failed", http.StatusOK) } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - return err + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusOK) } if textResponse.Error.Type != "" { - return errors.New(fmt.Sprintf("type %s, code %s, message %s", - textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message)) + return &OpenAIErrorWithStatusCode{ + OpenAIError: textResponse.Error, + StatusCode: resp.StatusCode, + } } // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the client will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - return err + return errorWrapper(err, "copy_response_body_failed", http.StatusOK) } err = resp.Body.Close() if err != nil { - return err + return errorWrapper(err, "close_response_body_failed", http.StatusOK) } return nil } } func RelayNotImplemented(c *gin.Context) { + err := OpenAIError{ + Message: "API not implemented", + Type: "one_api_error", + Param: "", + Code: "api_not_implemented", + } c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": "Not Implemented", - "type": "one_api_error", - }, + "error": err, }) }