diff --git a/common/model-ratio.go b/common/model-ratio.go index 681f0ae7..b4a471dc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -6,6 +6,29 @@ import ( "time" ) +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} + // ModelRatio // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf @@ -45,7 +68,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e": 8, + "dall-e-2": 8, // $0.016 - $0.020 / image + "dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens diff --git a/controller/model.go b/controller/model.go index 7bd9d097..f9904330 100644 --- a/controller/model.go +++ b/controller/model.go @@ -55,12 +55,21 @@ func init() { // https://platform.openai.com/docs/models/model-endpoint-compatibility openAIModels = []OpenAIModels{ { - Id: "dall-e", + Id: "dall-e-2", Object: "model", Created: 1677649963, OwnedBy: "openai", Permission: permission, - Root: "dall-e", + Root: "dall-e-2", + Parent: nil, + }, + { + Id: "dall-e-3", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "dall-e-3", Parent: nil, }, { diff --git a/controller/relay-image.go b/controller/relay-image.go index ccd52dce..1d1b71ba 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -6,15 +6,28 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) +func isWithinRange(element string, value int) bool { + if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + return false + } + + min := common.DalleGenerationImageAmounts[element][0] + max := common.DalleGenerationImageAmounts[element][1] + + return value >= min && value <= max +} + func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - imageModel := "dall-e" + imageModel := "dall-e-2" + imageSize := "1024x1024" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -31,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } + // Size validation + if imageRequest.Size != "" { + imageSize = imageRequest.Size + } + + // Model validation + if imageRequest.Model != "" { + imageModel = imageRequest.Model + } + + imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] + + // Check if model is supported + if hasValidSize { + if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { + if imageSize == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // Prompt validation if imageRequest.Prompt == "" { - return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - // Not "256x256", "512x512", or "1024x1024" - if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest) + // Check prompt length + if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { + return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + // Number of generated images validation + if isWithinRange(imageModel, imageRequest.N) == false { + return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } // map model name @@ -82,16 +120,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(userId) - sizeRatio := 1.0 - // Size - if imageRequest.Size == "256x256" { - sizeRatio = 1 - } else if imageRequest.Size == "512x512" { - sizeRatio = 1.125 - } else if imageRequest.Size == "1024x1024" { - sizeRatio = 1.25 - } - quota := int(ratio*sizeRatio*1000) * imageRequest.N + quota := int(ratio*imageCostRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) diff --git a/controller/relay.go b/controller/relay.go index 1926110e..9cff887b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -77,10 +77,16 @@ type TextRequest struct { //Stream bool `json:"stream"` } +// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create type ImageRequest struct { - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality"` + ResponseFormat string `json:"response_format"` + Style string `json:"style"` + User string `json:"user"` } type AudioResponse struct {