package controller import ( "context" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" ) // https://platform.openai.com/docs/api-reference/chat func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { case constant.RelayModeImagesGenerations: err = controller.RelayImageHelper(c, relayMode) case constant.RelayModeAudioSpeech: fallthrough case constant.RelayModeAudioTranslation: fallthrough case constant.RelayModeAudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) } return err } func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := constant.Path2RelayMode(c.Request.URL.Path) bizErr := relay(c, relayMode) if bizErr == nil { return } channelId := c.GetInt("channel_id") lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(bizErr.StatusCode) { logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode) retryTimes = 0 } for i := retryTimes; i > 0; i-- { channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) break } logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) if channel.Id == lastFailedChannelId { continue } middleware.SetupContextForSelectedChannel(c, channel, originalModel) bizErr = relay(c, relayMode) if bizErr == nil { return } channelId := c.GetInt("channel_id") lastFailedChannelId = channelId channelName := c.GetString("channel_name") go processChannelRelayError(ctx, channelId, channelName, bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) c.JSON(bizErr.StatusCode, gin.H{ "error": bizErr.Error, }) } } func shouldRetry(statusCode int) bool { if statusCode == http.StatusTooManyRequests { return true } if statusCode/100 == 5 { return true } if statusCode == http.StatusBadRequest { return false } if statusCode/100 == 2 { return false } return true } func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if util.ShouldDisableChannel(&err.Error, err.StatusCode) { disableChannel(channelId, channelName, err.Message) } } func RelayNotImplemented(c *gin.Context) { err := model.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", Code: "api_not_implemented", } c.JSON(http.StatusNotImplemented, gin.H{ "error": err, }) } func RelayNotFound(c *gin.Context) { err := model.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", Code: "", } c.JSON(http.StatusNotFound, gin.H{ "error": err, }) }