173 lines
6.3 KiB
Go
173 lines
6.3 KiB
Go
package controller
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"math"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/common/logger"
|
|
"one-api/model"
|
|
"one-api/relay/channel/openai"
|
|
"one-api/relay/constant"
|
|
"one-api/relay/util"
|
|
"strings"
|
|
)
|
|
|
|
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
|
ctx := c.Request.Context()
|
|
meta := util.GetRelayMeta(c)
|
|
var textRequest openai.GeneralOpenAIRequest
|
|
err := common.UnmarshalBodyReusable(c, &textRequest)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
}
|
|
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
|
textRequest.Model = "text-moderation-latest"
|
|
}
|
|
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
|
textRequest.Model = c.Param("model")
|
|
}
|
|
err = util.ValidateTextRequest(&textRequest, relayMode)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
|
}
|
|
var isModelMapped bool
|
|
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
|
apiType := constant.ChannelType2APIType(meta.ChannelType)
|
|
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
|
|
if err != nil {
|
|
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
|
|
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
|
|
}
|
|
var promptTokens int
|
|
var completionTokens int
|
|
switch relayMode {
|
|
case constant.RelayModeChatCompletions:
|
|
promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
|
case constant.RelayModeCompletions:
|
|
promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
|
case constant.RelayModeModerations:
|
|
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
|
}
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
if textRequest.MaxTokens != 0 {
|
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
}
|
|
modelRatio := common.GetModelRatio(textRequest.Model)
|
|
groupRatio := common.GetGroupRatio(meta.Group)
|
|
ratio := modelRatio * groupRatio
|
|
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
|
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
}
|
|
if userQuota-preConsumedQuota < 0 {
|
|
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
|
}
|
|
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
}
|
|
if userQuota > 100*preConsumedQuota {
|
|
// in this case, we do not pre-consume quota
|
|
// because the user has enough quota
|
|
preConsumedQuota = 0
|
|
logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
|
}
|
|
if preConsumedQuota > 0 {
|
|
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
}
|
|
}
|
|
requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
|
|
}
|
|
var req *http.Request
|
|
var resp *http.Response
|
|
isStream := textRequest.Stream
|
|
|
|
if apiType != constant.APITypeXunfei { // cause xunfei use websocket
|
|
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
}
|
|
SetupRequestHeaders(c, req, apiType, meta, isStream)
|
|
resp, err = util.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
}
|
|
err = req.Body.Close()
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
}
|
|
err = c.Request.Body.Close()
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
}
|
|
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
|
return util.RelayErrorHandler(resp)
|
|
}
|
|
}
|
|
|
|
var respErr *openai.ErrorWithStatusCode
|
|
var usage *openai.Usage
|
|
|
|
defer func(ctx context.Context) {
|
|
// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
|
|
if respErr != nil {
|
|
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
|
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
|
return
|
|
}
|
|
if usage == nil {
|
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
quota := 0
|
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
promptTokens = usage.PromptTokens
|
|
completionTokens = usage.CompletionTokens
|
|
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
|
if ratio != 0 && quota <= 0 {
|
|
quota = 1
|
|
}
|
|
totalTokens := promptTokens + completionTokens
|
|
if totalTokens == 0 {
|
|
// in this case, must be some error happened
|
|
// we cannot just return, because we may have to return the pre-consumed quota
|
|
quota = 0
|
|
}
|
|
quotaDelta := quota - preConsumedQuota
|
|
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
|
if err != nil {
|
|
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
|
}
|
|
err = model.CacheUpdateUserQuota(meta.UserId)
|
|
if err != nil {
|
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
|
}
|
|
if quota != 0 {
|
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
|
}
|
|
}()
|
|
}(ctx)
|
|
usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
|
|
if respErr != nil {
|
|
return respErr
|
|
}
|
|
return nil
|
|
}
|