diff --git a/common/constants.go b/common/constants.go index c7d3f222..05d06443 100644 --- a/common/constants.go +++ b/common/constants.go @@ -83,6 +83,7 @@ var PreConsumedQuota = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 +var RetryInterval = 0 // unit is millisecond var RootUserEmail = "" var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" diff --git a/controller/relay.go b/controller/relay.go index 1926110e..12921f1b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,10 +1,11 @@ package controller import ( + "encoding/json" + "errors" "fmt" "net/http" "one-api/common" - "strconv" "strings" "github.com/gin-gonic/gin" @@ -197,30 +198,28 @@ func Relay(c *gin.Context) { } if err != nil { requestId := c.GetString(common.RequestIdKey) - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = common.RetryTimes - } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" - } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) - c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, - }) - } - channelId := c.GetInt("channel_id") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) - // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + go func() { + defer func() { + if r := recover(); r != nil { + //ignore + } + }() channelId := c.GetInt("channel_id") - channelName := c.GetString("channel_name") - disableChannel(channelId, channelName, err.Message) + common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) + // https://platform.openai.com/docs/guides/error-codes/api-errors + if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + channelId := c.GetInt("channel_id") + channelName := c.GetString("channel_name") + disableChannel(channelId, channelName, err.Message) + } + }() + if err.StatusCode == http.StatusTooManyRequests { + err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + openaiErr, _ := json.Marshal(err) + _ = c.Error(errors.New(string(openaiErr))) + return } } diff --git a/middleware/retry.go b/middleware/retry.go new file mode 100644 index 00000000..26341341 --- /dev/null +++ b/middleware/retry.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "strconv" + "time" +) + +type OpenAIErrorWithStatusCode struct { + OpenAIError + StatusCode int `json:"status_code"` +} + +type OpenAIError struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param"` + Code any `json:"code"` +} + +func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc { + var retryHandler gin.HandlerFunc + // 获取RetryHandler在当前HandlersChain的位置 + index := len(group.Handlers) + 1 + retryHandler = func(c *gin.Context) { + // Backup request + hasBody := c.Request.ContentLength > 0 + backupHeader := c.Request.Header.Clone() + var backupBody []byte + var err error + if hasBody { + backupBody, err = io.ReadAll(c.Request.Body) + if err != nil { + abortWithMessage(c, http.StatusBadRequest, "Invalid request") + return + } + _ = c.Request.Body.Close() + c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody)) + } + + // 获取 retryHandler 后续的中间件 + // Get next handlers + nextHandlers := group.Handlers[index:] + + // 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay + // Add Relay handler + nextHandlers = append(nextHandlers, c.Handler()) + + // Retry + maxRetryStr := c.Query("retry") + maxRetry, err := strconv.Atoi(maxRetryStr) + if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes { + maxRetry = common.RetryTimes + } + retryDelay := time.Duration(common.RetryInterval) * time.Millisecond + var openaiErr *OpenAIErrorWithStatusCode + for i := 0; i < maxRetry; i++ { + if i == 0 { + // 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行 + // First request, execute next middleware + c.Next() + fmt.Println("c.Next()") + } else { + // Clear errors to avoid confusion in next middleware + c.Errors = c.Errors[:0] + // 重试, 恢复请求头和请求体, 并执行后续中间件 + // Retry, restore request and execute next middleware + c.Request.Header = backupHeader.Clone() + if hasBody { + c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody)) + } + for _, handler := range nextHandlers { + handler(c) + } + } + + // If no errors, return + if len(c.Errors) == 0 { + return + } + // c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件 + c.Abort() + // If errors, retry after delay + time.Sleep(retryDelay) + } + _ = json.Unmarshal([]byte(c.Errors.Last().Error()), &openaiErr) + c.JSON(openaiErr.StatusCode, gin.H{ + "error": openaiErr.OpenAIError, + }) + } + return retryHandler +} diff --git a/model/option.go b/model/option.go index 4ef4d260..236292f0 100644 --- a/model/option.go +++ b/model/option.go @@ -205,6 +205,8 @@ func updateOptionMap(key string, value string) (err error) { common.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": common.RetryTimes, _ = strconv.Atoi(value) + case "RetryInterval": + common.RetryInterval, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": diff --git a/router/relay-router.go b/router/relay-router.go index e84f02db..f53fc8ea 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -17,6 +17,7 @@ func SetRelayRouter(router *gin.Engine) { modelsRouter.GET("/:model", controller.RetrieveModel) } relayV1Router := router.Group("/v1") + relayV1Router.Use(middleware.RetryHandler(relayV1Router)) relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) { relayV1Router.POST("/completions", controller.Relay) diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index bf8b5ffd..0348bfee 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -21,7 +21,8 @@ const OperationSetting = () => { DisplayInCurrencyEnabled: '', DisplayTokenStatEnabled: '', ApproximateTokenEnabled: '', - RetryTimes: 0 + RetryTimes: 0, + RetryInterval: 0, }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -128,6 +129,9 @@ const OperationSetting = () => { if (originInputs['RetryTimes'] !== inputs.RetryTimes) { await updateOption('RetryTimes', inputs.RetryTimes); } + if (originInputs['RetryInterval'] !== inputs.RetryInterval) { + await updateOption('RetryInterval', inputs.RetryInterval); + } break; } }; @@ -190,6 +194,19 @@ const OperationSetting = () => { value={inputs.RetryTimes} placeholder='失败重试次数' /> + +