package controller import ( "context" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/controller/validator" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" "math" "net/http" ) func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) { textRequest := &relaymodel.GeneralOpenAIRequest{} err := common.UnmarshalBodyReusable(c, textRequest) if err != nil { return nil, err } if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } err = validator.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err } if imageRequest.N == 0 { imageRequest.N = 1 } if imageRequest.Size == "" { imageRequest.Size = "1024x1024" } if imageRequest.Model == "" { imageRequest.Model = "dall-e-2" } return imageRequest, nil } func isValidImageSize(model string, size string) bool { if model == "cogview-3" { return true } _, ok := billingratio.ImageSizeRatios[model][size] return ok } func getImageSizeRatio(model string, size string) float64 { ratio, ok := billingratio.ImageSizeRatios[model][size] if !ok { return 1 } return ratio } func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { // model validation hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) if !hasValidSize { return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if !isWithinRange(imageRequest.Model, imageRequest.N) { // channel not azure if meta.ChannelType != channeltype.Azure { return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } return nil } func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { if imageRequest.Size == "1024x1024" { imageCostRatio *= 2 } else { imageCostRatio *= 1.5 } } return imageCostRatio, nil } func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 } func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) } return int64(float64(preConsumedTokens) * ratio) } func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota)) } if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } return preConsumedQuota, nil } func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } var quota int64 completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } totalTokens := promptTokens + completionTokens if totalTokens == 0 { // in this case, must be some error happened // we cannot just return, because we may have to return the pre-consumed quota quota = 0 } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta) if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) } func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { if mapping == nil { return modelName, false } mappedModelName := mapping[modelName] if mappedModelName != "" { return mappedModelName, true } return modelName, false }