fix: fix size not support during image generation (#1564)
Fixes #1224, #1068
This commit is contained in:
parent
c135d74f13
commit
fecaece71b
@ -40,78 +40,6 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
|||||||
return textRequest, nil
|
return textRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
|
||||||
imageRequest := &relaymodel.ImageRequest{}
|
|
||||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if imageRequest.N == 0 {
|
|
||||||
imageRequest.N = 1
|
|
||||||
}
|
|
||||||
if imageRequest.Size == "" {
|
|
||||||
imageRequest.Size = "1024x1024"
|
|
||||||
}
|
|
||||||
if imageRequest.Model == "" {
|
|
||||||
imageRequest.Model = "dall-e-2"
|
|
||||||
}
|
|
||||||
return imageRequest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isValidImageSize(model string, size string) bool {
|
|
||||||
if model == "cogview-3" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
_, ok := billingratio.ImageSizeRatios[model][size]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func getImageSizeRatio(model string, size string) float64 {
|
|
||||||
ratio, ok := billingratio.ImageSizeRatios[model][size]
|
|
||||||
if !ok {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return ratio
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
|
||||||
// model validation
|
|
||||||
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
|
|
||||||
if !hasValidSize {
|
|
||||||
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
// check prompt length
|
|
||||||
if imageRequest.Prompt == "" {
|
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
|
|
||||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
// Number of generated images validation
|
|
||||||
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
|
||||||
// channel not azure
|
|
||||||
if meta.ChannelType != channeltype.Azure {
|
|
||||||
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
|
||||||
if imageRequest == nil {
|
|
||||||
return 0, errors.New("imageRequest is nil")
|
|
||||||
}
|
|
||||||
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
|
||||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
|
||||||
if imageRequest.Size == "1024x1024" {
|
|
||||||
imageCostRatio *= 2
|
|
||||||
} else {
|
|
||||||
imageCostRatio *= 1.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return imageCostRatio, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
@ -20,13 +21,84 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isWithinRange(element string, value int) bool {
|
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
|
||||||
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
|
imageRequest := &relaymodel.ImageRequest{}
|
||||||
return false
|
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
min := billingratio.ImageGenerationAmounts[element][0]
|
if imageRequest.N == 0 {
|
||||||
max := billingratio.ImageGenerationAmounts[element][1]
|
imageRequest.N = 1
|
||||||
return value >= min && value <= max
|
}
|
||||||
|
if imageRequest.Size == "" {
|
||||||
|
imageRequest.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
if imageRequest.Model == "" {
|
||||||
|
imageRequest.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
return imageRequest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidImageSize(model string, size string) bool {
|
||||||
|
if model == "cogview-3" || billingratio.ImageSizeRatios[model] == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := billingratio.ImageSizeRatios[model][size]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidImagePromptLength(model string, promptLength int) bool {
|
||||||
|
maxPromptLength, ok := billingratio.ImagePromptLengthLimitations[model]
|
||||||
|
return !ok || promptLength <= maxPromptLength
|
||||||
|
}
|
||||||
|
|
||||||
|
func isWithinRange(element string, value int) bool {
|
||||||
|
amounts, ok := billingratio.ImageGenerationAmounts[element]
|
||||||
|
return !ok || (value >= amounts[0] && value <= amounts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func getImageSizeRatio(model string, size string) float64 {
|
||||||
|
if ratio, ok := billingratio.ImageSizeRatios[model][size]; ok {
|
||||||
|
return ratio
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
|
// check prompt length
|
||||||
|
if imageRequest.Prompt == "" {
|
||||||
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// model validation
|
||||||
|
if !isValidImageSize(imageRequest.Model, imageRequest.Size) {
|
||||||
|
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidImagePromptLength(imageRequest.Model, len(imageRequest.Prompt)) {
|
||||||
|
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Number of generated images validation
|
||||||
|
if !isWithinRange(imageRequest.Model, imageRequest.N) {
|
||||||
|
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||||
|
if imageRequest == nil {
|
||||||
|
return 0, errors.New("imageRequest is nil")
|
||||||
|
}
|
||||||
|
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
|
||||||
|
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
||||||
|
if imageRequest.Size == "1024x1024" {
|
||||||
|
imageCostRatio *= 2
|
||||||
|
} else {
|
||||||
|
imageCostRatio *= 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return imageCostRatio, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||||
|
Loading…
Reference in New Issue
Block a user