ai-gateway/relay/controller/text.go
JustSong 2d760d4a01
refactor: refactor relay part (#957)
* refactor: refactor relay part

* refactor: refactor config part
2024-01-21 23:21:42 +08:00

174 lines
6.4 KiB
Go

package controller
import (
"context"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"math"
"net/http"
"one-api/common"
"one-api/common/config"
"one-api/common/logger"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/util"
"strings"
)
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
var textRequest openai.GeneralOpenAIRequest
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(&textRequest, relayMode)
if err != nil {
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
var isModelMapped bool
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
apiType := constant.ChannelType2APIType(meta.ChannelType)
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
if err != nil {
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
}
var promptTokens int
var completionTokens int
switch relayMode {
case constant.RelayModeChatCompletions:
promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(meta.UserId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota-preConsumedQuota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
if err != nil {
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
if err != nil {
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
}
var req *http.Request
var resp *http.Response
isStream := textRequest.Stream
if apiType != constant.APITypeXunfei { // cause xunfei use websocket
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
SetupRequestHeaders(c, req, apiType, meta, isStream)
resp, err = util.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
}
var respErr *openai.ErrorWithStatusCode
var usage *openai.Usage
defer func(ctx context.Context) {
// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return
}
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
go func() {
quota := 0
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens = usage.PromptTokens
completionTokens = usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
if err != nil {
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(meta.UserId)
if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
}()
}(ctx)
usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
if respErr != nil {
return respErr
}
return nil
}