package controller import ( "context" "errors" "fmt" "math" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/controller/validator" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { textRequest := &relaymodel.GeneralOpenAIRequest{} err := common.UnmarshalBodyReusable(c, textRequest) if err != nil { return nil, err } if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } err = validator.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 } func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota + int64(promptTokens) if textRequest.MaxTokens != 0 { preConsumedTokens += int64(textRequest.MaxTokens) } return int64(float64(preConsumedTokens) * ratio) } func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) } if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } return preConsumedQuota, nil } func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } var quota int64 completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } totalTokens := promptTokens + completionTokens if totalTokens == 0 { // in this case, must be some error happened // we cannot just return, because we may have to return the pre-consumed quota quota = 0 } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) } func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { if mapping == nil { return modelName, false } mappedModelName := mapping[modelName] if mappedModelName != "" { return mappedModelName, true } return modelName, false } func isErrorHappened(meta *meta.Meta, resp *http.Response) bool { if resp == nil { if meta.ChannelType == channeltype.AwsClaude { return false } return true } if resp.StatusCode != http.StatusOK { return true } if meta.ChannelType == channeltype.DeepL { // skip stream check for deepl return false } if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { return true } return false }