feat: built in retry supported (close #1036, close #770)

This commit is contained in:
JustSong 2024-02-25 19:01:49 +08:00
parent f141a37a9e
commit 565ea58e68
5 changed files with 117 additions and 51 deletions

View File

@ -137,6 +137,7 @@ func GetUUID() string {
} }
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() { func init() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
@ -168,6 +169,15 @@ func GetRandomString(length int) string {
return string(key) return string(key)
} }
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 { func GetTimestamp() int64 {
return time.Now().Unix() return time.Now().Unix()
} }

View File

@ -1,23 +1,24 @@
package controller package controller
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"net/http" "net/http"
"strconv"
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
func Relay(c *gin.Context) { func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
var err *model.ErrorWithStatusCode var err *model.ErrorWithStatusCode
switch relayMode { switch relayMode {
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
@ -31,34 +32,82 @@ func Relay(c *gin.Context) {
default: default:
err = controller.RelayTextHelper(c) err = controller.RelayTextHelper(c)
} }
if err != nil { return err
}
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
bizErr := relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey) requestId := c.GetString(logger.RequestIdKey)
retryTimesStr := c.Query("retry") retryTimes := config.RetryTimes
retryTimes, _ := strconv.Atoi(retryTimesStr) if !shouldRetry(bizErr.StatusCode) {
if retryTimesStr == "" { logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode)
retryTimes = config.RetryTimes retryTimes = 0
} }
if retryTimes > 0 { for i := retryTimes; i > 0; i-- {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel)
} else { if err != nil {
if err.StatusCode == http.StatusTooManyRequests { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
err.Error.Message = "当前分组上游负载已饱和,请稍后再试" break
} }
err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
c.JSON(err.StatusCode, gin.H{ if channel.Id == lastFailedChannelId {
"error": err.Error, continue
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
bizErr = relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
}) })
} }
channelId := c.GetInt("channel_id") }
logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
func shouldRetry(statusCode int) bool {
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) { if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message) disableChannel(channelId, channelName, err.Message)
} }
} }
}
func RelayNotImplemented(c *gin.Context) { func RelayNotImplemented(c *gin.Context) {
err := model.Error{ err := model.Error{

View File

@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) {
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("specific_channel_id", parts[1])
} else { } else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return return

View File

@ -21,8 +21,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup) c.Set("group", userGroup)
var requestModel string
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("channelId") channelId, ok := c.Get("specific_channel_id")
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
@ -66,6 +67,7 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }
} }
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@ -77,10 +79,17 @@ func Distribute() func(c *gin.Context) {
return return
} }
} }
SetupContextForSelectedChannel(c, channel, requestModel)
c.Next()
}
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("original_model", modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility // this is for backward compatibility
@ -100,6 +109,4 @@ func Distribute() func(c *gin.Context) {
for k, v := range cfg { for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v) c.Set(common.ConfigKeyPrefix+k, v)
} }
c.Next()
}
} }

View File

@ -9,7 +9,7 @@ import (
func RequestId() func(c *gin.Context) { func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GetTimeString() + helper.GetRandomString(8) id := helper.GetTimeString() + helper.GetRandomNumberString(8)
c.Set(logger.RequestIdKey, id) c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)