package controller import ( "context" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/util" "io" "math" "net/http" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) { textRequest := &openai.GeneralOpenAIRequest{} err := common.UnmarshalBodyReusable(c, textRequest) if err != nil { return nil, err } if relayMode == constant.RelayModeModerations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } err = util.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case constant.RelayModeChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) case constant.RelayModeCompletions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case constant.RelayModeModerations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 } func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens } return int(float64(preConsumedTokens) * ratio) } func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) } if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } return preConsumedQuota, nil } func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } totalTokens := promptTokens + completionTokens if totalTokens == 0 { // in this case, must be some error happened // we cannot just return, because we may have to return the pre-consumed quota quota = 0 } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } if quota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) } } func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) { req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, err } SetupRequestHeaders(c, req, meta, isStream) resp, err := util.HTTPClient.Do(req) if err != nil { return nil, err } if resp == nil { return nil, errors.New("resp is nil") } err = req.Body.Close() if err != nil { logger.Warnf(ctx, "close req.Body failed: %+v", err) } err = c.Request.Body.Close() if err != nil { logger.Warnf(ctx, "close c.Request.Body failed: %+v", err) } return resp, nil }