refactor: split RelayTextHelper function
This commit is contained in:
parent
ea407f0054
commit
b9d3cb0c45
@ -29,7 +29,7 @@ func Relay(c *gin.Context) {
|
|||||||
case constant.RelayModeAudioTranscription:
|
case constant.RelayModeAudioTranscription:
|
||||||
err = controller.RelayAudioHelper(c, relayMode)
|
err = controller.RelayAudioHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = controller.RelayTextHelper(c, relayMode)
|
err = controller.RelayTextHelper(c)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
requestId := c.GetString(logger.RequestIdKey)
|
requestId := c.GetString(logger.RequestIdKey)
|
||||||
|
146
relay/controller/helper.go
Normal file
146
relay/controller/helper.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/common/config"
|
||||||
|
"one-api/common/logger"
|
||||||
|
"one-api/model"
|
||||||
|
"one-api/relay/channel/openai"
|
||||||
|
"one-api/relay/constant"
|
||||||
|
"one-api/relay/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) {
|
||||||
|
textRequest := &openai.GeneralOpenAIRequest{}
|
||||||
|
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
||||||
|
textRequest.Model = "text-moderation-latest"
|
||||||
|
}
|
||||||
|
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||||
|
textRequest.Model = c.Param("model")
|
||||||
|
}
|
||||||
|
err = util.ValidateTextRequest(textRequest, relayMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return textRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int {
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
|
case constant.RelayModeCompletions:
|
||||||
|
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
|
case constant.RelayModeModerations:
|
||||||
|
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
|
||||||
|
preConsumedTokens := config.PreConsumedQuota
|
||||||
|
if textRequest.MaxTokens != 0 {
|
||||||
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||||
|
}
|
||||||
|
return int(float64(preConsumedTokens) * ratio)
|
||||||
|
}
|
||||||
|
|
||||||
|
func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) {
|
||||||
|
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
||||||
|
|
||||||
|
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota-preConsumedQuota < 0 {
|
||||||
|
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota > 100*preConsumedQuota {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return preConsumedQuota, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
|
||||||
|
if usage == nil {
|
||||||
|
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
quota := 0
|
||||||
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||||
|
promptTokens := usage.PromptTokens
|
||||||
|
completionTokens := usage.CompletionTokens
|
||||||
|
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||||
|
if ratio != 0 && quota <= 0 {
|
||||||
|
quota = 1
|
||||||
|
}
|
||||||
|
totalTokens := promptTokens + completionTokens
|
||||||
|
if totalTokens == 0 {
|
||||||
|
// in this case, must be some error happened
|
||||||
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
|
quota = 0
|
||||||
|
}
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(meta.UserId)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
||||||
|
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
SetupRequestHeaders(c, req, meta, isStream)
|
||||||
|
resp, err := util.HTTPClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return nil, errors.New("resp is nil")
|
||||||
|
}
|
||||||
|
err = req.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf(ctx, "close req.Body failed: %+v", err)
|
||||||
|
}
|
||||||
|
err = c.Request.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf(ctx, "close c.Request.Body failed: %+v", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
@ -24,9 +24,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
|
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
|
||||||
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||||
switch apiType {
|
switch meta.APIType {
|
||||||
case constant.APITypeOpenAI:
|
case constant.APITypeOpenAI:
|
||||||
if meta.ChannelType == common.ChannelTypeAzure {
|
if meta.ChannelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
@ -81,7 +81,7 @@ func GetRequestURL(requestURL string, apiType int, relayMode int, meta *util.Rel
|
|||||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||||
case constant.APITypeAli:
|
case constant.APITypeAli:
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||||
if relayMode == constant.RelayModeEmbeddings {
|
if meta.Mode == constant.RelayModeEmbeddings {
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
}
|
}
|
||||||
case constant.APITypeTencent:
|
case constant.APITypeTencent:
|
||||||
@ -191,8 +191,8 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
|
|||||||
return requestBody, nil
|
return requestBody, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
|
func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||||
SetupAuthHeaders(c, req, apiType, meta, isStream)
|
SetupAuthHeaders(c, req, meta, isStream)
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||||
@ -200,9 +200,9 @@ func SetupRequestHeaders(c *gin.Context, req *http.Request, apiType int, meta *u
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetupAuthHeaders(c *gin.Context, req *http.Request, apiType int, meta *util.RelayMeta, isStream bool) {
|
func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||||
apiKey := meta.APIKey
|
apiKey := meta.APIKey
|
||||||
switch apiType {
|
switch meta.APIType {
|
||||||
case constant.APITypeOpenAI:
|
case constant.APITypeOpenAI:
|
||||||
if meta.ChannelType == common.ChannelTypeAzure {
|
if meta.ChannelType == common.ChannelTypeAzure {
|
||||||
req.Header.Set("api-key", apiKey)
|
req.Header.Set("api-key", apiKey)
|
@ -1,115 +1,61 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/common/config"
|
|
||||||
"one-api/common/logger"
|
"one-api/common/logger"
|
||||||
"one-api/model"
|
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/relay/util"
|
"one-api/relay/util"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
meta := util.GetRelayMeta(c)
|
meta := util.GetRelayMeta(c)
|
||||||
var textRequest openai.GeneralOpenAIRequest
|
// get & validate textRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
|
||||||
textRequest.Model = "text-moderation-latest"
|
|
||||||
}
|
|
||||||
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
|
||||||
textRequest.Model = c.Param("model")
|
|
||||||
}
|
|
||||||
err = util.ValidateTextRequest(&textRequest, relayMode)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
||||||
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
// map model name
|
||||||
var isModelMapped bool
|
var isModelMapped bool
|
||||||
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||||
apiType := constant.ChannelType2APIType(meta.ChannelType)
|
// get model ratio & group ratio
|
||||||
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), apiType, relayMode, meta, &textRequest)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
|
|
||||||
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
var promptTokens int
|
|
||||||
var completionTokens int
|
|
||||||
switch relayMode {
|
|
||||||
case constant.RelayModeChatCompletions:
|
|
||||||
promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
|
||||||
case constant.RelayModeCompletions:
|
|
||||||
promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
|
||||||
case constant.RelayModeModerations:
|
|
||||||
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
|
||||||
}
|
|
||||||
preConsumedTokens := config.PreConsumedQuota
|
|
||||||
if textRequest.MaxTokens != 0 {
|
|
||||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
|
||||||
}
|
|
||||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||||
groupRatio := common.GetGroupRatio(meta.Group)
|
groupRatio := common.GetGroupRatio(meta.Group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
// pre-consume quota
|
||||||
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||||
if err != nil {
|
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
||||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
if bizErr != nil {
|
||||||
|
logger.Warnf(ctx, "preConsumeQuota failed: %+v", *bizErr)
|
||||||
|
return bizErr
|
||||||
}
|
}
|
||||||
if userQuota-preConsumedQuota < 0 {
|
|
||||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
// get request body
|
||||||
}
|
requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode)
|
||||||
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
logger.Info(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requestBody, err := GetRequestBody(c, textRequest, isModelMapped, apiType, relayMode)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
var req *http.Request
|
// do request
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
isStream := textRequest.Stream
|
isStream := textRequest.Stream
|
||||||
|
if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket
|
||||||
|
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
|
||||||
|
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
if apiType != constant.APITypeXunfei { // cause xunfei use websocket
|
resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody)
|
||||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
SetupRequestHeaders(c, req, apiType, meta, isStream)
|
|
||||||
resp, err = util.HTTPClient.Do(req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Errorf(ctx, "doRequest failed: %s", err.Error())
|
||||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = req.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = c.Request.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
@ -117,57 +63,14 @@ func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
|||||||
return util.RelayErrorHandler(resp)
|
return util.RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// do response
|
||||||
var respErr *openai.ErrorWithStatusCode
|
usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens)
|
||||||
var usage *openai.Usage
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
|
||||||
// Why we use defer here? Because if error happened, we will have to return the pre-consumed quota.
|
|
||||||
if respErr != nil {
|
|
||||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
|
||||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if usage == nil {
|
|
||||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
quota := 0
|
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
|
||||||
promptTokens = usage.PromptTokens
|
|
||||||
completionTokens = usage.CompletionTokens
|
|
||||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
|
||||||
if ratio != 0 && quota <= 0 {
|
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
totalTokens := promptTokens + completionTokens
|
|
||||||
if totalTokens == 0 {
|
|
||||||
// in this case, must be some error happened
|
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
|
||||||
quota = 0
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(meta.UserId)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
|
||||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(ctx)
|
|
||||||
usage, respErr = DoResponse(c, &textRequest, resp, relayMode, apiType, isStream, promptTokens)
|
|
||||||
if respErr != nil {
|
if respErr != nil {
|
||||||
|
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||||
|
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||||
return respErr
|
return respErr
|
||||||
}
|
}
|
||||||
|
// post-consume quota
|
||||||
|
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -3,10 +3,12 @@ package util
|
|||||||
import (
|
import (
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RelayMeta struct {
|
type RelayMeta struct {
|
||||||
|
Mode int
|
||||||
ChannelType int
|
ChannelType int
|
||||||
ChannelId int
|
ChannelId int
|
||||||
TokenId int
|
TokenId int
|
||||||
@ -17,11 +19,13 @@ type RelayMeta struct {
|
|||||||
BaseURL string
|
BaseURL string
|
||||||
APIVersion string
|
APIVersion string
|
||||||
APIKey string
|
APIKey string
|
||||||
|
APIType int
|
||||||
Config map[string]string
|
Config map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRelayMeta(c *gin.Context) *RelayMeta {
|
func GetRelayMeta(c *gin.Context) *RelayMeta {
|
||||||
meta := RelayMeta{
|
meta := RelayMeta{
|
||||||
|
Mode: constant.Path2RelayMode(c.Request.URL.Path),
|
||||||
ChannelType: c.GetInt("channel"),
|
ChannelType: c.GetInt("channel"),
|
||||||
ChannelId: c.GetInt("channel_id"),
|
ChannelId: c.GetInt("channel_id"),
|
||||||
TokenId: c.GetInt("token_id"),
|
TokenId: c.GetInt("token_id"),
|
||||||
@ -40,5 +44,6 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
|
|||||||
if meta.BaseURL == "" {
|
if meta.BaseURL == "" {
|
||||||
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
|
meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
|
||||||
}
|
}
|
||||||
|
meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
|
||||||
return &meta
|
return &meta
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user