diff --git a/controller/relay.go b/controller/relay.go index 2d81bc7e..e31dcd9b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -45,6 +45,18 @@ type StreamResponse struct { } func Relay(c *gin.Context) { + err := relayHelper(c) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "error": gin.H{ + "message": err.Error(), + "type": "one_api_error", + }, + }) + } +} + +func relayHelper(c *gin.Context) error { channelType := c.GetInt("channel") tokenId := c.GetInt("token_id") consumeQuota := c.GetBool("consume_quota") @@ -54,47 +66,27 @@ func Relay(c *gin.Context) { } requestBody, err := io.ReadAll(c.Request.Body) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } err = c.Request.Body.Close() if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } var textRequest TextRequest err = json.Unmarshal(requestBody, &textRequest) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } // Reset request body c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) requestURL := c.Request.URL.String() req, err := http.NewRequest(c.Request.Method, fmt.Sprintf("%s%s", baseURL, requestURL), c.Request.Body) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err + } + err = c.Request.Body.Close() + if err != nil { + return err } req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) @@ -103,23 +95,11 @@ func Relay(c *gin.Context) { client := &http.Client{} resp, err := client.Do(req) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } err = req.Body.Close() if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } var textResponse TextResponse @@ -192,53 +172,38 @@ func Relay(c *gin.Context) { return false } }) - return + err = resp.Body.Close() + if err != nil { + return err + } + return nil } else { for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } responseBody, err := io.ReadAll(resp.Body) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } err = resp.Body.Close() if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } err = json.Unmarshal(responseBody, &textResponse) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) _, err = io.Copy(c.Writer, resp.Body) if err != nil { - c.JSON(http.StatusOK, gin.H{ - "error": gin.H{ - "message": err.Error(), - "type": "one_api_error", - }, - }) - return + return err } + err = resp.Body.Close() + if err != nil { + return err + } + return nil } }