package controller import ( "bytes" "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "io" "net/http" ) func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err } if imageRequest.N == 0 { imageRequest.N = 1 } if imageRequest.Size == "" { imageRequest.Size = "1024x1024" } if imageRequest.Model == "" { imageRequest.Model = "dall-e-2" } return imageRequest, nil } func isValidImageSize(model string, size string) bool { if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { return true } _, ok := billingratio.ImageSizeRatios[model][size] return ok } func isValidImagePromptLength(model string, promptLength int) bool { maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] return !ok || promptLength <= maxPromptLength } func isWithinRange(element string, value int) bool { amounts, ok := billingratio.ImageGenerationAmounts[element] return !ok || (value >= amounts[0] && value <= amounts[1]) } func getImageSizeRatio(model string, size string) float64 { if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { return ratio } return 1 } func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { // check prompt length if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } // model validation if !isValidImageSize(imageRequest.Model, imageRequest.Size) { return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if !isWithinRange(imageRequest.Model, imageRequest.N) { return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } return nil } func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { if imageRequest.Size == "1024x1024" { imageCostRatio *= 2 } else { imageCostRatio *= 1.5 } } return imageCostRatio, nil } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model // model validation bizErr := validateImageRequest(imageRequest, meta) if bizErr != nil { return bizErr } imageCostRatio, err := getImageCostRatio(imageRequest) if err != nil { return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } imageModel := imageRequest.Model // Convert the original image model imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName) c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } adaptor.Init(meta) switch meta.ChannelType { case channeltype.Ali: fallthrough case channeltype.Baidu: fallthrough case channeltype.Zhipu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) } jsonStr, err := json.Marshal(finalRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } modelRatio := billingratio.GetModelRatio(imageModel) groupRatio := billingratio.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } // do request resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } defer func(ctx context.Context) { if resp != nil && resp.StatusCode != http.StatusOK { return } err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString(ctxkey.TokenName) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt(ctxkey.ChannelId) model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) // do response _, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) return respErr } return nil }