From fecaece71b700b43ba11c161a3f8a971af204971 Mon Sep 17 00:00:00 2001 From: igophper <34326532+igophper@users.noreply.github.com> Date: Sun, 30 Jun 2024 19:52:33 +0800 Subject: [PATCH] fix: fix size not support during image generation (#1564) Fixes #1224, #1068 --- relay/controller/helper.go | 72 -------------------------------- relay/controller/image.go | 84 +++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..c47cb558 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { - imageRequest := &relaymodel.ImageRequest{} - err := common.UnmarshalBodyReusable(c, imageRequest) - if err != nil { - return nil, err - } - if imageRequest.N == 0 { - imageRequest.N = 1 - } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" - } - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" - } - return imageRequest, nil -} - -func isValidImageSize(model string, size string) bool { - if model == "cogview-3" { - return true - } - _, ok := billingratio.ImageSizeRatios[model][size] - return ok -} - -func getImageSizeRatio(model string, size string) float64 { - ratio, ok := billingratio.ImageSizeRatios[model][size] - if !ok { - return 1 - } - return ratio -} - -func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { - // model validation - hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) - if !hasValidSize { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - // check prompt length - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - // Number of generated images validation - if !isWithinRange(imageRequest.Model, imageRequest.N) { - // channel not azure - if meta.ChannelType != channeltype.Azure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } - } - return nil -} - -func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { - if imageRequest == nil { - return 0, errors.New("imageRequest is nil") - } - imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if imageRequest.Size == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - return imageCostRatio, nil -} - func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case relaymode.ChatCompletions: diff --git a/relay/controller/image.go b/relay/controller/image.go index 691c7c0e..e6245226 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -20,13 +21,84 @@ import ( "net/http" ) -func isWithinRange(element string, value int) bool { - if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { - return false +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err } - min := billingratio.ImageGenerationAmounts[element][0] - max := billingratio.ImageGenerationAmounts[element][1] - return value >= min && value <= max + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func isValidImagePromptLength(model string, promptLength int) bool { + maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model] + return !ok || promptLength <= maxPromptLength +} + +func isWithinRange(element string, value int) bool { + amounts, ok := billingratio.ImageGenerationAmounts[element] + return !ok || (value >= amounts[0] && value <= amounts[1]) +} + +func getImageSizeRatio(model string, size string) float64 { + if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok { + return ratio + } + return 1 +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + + // model validation + if !isValidImageSize(imageRequest.Model, imageRequest.Size) { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + + if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + return nil +} + +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {