From 0f038d715d8e106f683fe5ebc6615849e9dc0e73 Mon Sep 17 00:00:00 2001 From: Martial BE Date: Fri, 1 Dec 2023 18:25:05 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add:=20add=20images=20edits=20and?= =?UTF-8?q?=20variations=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 2 +- common/token.go | 23 ++++-- controller/relay-helper.go | 63 +++++++++++++++- controller/relay.go | 4 ++ middleware/distributor.go | 61 ++++++++++++---- providers/base/common.go | 2 +- providers/base/interface.go | 12 ++++ providers/openai/base.go | 2 +- providers/openai/image_edits.go | 104 +++++++++++++++++++++++++++ providers/openai/image_variations.go | 49 +++++++++++++ router/relay-router.go | 4 +- 11 files changed, 302 insertions(+), 24 deletions(-) create mode 100644 providers/openai/image_edits.go create mode 100644 providers/openai/image_variations.go diff --git a/common/constants.go b/common/constants.go index 52ac4395..c9fed019 100644 --- a/common/constants.go +++ b/common/constants.go @@ -222,7 +222,7 @@ const ( RelayModeEmbeddings RelayModeModerations RelayModeImagesGenerations - RelayModeImagesEdit + RelayModeImagesEdits RelayModeImagesVariations RelayModeEdits RelayModeAudioSpeech diff --git a/common/token.go b/common/token.go index 4a6a6fbb..59d58343 100644 --- a/common/token.go +++ b/common/token.go @@ -109,12 +109,25 @@ func CountTokenText(text string, model string) int { return getTokenNum(tokenEncoder, text) } -func CountTokenImage(imageRequest types.ImageRequest) (int, error) { - imageCostRatio, hasValidSize := DalleSizeRatios[imageRequest.Model][imageRequest.Size] +func CountTokenImage(input interface{}) (int, error) { + switch v := input.(type) { + case types.ImageRequest: + // 处理 ImageRequest + return calculateToken(v.Model, v.Size, v.N, v.Quality) + case types.ImageEditRequest: + // 处理 ImageEditsRequest + return calculateToken(v.Model, v.Size, v.N, "") + default: + return 0, errors.New("unsupported type") + } +} + +func calculateToken(model string, size string, n int, quality string) (int, error) { + imageCostRatio, hasValidSize := DalleSizeRatios[model][size] if hasValidSize { - if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { - if imageRequest.Size == "1024x1024" { + if quality == "hd" && model == "dall-e-3" { + if size == "1024x1024" { imageCostRatio *= 2 } else { imageCostRatio *= 1.5 @@ -124,5 +137,5 @@ func CountTokenImage(imageRequest types.ImageRequest) (int, error) { return 0, errors.New("size not supported for this image model") } - return int(imageCostRatio*1000) * imageRequest.N, nil + return int(imageCostRatio*1000) * n, nil } diff --git a/controller/relay-helper.go b/controller/relay-helper.go index 94af4d64..fae69582 100644 --- a/controller/relay-helper.go +++ b/controller/relay-helper.go @@ -65,6 +65,10 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group) case common.RelayModeImagesGenerations: usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group) + case common.RelayModeImagesEdits: + usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit") + case common.RelayModeImagesVariations: + usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation") default: return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) } @@ -336,7 +340,7 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { var imageRequest types.ImageRequest isModelMapped := false - speechProvider, ok := provider.(providers_base.ImageGenerationsInterface) + imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface) if !ok { return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) } @@ -374,5 +378,60 @@ func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInte if quota_err != nil { return nil, quota_err } - return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) + return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) +} + +func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var imageEditRequest types.ImageEditRequest + isModelMapped := false + var imageEditsProvider providers_base.ImageEditsInterface + var imageVariations providers_base.ImageVariationsInterface + var ok bool + if imageType == "edit" { + imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + } else { + imageVariations, ok = provider.(providers_base.ImageVariationsInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + } + + err := common.UnmarshalBodyReusable(c, &imageEditRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + if imageEditRequest.Model == "" { + imageEditRequest.Model = "dall-e-2" + } + + if imageEditRequest.Size == "" { + imageEditRequest.Size = "1024x1024" + } + + if modelMap != nil && modelMap[imageEditRequest.Model] != "" { + imageEditRequest.Model = modelMap[imageEditRequest.Model] + isModelMapped = true + } + promptTokens, err := common.CountTokenImage(imageEditRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError) + } + + quotaInfo.modelName = imageEditRequest.Model + quotaInfo.promptTokens = promptTokens + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + + if imageType == "edit" { + return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens) + } + + return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens) } diff --git a/controller/relay.go b/controller/relay.go index c1bae21f..f73e7593 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -244,6 +244,10 @@ func Relay(c *gin.Context) { relayMode = common.RelayModeAudioTranslation } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { relayMode = common.RelayModeImagesGenerations + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { + relayMode = common.RelayModeImagesEdits + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/variations") { + relayMode = common.RelayModeImagesVariations } // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { // relayMode = RelayModeEdits diff --git a/middleware/distributor.go b/middleware/distributor.go index 50ac1c29..811b512e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -11,10 +11,35 @@ import ( "github.com/gin-gonic/gin" ) +type ModelRequestInterface interface { + GetModel() string + SetModel(string) +} + type ModelRequest struct { Model string `json:"model"` } +func (m *ModelRequest) GetModel() string { + return m.Model +} + +func (m *ModelRequest) SetModel(model string) { + m.Model = model +} + +type ModelFormRequest struct { + Model string `form:"model"` +} + +func (m *ModelFormRequest) GetModel() string { + return m.Model +} + +func (m *ModelFormRequest) SetModel(model string) { + m.Model = model +} + func Distribute() func(c *gin.Context) { return func(c *gin.Context) { userId := c.GetInt("id") @@ -39,35 +64,36 @@ func Distribute() func(c *gin.Context) { } } else { // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + modelRequest := getModelRequest(c) + err := common.UnmarshalBodyReusable(c, modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" + if modelRequest.GetModel() == "" { + modelRequest.SetModel("text-moderation-stable") } } if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") + if modelRequest.GetModel() == "" { + modelRequest.SetModel(c.Param("model")) } } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" + if modelRequest.GetModel() == "" { + modelRequest.SetModel("dall-e-2") } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" + if modelRequest.GetModel() == "" { + modelRequest.SetModel("whisper-1") } } - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel()) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel()) if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" @@ -94,3 +120,14 @@ func Distribute() func(c *gin.Context) { c.Next() } } + +func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) { + contentType := c.Request.Header.Get("Content-Type") + if strings.HasPrefix(contentType, "application/json") { + modelRequest = &ModelRequest{} + } else if strings.HasPrefix(contentType, "multipart/form-data") { + modelRequest = &ModelFormRequest{} + } + + return +} diff --git a/providers/base/common.go b/providers/base/common.go index e6cc49f8..4222515b 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -146,7 +146,7 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool { return p.Moderation != "" case common.RelayModeImagesGenerations: return p.ImagesGenerations != "" - case common.RelayModeImagesEdit: + case common.RelayModeImagesEdits: return p.ImagesEdit != "" case common.RelayModeImagesVariations: return p.ImagesVariations != "" diff --git a/providers/base/interface.go b/providers/base/interface.go index 6c6bcfd4..714bd5c1 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -56,11 +56,23 @@ type TranslationInterface interface { TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +// 图片生成接口 type ImageGenerationsInterface interface { ProviderInterface ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +// 图片编辑接口 +type ImageEditsInterface interface { + ProviderInterface + ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + +type ImageVariationsInterface interface { + ProviderInterface + ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + // 余额接口 type BalanceInterface interface { BalanceAction(channel *model.Channel) (float64, error) diff --git a/providers/openai/base.go b/providers/openai/base.go index cd34edb9..b8d0d9b8 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -39,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { AudioTranscriptions: "/v1/audio/transcriptions", AudioTranslations: "/v1/audio/translations", ImagesGenerations: "/v1/images/generations", - ImagesEdit: "/v1/images/edit", + ImagesEdit: "/v1/images/edits", ImagesVariations: "/v1/images/variations", Context: c, }, diff --git a/providers/openai/image_edits.go b/providers/openai/image_edits.go new file mode 100644 index 00000000..dce7c4ff --- /dev/null +++ b/providers/openai/image_edits.go @@ -0,0 +1,104 @@ +package openai + +import ( + "bytes" + "fmt" + "net/http" + "one-api/common" + "one-api/types" +) + +func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + + var formBody bytes.Buffer + var req *http.Request + var err error + if isModelMapped { + builder := client.CreateFormBuilder(&formBody) + if err := imagesEditsMultipartForm(request, builder); err != nil { + return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) + } + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + req.ContentLength = int64(formBody.Len()) + + } else { + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + req.ContentLength = p.Context.Request.ContentLength + } + + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} + errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) + if errWithCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + + return +} + +func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error { + err := b.CreateFormFile("image", request.Image) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + err = b.WriteField("prompt", request.Prompt) + if err != nil { + return fmt.Errorf("writing prompt: %w", err) + } + + err = b.WriteField("model", request.Model) + if err != nil { + return fmt.Errorf("writing model name: %w", err) + } + + if request.Mask != nil { + err = b.CreateFormFile("mask", request.Mask) + if err != nil { + return fmt.Errorf("writing format: %w", err) + } + } + + if request.ResponseFormat != "" { + err = b.WriteField("response_format", request.ResponseFormat) + if err != nil { + return fmt.Errorf("writing format: %w", err) + } + } + + if request.N != 0 { + err = b.WriteField("n", fmt.Sprintf("%.2f", request.N)) + if err != nil { + return fmt.Errorf("writing temperature: %w", err) + } + } + + if request.Size != "" { + err = b.WriteField("size", request.Size) + if err != nil { + return fmt.Errorf("writing language: %w", err) + } + } + + if request.User != "" { + err = b.WriteField("user", request.User) + if err != nil { + return fmt.Errorf("writing language: %w", err) + } + } + + return b.Close() +} diff --git a/providers/openai/image_variations.go b/providers/openai/image_variations.go new file mode 100644 index 00000000..2ddd03d7 --- /dev/null +++ b/providers/openai/image_variations.go @@ -0,0 +1,49 @@ +package openai + +import ( + "bytes" + "net/http" + "one-api/common" + "one-api/types" +) + +func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + + var formBody bytes.Buffer + var req *http.Request + var err error + if isModelMapped { + builder := client.CreateFormBuilder(&formBody) + if err := imagesEditsMultipartForm(request, builder); err != nil { + return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) + } + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + req.ContentLength = int64(formBody.Len()) + + } else { + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + req.ContentLength = p.Context.Request.ContentLength + } + + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} + errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) + if errWithCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + + return +} diff --git a/router/relay-router.go b/router/relay-router.go index 912f4989..4b1be206 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -23,8 +23,8 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) - relayV1Router.POST("/images/variations", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) + relayV1Router.POST("/images/variations", controller.Relay) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay)