parent
f141a37a9e
commit
565ea58e68
@ -137,6 +137,7 @@ func GetUUID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
const keyNumbers = "0123456789"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
@ -168,6 +169,15 @@ func GetRandomString(length int) string {
|
|||||||
return string(key)
|
return string(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetRandomNumberString(length int) string {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
key := make([]byte, length)
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
|
||||||
|
}
|
||||||
|
return string(key)
|
||||||
|
}
|
||||||
|
|
||||||
func GetTimestamp() int64 {
|
func GetTimestamp() int64 {
|
||||||
return time.Now().Unix()
|
return time.Now().Unix()
|
||||||
}
|
}
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
|
dbmodel "github.com/songquanpeng/one-api/model"
|
||||||
"github.com/songquanpeng/one-api/relay/constant"
|
"github.com/songquanpeng/one-api/relay/constant"
|
||||||
"github.com/songquanpeng/one-api/relay/controller"
|
"github.com/songquanpeng/one-api/relay/controller"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
"github.com/songquanpeng/one-api/relay/util"
|
"github.com/songquanpeng/one-api/relay/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
|
||||||
var err *model.ErrorWithStatusCode
|
var err *model.ErrorWithStatusCode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
@ -31,32 +32,80 @@ func Relay(c *gin.Context) {
|
|||||||
default:
|
default:
|
||||||
err = controller.RelayTextHelper(c)
|
err = controller.RelayTextHelper(c)
|
||||||
}
|
}
|
||||||
if err != nil {
|
return err
|
||||||
requestId := c.GetString(logger.RequestIdKey)
|
}
|
||||||
retryTimesStr := c.Query("retry")
|
|
||||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
func Relay(c *gin.Context) {
|
||||||
if retryTimesStr == "" {
|
ctx := c.Request.Context()
|
||||||
retryTimes = config.RetryTimes
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
|
bizErr := relay(c, relayMode)
|
||||||
|
if bizErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channelId := c.GetInt("channel_id")
|
||||||
|
lastFailedChannelId := channelId
|
||||||
|
channelName := c.GetString("channel_name")
|
||||||
|
group := c.GetString("group")
|
||||||
|
originalModel := c.GetString("original_model")
|
||||||
|
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||||
|
requestId := c.GetString(logger.RequestIdKey)
|
||||||
|
retryTimes := config.RetryTimes
|
||||||
|
if !shouldRetry(bizErr.StatusCode) {
|
||||||
|
logger.Errorf(ctx, "relay error happen, but status code is %d, won't retry in this case", bizErr.StatusCode)
|
||||||
|
retryTimes = 0
|
||||||
|
}
|
||||||
|
for i := retryTimes; i > 0; i-- {
|
||||||
|
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel)
|
||||||
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if retryTimes > 0 {
|
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
if channel.Id == lastFailedChannelId {
|
||||||
} else {
|
continue
|
||||||
if err.StatusCode == http.StatusTooManyRequests {
|
}
|
||||||
err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
}
|
bizErr = relay(c, relayMode)
|
||||||
err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId)
|
if bizErr == nil {
|
||||||
c.JSON(err.StatusCode, gin.H{
|
return
|
||||||
"error": err.Error,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
lastFailedChannelId = channelId
|
||||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
channelName := c.GetString("channel_name")
|
||||||
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
go processChannelRelayError(ctx, channelId, channelName, bizErr)
|
||||||
channelId := c.GetInt("channel_id")
|
}
|
||||||
channelName := c.GetString("channel_name")
|
if bizErr != nil {
|
||||||
disableChannel(channelId, channelName, err.Message)
|
if bizErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||||
}
|
}
|
||||||
|
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
|
||||||
|
c.JSON(bizErr.StatusCode, gin.H{
|
||||||
|
"error": bizErr.Error,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetry(statusCode int) bool {
|
||||||
|
if statusCode == http.StatusTooManyRequests {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode/100 == 5 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if statusCode == http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if statusCode/100 == 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
|
||||||
|
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
|
||||||
|
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||||
|
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||||
|
disableChannel(channelId, channelName, err.Message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Set("token_name", token.Name)
|
c.Set("token_name", token.Name)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("specific_channel_id", parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||||
return
|
return
|
||||||
|
@ -21,8 +21,9 @@ func Distribute() func(c *gin.Context) {
|
|||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||||
c.Set("group", userGroup)
|
c.Set("group", userGroup)
|
||||||
|
var requestModel string
|
||||||
var channel *model.Channel
|
var channel *model.Channel
|
||||||
channelId, ok := c.Get("channelId")
|
channelId, ok := c.Get("specific_channel_id")
|
||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -66,6 +67,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
modelRequest.Model = "whisper-1"
|
modelRequest.Model = "whisper-1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
requestModel = modelRequest.Model
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||||
@ -77,29 +79,34 @@ func Distribute() func(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Set("channel", channel.Type)
|
SetupContextForSelectedChannel(c, channel, requestModel)
|
||||||
c.Set("channel_id", channel.Id)
|
|
||||||
c.Set("channel_name", channel.Name)
|
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
|
||||||
// this is for backward compatibility
|
|
||||||
switch channel.Type {
|
|
||||||
case common.ChannelTypeAzure:
|
|
||||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
|
||||||
case common.ChannelTypeXunfei:
|
|
||||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
|
||||||
case common.ChannelTypeGemini:
|
|
||||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
|
||||||
c.Set(common.ConfigKeyLibraryID, channel.Other)
|
|
||||||
case common.ChannelTypeAli:
|
|
||||||
c.Set(common.ConfigKeyPlugin, channel.Other)
|
|
||||||
}
|
|
||||||
cfg, _ := channel.LoadConfig()
|
|
||||||
for k, v := range cfg {
|
|
||||||
c.Set(common.ConfigKeyPrefix+k, v)
|
|
||||||
}
|
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||||
|
c.Set("channel", channel.Type)
|
||||||
|
c.Set("channel_id", channel.Id)
|
||||||
|
c.Set("channel_name", channel.Name)
|
||||||
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Set("original_model", modelName) // for retry
|
||||||
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
|
// this is for backward compatibility
|
||||||
|
switch channel.Type {
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||||
|
case common.ChannelTypeXunfei:
|
||||||
|
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||||
|
case common.ChannelTypeGemini:
|
||||||
|
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||||
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
|
c.Set(common.ConfigKeyLibraryID, channel.Other)
|
||||||
|
case common.ChannelTypeAli:
|
||||||
|
c.Set(common.ConfigKeyPlugin, channel.Other)
|
||||||
|
}
|
||||||
|
cfg, _ := channel.LoadConfig()
|
||||||
|
for k, v := range cfg {
|
||||||
|
c.Set(common.ConfigKeyPrefix+k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
func RequestId() func(c *gin.Context) {
|
func RequestId() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
id := helper.GetTimeString() + helper.GetRandomString(8)
|
id := helper.GetTimeString() + helper.GetRandomNumberString(8)
|
||||||
c.Set(logger.RequestIdKey, id)
|
c.Set(logger.RequestIdKey, id)
|
||||||
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
|
||||||
c.Request = c.Request.WithContext(ctx)
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
Loading…
Reference in New Issue
Block a user