diff --git a/controller/relay-image.go b/controller/relay-image.go index a7941ba5..8ee07670 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -27,7 +27,7 @@ func isWithinRange(element string, value int) bool { func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { imageModel := "dall-e-2" - requestSize := "1024x1024" + imageSize := "1024x1024" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -45,29 +45,34 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + // Size validation + if imageRequest.Size != "" { + imageSize = imageRequest.Size + } + // Model validation if imageRequest.Model != "" { - // Check if model is supported - if _, ok := common.DalleSizeRatios[imageRequest.Model]; !ok { - return errorWrapper(errors.New("model not supported"), "model_not_supported", http.StatusBadRequest) - } - imageModel = imageRequest.Model } - // Size validation - if imageRequest.Size != "" { - requestSize = imageRequest.Size + imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] + + // Check if model is supported + if hasValidSize { + if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { + if imageSize == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) - } - - // Number of generated images validation - if isWithinRange(imageModel, imageRequest.N) == false { - return errorWrapper(errors.New("invalud value of n"), "number_of_generated_images_not_within_range", http.StatusBadRequest) + return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } // Check prompt length @@ -75,6 +80,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } + // Number of generated images validation + if isWithinRange(imageModel, imageRequest.N) == false { + return errorWrapper(errors.New("invalud value of n"), "n_not_within_range", http.StatusBadRequest) + } + // map model name modelMapping := c.GetString("model_mapping") isModelMapped := false @@ -111,20 +121,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - sizeRatio := 1.0 - - if ratio, ok := common.DalleSizeRatios[imageModel][requestSize]; ok { - sizeRatio = ratio - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if requestSize == "1024x1024" { - sizeRatio *= 2 - } else { - sizeRatio *= 1.5 - } - } - } - - quota := int(ratio*sizeRatio*1000) * imageRequest.N + quota := int(ratio*imageCostRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)