package controller import ( "bytes" "context" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" ) // https://platform.openai.com/docs/api-reference/chat func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough case relaymode.AudioTranslation: fallthrough case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) } return err } func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) if config.DebugEnabled { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt("channel_id") bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) return } lastFailedChannelId := channelId channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString(ctxkey.OriginalModel) go processChannelRelayError(ctx, channelId, channelName, bizErr) requestId := c.GetString(logger.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) retryTimes = 0 } for i := retryTimes; i > 0; i-- { channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err) break } logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) if channel.Id == lastFailedChannelId { continue } middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) bizErr = relayHelper(c, relayMode) if bizErr == nil { return } channelId := c.GetInt("channel_id") lastFailedChannelId = channelId channelName := c.GetString("channel_name") go processChannelRelayError(ctx, channelId, channelName, bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) c.JSON(bizErr.StatusCode, gin.H{ "error": bizErr.Error, }) } } func shouldRetry(c *gin.Context, statusCode int) bool { if _, ok := c.Get("specific_channel_id"); ok { return false } if statusCode == http.StatusTooManyRequests { return true } if statusCode/100 == 5 { return true } if statusCode == http.StatusBadRequest { return false } if statusCode/100 == 2 { return false } return true } func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { monitor.DisableChannel(channelId, channelName, err.Message) } else { monitor.Emit(channelId, false) } } func RelayNotImplemented(c *gin.Context) { err := model.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", Code: "api_not_implemented", } c.JSON(http.StatusNotImplemented, gin.H{ "error": err, }) } func RelayNotFound(c *gin.Context) { err := model.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", Code: "", } c.JSON(http.StatusNotFound, gin.H{ "error": err, }) }