ai-gateway/controller/relay.go

147 lines
4.2 KiB
Go
Raw Normal View History

2023-04-23 10:24:11 +00:00
package controller
import (
"bytes"
"context"
2023-04-23 10:24:11 +00:00
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
2024-01-28 11:38:58 +00:00
"github.com/songquanpeng/one-api/relay/util"
"io"
2023-04-23 10:24:11 +00:00
"net/http"
)
2023-05-21 06:26:59 +00:00
// https://platform.openai.com/docs/api-reference/chat
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
2023-06-19 02:28:55 +00:00
switch relayMode {
2024-01-14 11:21:03 +00:00
case constant.RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
fallthrough
2024-01-14 11:21:03 +00:00
case constant.RelayModeAudioTranslation:
fallthrough
2024-01-14 11:21:03 +00:00
case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
2023-06-19 02:28:55 +00:00
default:
err = controller.RelayTextHelper(c)
}
return err
}
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
bizErr := relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId := channelId
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
requestId := c.GetString(logger.RequestIdKey)
retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) {
logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode)
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %w", err)
break
2023-07-15 11:06:51 +00:00
}
logger.Infof(ctx, "using channel #%d to retry (remain times %d)", channel.Id, i)
if channel.Id == lastFailedChannelId {
continue
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode)
if bizErr == nil {
return
}
channelId := c.GetInt("channel_id")
lastFailedChannelId = channelId
channelName := c.GetString("channel_name")
go processChannelRelayError(ctx, channelId, channelName, bizErr)
}
if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests {
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
c.JSON(bizErr.StatusCode, gin.H{
"error": bizErr.Error,
})
}
}
func shouldRetry(c *gin.Context, statusCode int) bool {
if _, ok := c.Get("specific_channel_id"); ok {
return false
}
if statusCode == http.StatusTooManyRequests {
return true
}
if statusCode/100 == 5 {
return true
}
if statusCode == http.StatusBadRequest {
return false
}
if statusCode/100 == 2 {
return false
}
return true
}
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
disableChannel(channelId, channelName, err.Message)
}
}
func RelayNotImplemented(c *gin.Context) {
err := model.Error{
Message: "API not implemented",
Type: "one_api_error",
Param: "",
Code: "api_not_implemented",
}
2023-06-23 14:59:44 +00:00
c.JSON(http.StatusNotImplemented, gin.H{
"error": err,
})
}
func RelayNotFound(c *gin.Context) {
err := model.Error{
2023-08-11 11:53:01 +00:00
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error",
Param: "",
2023-08-11 11:53:01 +00:00
Code: "",
}
2023-06-23 14:59:44 +00:00
c.JSON(http.StatusNotFound, gin.H{
"error": err,
})
}