From b9d3cb0c4520fd42bdae085b98af76d369527cf1 Mon Sep 17 00:00:00 2001 From: JustSong Date: Sun, 28 Jan 2024 19:13:11 +0800 Subject: [PATCH] refactor: split RelayTextHelper function --- controller/relay.go | 2 +- relay/controller/helper.go | 146 ++++++++++++++++++++++++ relay/controller/{util.go => temp.go} | 14 +-- relay/controller/text.go | 157 +++++--------------------- relay/util/relay_meta.go | 5 + 5 files changed, 189 insertions(+), 135 deletions(-) create mode 100644 relay/controller/helper.go rename relay/controller/{util.go => temp.go} (95%) diff --git a/controller/relay.go b/controller/relay.go index 46fedc7e..3f5b67fc 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,7 +29,7 @@ func Relay(c *gin.Context) { case constant.RelayModeAudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: - err = controller.RelayTextHelper(c, relayMode) + err = controller.RelayTextHelper(c) } if err != nil { requestId := c.GetString(logger.RequestIdKey) diff --git a/relay/controller/helper.go b/relay/controller/helper.go new file mode 100644 index 00000000..46e57969 --- /dev/null +++ b/relay/controller/helper.go @@ -0,0 +1,146 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "math" + "net/http" + "one-api/common" + "one-api/common/config" + "one-api/common/logger" + "one-api/model" + "one-api/relay/channel/openai" + "one-api/relay/constant" + "one-api/relay/util" +) + +func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) { + textRequest := &openai.GeneralOpenAIRequest{} + err := common.UnmarshalBodyReusable(c, textRequest) + if err != nil { + return nil, err + } + if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + textRequest.Model = "text-moderation-latest" + } + if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + textRequest.Model = c.Param("model") + } + err = util.ValidateTextRequest(textRequest, relayMode) + if err != nil { + return nil, err + } + return textRequest, nil +} + +func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int { + switch relayMode { + case constant.RelayModeChatCompletions: + return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) + case constant.RelayModeCompletions: + return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) + case constant.RelayModeModerations: + return openai.CountTokenInput(textRequest.Input, textRequest.Model) + } + return 0 +} + +func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int { + preConsumedTokens := config.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } + return int(float64(preConsumedTokens) * ratio) +} + +func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) { + preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) + + userQuota, err := model.CacheGetUserQuota(meta.UserId) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + if userQuota-preConsumedQuota < 0 { + return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) + if err != nil { + return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + return preConsumedQuota, nil +} + +func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { + if usage == nil { + logger.Error(ctx, "usage is nil, which is unexpected") + return + } + quota := 0 + completionRatio := common.GetCompletionRatio(textRequest.Model) + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + totalTokens := promptTokens + completionTokens + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + } + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) + if err != nil { + logger.Error(ctx, "error consuming token remain quota: "+err.Error()) + } + err = model.CacheUpdateUserQuota(meta.UserId) + if err != nil { + logger.Error(ctx, "error update user quota cache: "+err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) + model.UpdateChannelUsedQuota(meta.ChannelId, quota) + } +} + +func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) { + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, err + } + SetupRequestHeaders(c, req, meta, isStream) + resp, err := util.HTTPClient.Do(req) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("resp is nil") + } + err = req.Body.Close() + if err != nil { + logger.Warnf(ctx, "close req.Body failed: %+v", err) + } + err = c.Request.Body.Close() + if err != nil { + logger.Warnf(ctx, "close c.Request.Body failed: %+v", err) + } + return resp, nil +} diff --git a/relay/controller/util.go b/relay/controller/temp.go similarity index 95% rename from relay/controller/util.go rename to relay/controller/temp.go index 02f1b30f..7161cfb9 100644 --- a/relay/controller/util.go +++ b/relay/controller/temp.go @@ -24,9 +24,9 @@ import ( "strings" ) -func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { +func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) { fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) - switch apiType { + switch meta.APIType { case constant.APITypeOpenAI: if meta.ChannelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api @@ -81,7 +81,7 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) case constant.APITypeAli: fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == constant.RelayModeEmbeddings { + if meta.Mode == constant.RelayModeEmbeddings { fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" } case constant.APITypeTencent: @@ -191,8 +191,8 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM return requestBody, nil } -func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { - SetupAuthHeaders(c, req, apiType, meta, isStream) +func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) { + SetupAuthHeaders(c, req, meta, isStream) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if isStream && c.Request.Header.Get("Accept") == "" { @@ -200,9 +200,9 @@ func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *u } } -func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) { +func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) { apiKey := meta.APIKey - switch apiType { + switch meta.APIType { case constant.APITypeOpenAI: if meta.ChannelType == common.ChannelTypeAzure { req.Header.Set("api-key", apiKey) diff --git a/relay/controller/text.go b/relay/controller/text.go index 68354628..335d53be 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -1,115 +1,61 @@ package controller import ( - "context" - "errors" "fmt" "github.com/gin-gonic/gin" - "math" "net/http" "one-api/common" - "one-api/common/config" "one-api/common/logger" - "one-api/model" "one-api/relay/channel/openai" "one-api/relay/constant" "one-api/relay/util" "strings" ) -func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { +func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode { ctx := c.Request.Context() meta := util.GetRelayMeta(c) - var textRequest openai.GeneralOpenAIRequest - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - if relayMode == constant.RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } - err = util.ValidateTextRequest(&textRequest, relayMode) + // get & validate textRequest + textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { + logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest) } + // map model name var isModelMapped bool textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) - apiType := constant.ChannelType2APIType(meta.ChannelType) - fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest) - if err != nil { - logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) - return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) - } - var promptTokens int - var completionTokens int - switch relayMode { - case constant.RelayModeChatCompletions: - promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: - promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model) - case constant.RelayModeModerations: - promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model) - } - preConsumedTokens := config.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens - } + // get model ratio & group ratio modelRatio := common.GetModelRatio(textRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(meta.UserId) - if err != nil { - return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + // pre-consume quota + promptTokens := getPromptTokens(textRequest, meta.Mode) + preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta) + if bizErr != nil { + logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr) + return bizErr } - if userQuota-preConsumedQuota < 0 { - return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) - if err != nil { - return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) - if err != nil { - return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode) + + // get request body + requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode) if err != nil { return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError) } - var req *http.Request + // do request var resp *http.Response isStream := textRequest.Stream + if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket + fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest) + if err != nil { + logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error())) + return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError) + } - if apiType != constant.APITypeXunfei { // cause xunfei use websocket - req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - SetupRequestHeaders(c, req, apiType, meta, isStream) - resp, err = util.HTTPClient.Do(req) + resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody) if err != nil { + logger.Errorf(ctx, "doRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { @@ -117,57 +63,14 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode return util.RelayErrorHandler(resp) } } - - var respErr *openai.ErrorWithStatusCode - var usage *openai.Usage - - defer func(ctx context.Context) { - // Why we use defer here? Because if error happened, we will have to return the pre-consumed quota. - if respErr != nil { - logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) - return - } - if usage == nil { - logger.Error(ctx, "usage is nil, which is unexpected") - return - } - - go func() { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = usage.PromptTokens - completionTokens = usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) - if err != nil { - logger.Error(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(meta.UserId) - if err != nil { - logger.Error(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - model.UpdateChannelUsedQuota(meta.ChannelId, quota) - } - }() - }(ctx) - usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens) + // do response + usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens) if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } + // post-consume quota + go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio) return nil } diff --git a/relay/util/relay_meta.go b/relay/util/relay_meta.go index 19936e49..bea5a406 100644 --- a/relay/util/relay_meta.go +++ b/relay/util/relay_meta.go @@ -3,10 +3,12 @@ package util import ( "github.com/gin-gonic/gin" "one-api/common" + "one-api/relay/constant" "strings" ) type RelayMeta struct { + Mode int ChannelType int ChannelId int TokenId int @@ -17,11 +19,13 @@ type RelayMeta struct { BaseURL string APIVersion string APIKey string + APIType int Config map[string]string } func GetRelayMeta(c *gin.Context) *RelayMeta { meta := RelayMeta{ + Mode: constant.Path2RelayMode(c.Request.URL.Path), ChannelType: c.GetInt("channel"), ChannelId: c.GetInt("channel_id"), TokenId: c.GetInt("token_id"), @@ -40,5 +44,6 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { if meta.BaseURL == "" { meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] } + meta.APIType = constant.ChannelType2APIType(meta.ChannelType) return &meta }