refactor: split RelayTextHelper function

This commit is contained in:
JustSong 2024-01-28 19:13:11 +08:00
parent ea407f0054
commit b9d3cb0c45
5 changed files with 189 additions and 135 deletions

View File

@ -29,7 +29,7 @@ func Relay(c *gin.Context) {
case constant.RelayModeAudioTranscription: case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode) err = controller.RelayAudioHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c, relayMode) err = controller.RelayTextHelper(c)
} }
if err != nil { if err != nil {
requestId := c.GetString(logger.RequestIdKey) requestId := c.GetString(logger.RequestIdKey)

146
relay/controller/helper.go Normal file
View File

@ -0,0 +1,146 @@
package controller
import (
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"math"
"net/http"
"one-api/common"
"one-api/common/config"
"one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/util"
)
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) {
textRequest := &openai.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(textRequest, relayMode)
if err != nil {
return nil, err
}
return textRequest, nil
}
func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
return 0
}
func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
return int(float64(preConsumedTokens) * ratio)
}
func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(meta.UserId)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
if err != nil {
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(meta.UserId)
if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
}
func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) {
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, err
}
SetupRequestHeaders(c, req, meta, isStream)
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
err = req.Body.Close()
if err != nil {
logger.Warnf(ctx, "close req.Body failed: %+v", err)
}
err = c.Request.Body.Close()
if err != nil {
logger.Warnf(ctx, "close c.Request.Body failed: %+v", err)
}
return resp, nil
}

View File

@ -24,9 +24,9 @@ import (
"strings" "strings"
) )
func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
switch apiType { switch meta.APIType {
case constant.APITypeOpenAI: case constant.APITypeOpenAI:
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
@ -81,7 +81,7 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case constant.APITypeAli: case constant.APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == constant.RelayModeEmbeddings { if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
} }
case constant.APITypeTencent: case constant.APITypeTencent:
@ -191,8 +191,8 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
return requestBody, nil return requestBody, nil
} }
func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
SetupAuthHeaders(c, req, apiType, meta, isStream) SetupAuthHeaders(c, req, meta, isStream)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" { if isStream && c.Request.Header.Get("Accept") == "" {
@ -200,9 +200,9 @@ func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *u
} }
} }
func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
apiKey := meta.APIKey apiKey := meta.APIKey
switch apiType { switch meta.APIType {
case constant.APITypeOpenAI: case constant.APITypeOpenAI:
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey) req.Header.Set("api-key", apiKey)

View File

@ -1,115 +1,61 @@
package controller package controller
import ( import (
"context"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"math"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/common/config"
"one-api/common/logger" "one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/relay/util" "one-api/relay/util"
"strings" "strings"
) )
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
ctx := c.Request.Context() ctx := c.Request.Context()
meta := util.GetRelayMeta(c) meta := util.GetRelayMeta(c)
var textRequest openai.GeneralOpenAIRequest // get & validate textRequest
err := common.UnmarshalBodyReusable(c, &textRequest) textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(&textRequest, relayMode)
if err != nil { if err != nil {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
} }
// map model name
var isModelMapped bool var isModelMapped bool
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
apiType := constant.ChannelType2APIType(meta.ChannelType) // get model ratio & group ratio
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
if err != nil {
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
}
var promptTokens int
var completionTokens int
switch relayMode {
case constant.RelayModeChatCompletions:
promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model) modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(meta.Group) groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) // pre-consume quota
userQuota, err := model.CacheGetUserQuota(meta.UserId) promptTokens := getPromptTokens(textRequest, meta.Mode)
if err != nil { preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) if bizErr != nil {
logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
return bizErr
} }
if userQuota-preConsumedQuota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) // get request body
} requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode)
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
} }
var req *http.Request // do request
var resp *http.Response var resp *http.Response
isStream := textRequest.Stream isStream := textRequest.Stream
if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest)
if err != nil {
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
}
if apiType != constant.APITypeXunfei { // cause xunfei use websocket resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody)
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
SetupRequestHeaders(c, req, apiType, meta, isStream)
resp, err = util.HTTPClient.Do(req)
if err != nil { if err != nil {
logger.Errorf(ctx, "doRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
} }
err = req.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
@ -117,57 +63,14 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
return util.RelayErrorHandler(resp) return util.RelayErrorHandler(resp)
} }
} }
// do response
var respErr *openai.ErrorWithStatusCode usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens)
var usage *openai.Usage
defer func(ctx context.Context) {
// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return
}
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = usage.PromptTokens
completionTokens = usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
if err != nil {
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(meta.UserId)
if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
}()
}(ctx)
usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
if respErr != nil { if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr return respErr
} }
// post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil return nil
} }

View File

@ -3,10 +3,12 @@ package util
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/relay/constant"
"strings" "strings"
) )
type RelayMeta struct { type RelayMeta struct {
Mode int
ChannelType int ChannelType int
ChannelId int ChannelId int
TokenId int TokenId int
@ -17,11 +19,13 @@ type RelayMeta struct {
BaseURL string BaseURL string
APIVersion string APIVersion string
APIKey string APIKey string
APIType int
Config map[string]string Config map[string]string
} }
func GetRelayMeta(c *gin.Context) *RelayMeta { func GetRelayMeta(c *gin.Context) *RelayMeta {
meta := RelayMeta{ meta := RelayMeta{
Mode: constant.Path2RelayMode(c.Request.URL.Path),
ChannelType: c.GetInt("channel"), ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"), ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"), TokenId: c.GetInt("token_id"),
@ -40,5 +44,6 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
if meta.BaseURL == "" { if meta.BaseURL == "" {
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
} }
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
return &meta return &meta
} }