diff --git a/common/helper/helper.go b/common/helper/helper.go index a0d88ec2..babe422b 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -137,6 +137,7 @@ func GetUUID() string { } const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const keyNumbers = "0123456789" func init() { rand.Seed(time.Now().UnixNano()) @@ -168,6 +169,15 @@ func GetRandomString(length int) string { return string(key) } +func GetRandomNumberString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyNumbers[rand.Intn(len(keyNumbers))] + } + return string(key) +} + func GetTimestamp() int64 { return time.Now().Unix() } diff --git a/controller/relay.go b/controller/relay.go index 6c6d268e..240042b6 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,23 +1,24 @@ package controller import ( + "context" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/middleware" + dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" "net/http" - "strconv" ) // https://platform.openai.com/docs/api-reference/chat -func Relay(c *gin.Context) { - relayMode := constant.Path2RelayMode(c.Request.URL.Path) +func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { case constant.RelayModeImagesGenerations: @@ -31,32 +32,80 @@ func Relay(c *gin.Context) { default: err = controller.RelayTextHelper(c) } - if err != nil { - requestId := c.GetString(logger.RequestIdKey) - retryTimesStr := c.Query("retry") - retryTimes, _ := strconv.Atoi(retryTimesStr) - if retryTimesStr == "" { - retryTimes = config.RetryTimes + return err +} + +func Relay(c *gin.Context) { + ctx := c.Request.Context() + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + bizErr := relay(c, relayMode) + if bizErr == nil { + return + } + channelId := c.GetInt("channel_id") + lastFailedChannelId := channelId + channelName := c.GetString("channel_name") + group := c.GetString("group") + originalModel := c.GetString("original_model") + go processChannelRelayError(ctx, channelId, channelName, bizErr) + requestId := c.GetString(logger.RequestIdKey) + retryTimes := config.RetryTimes + if !shouldRetry(bizErr.StatusCode) { + logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode) + retryTimes = 0 + } + for i := retryTimes; i > 0; i-- { + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + if err != nil { + logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err) + break } - if retryTimes > 0 { - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) - } else { - if err.StatusCode == http.StatusTooManyRequests { - err.Error.Message = "当前分组上游负载已饱和,请稍后再试" - } - err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) - c.JSON(err.StatusCode, gin.H{ - "error": err.Error, - }) + logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i) + if channel.Id == lastFailedChannelId { + continue + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + bizErr = relay(c, relayMode) + if bizErr == nil { + return } channelId := c.GetInt("channel_id") - logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) - // https://platform.openai.com/docs/guides/error-codes/api-errors - if util.ShouldDisableChannel(&err.Error, err.StatusCode) { - channelId := c.GetInt("channel_id") - channelName := c.GetString("channel_name") - disableChannel(channelId, channelName, err.Message) + lastFailedChannelId = channelId + channelName := c.GetString("channel_name") + go processChannelRelayError(ctx, channelId, channelName, bizErr) + } + if bizErr != nil { + if bizErr.StatusCode == http.StatusTooManyRequests { + bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } + bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) + c.JSON(bizErr.StatusCode, gin.H{ + "error": bizErr.Error, + }) + } +} + +func shouldRetry(statusCode int) bool { + if statusCode == http.StatusTooManyRequests { + return true + } + if statusCode/100 == 5 { + return true + } + if statusCode == http.StatusBadRequest { + return false + } + if statusCode/100 == 2 { + return false + } + return true +} + +func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { + logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) + // https://platform.openai.com/docs/guides/error-codes/api-errors + if util.ShouldDisableChannel(&err.Error, err.StatusCode) { + disableChannel(channelId, channelName, err.Message) } } diff --git a/middleware/auth.go b/middleware/auth.go index 42a599d0..9d25f395 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) { c.Set("token_name", token.Name) if len(parts) > 1 { if model.IsAdmin(token.UserId) { - c.Set("channelId", parts[1]) + c.Set("specific_channel_id", parts[1]) } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return diff --git a/middleware/distributor.go b/middleware/distributor.go index 704f6236..aeb2796a 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -21,8 +21,9 @@ func Distribute() func(c *gin.Context) { userId := c.GetInt("id") userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) + var requestModel string var channel *model.Channel - channelId, ok := c.Get("channelId") + channelId, ok := c.Get("specific_channel_id") if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { @@ -66,6 +67,7 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "whisper-1" } } + requestModel = modelRequest.Model channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) @@ -77,29 +79,34 @@ func Distribute() func(c *gin.Context) { return } } - c.Set("channel", channel.Type) - c.Set("channel_id", channel.Id) - c.Set("channel_name", channel.Name) - c.Set("model_mapping", channel.GetModelMapping()) - c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - c.Set("base_url", channel.GetBaseURL()) - // this is for backward compatibility - switch channel.Type { - case common.ChannelTypeAzure: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeXunfei: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeGemini: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set(common.ConfigKeyLibraryID, channel.Other) - case common.ChannelTypeAli: - c.Set(common.ConfigKeyPlugin, channel.Other) - } - cfg, _ := channel.LoadConfig() - for k, v := range cfg { - c.Set(common.ConfigKeyPrefix+k, v) - } + SetupContextForSelectedChannel(c, channel, requestModel) c.Next() } } + +func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) + c.Set("model_mapping", channel.GetModelMapping()) + c.Set("original_model", modelName) // for retry + c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) + c.Set("base_url", channel.GetBaseURL()) + // this is for backward compatibility + switch channel.Type { + case common.ChannelTypeAzure: + c.Set(common.ConfigKeyAPIVersion, channel.Other) + case common.ChannelTypeXunfei: + c.Set(common.ConfigKeyAPIVersion, channel.Other) + case common.ChannelTypeGemini: + c.Set(common.ConfigKeyAPIVersion, channel.Other) + case common.ChannelTypeAIProxyLibrary: + c.Set(common.ConfigKeyLibraryID, channel.Other) + case common.ChannelTypeAli: + c.Set(common.ConfigKeyPlugin, channel.Other) + } + cfg, _ := channel.LoadConfig() + for k, v := range cfg { + c.Set(common.ConfigKeyPrefix+k, v) + } +} diff --git a/middleware/request-id.go b/middleware/request-id.go index 7cb66e93..234a93d8 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomString(8) + id := helper.GetTimeString() + helper.GetRandomNumberString(8) c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx)