Refactored image size and model validation in

relayImageHelper function
This commit is contained in:
ckt1031 2023-11-15 20:33:49 +08:00
parent b59fbafbc3
commit a3a309cdc5

View File

@ -27,7 +27,7 @@ func isWithinRange(element string, value int) bool {
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e-2" imageModel := "dall-e-2"
requestSize := "1024x1024" imageSize := "1024x1024"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
@ -45,29 +45,34 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
} }
// Size validation
if imageRequest.Size != "" {
imageSize = imageRequest.Size
}
// Model validation // Model validation
if imageRequest.Model != "" { if imageRequest.Model != "" {
// Check if model is supported
if _, ok := common.DalleSizeRatios[imageRequest.Model]; !ok {
return errorWrapper(errors.New("model not supported"), "model_not_supported", http.StatusBadRequest)
}
imageModel = imageRequest.Model imageModel = imageRequest.Model
} }
// Size validation imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
if imageRequest.Size != "" {
requestSize = imageRequest.Size // Check if model is supported
if hasValidSize {
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
if imageSize == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
}
}
} else {
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
} }
// Prompt validation // Prompt validation
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("invalud value of n"), "number_of_generated_images_not_within_range", http.StatusBadRequest)
} }
// Check prompt length // Check prompt length
@ -75,6 +80,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
} }
// Number of generated images validation
if isWithinRange(imageModel, imageRequest.N) == false {
return errorWrapper(errors.New("invalud value of n"), "n_not_within_range", http.StatusBadRequest)
}
// map model name // map model name
modelMapping := c.GetString("model_mapping") modelMapping := c.GetString("model_mapping")
isModelMapped := false isModelMapped := false
@ -111,20 +121,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0 quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if ratio, ok := common.DalleSizeRatios[imageModel][requestSize]; ok {
sizeRatio = ratio
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if requestSize == "1024x1024" {
sizeRatio *= 2
} else {
sizeRatio *= 1.5
}
}
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if consumeQuota && userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)