add: add images edits and variations API

This commit is contained in:
Martial BE 2023-12-01 18:25:05 +08:00
parent 9dd92bbddd
commit 0f038d715d
No known key found for this signature in database
GPG Key ID: D06C32DF0EDB9084
11 changed files with 302 additions and 24 deletions

View File

@ -222,7 +222,7 @@ const (
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeImagesEdit
RelayModeImagesEdits
RelayModeImagesVariations
RelayModeEdits
RelayModeAudioSpeech

View File

@ -109,12 +109,25 @@ func CountTokenText(text string, model string) int {
return getTokenNum(tokenEncoder, text)
}
func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
imageCostRatio, hasValidSize := DalleSizeRatios[imageRequest.Model][imageRequest.Size]
func CountTokenImage(input interface{}) (int, error) {
switch v := input.(type) {
case types.ImageRequest:
// 处理 ImageRequest
return calculateToken(v.Model, v.Size, v.N, v.Quality)
case types.ImageEditRequest:
// 处理 ImageEditsRequest
return calculateToken(v.Model, v.Size, v.N, "")
default:
return 0, errors.New("unsupported type")
}
}
func calculateToken(model string, size string, n int, quality string) (int, error) {
imageCostRatio, hasValidSize := DalleSizeRatios[model][size]
if hasValidSize {
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
if quality == "hd" && model == "dall-e-3" {
if size == "1024x1024" {
imageCostRatio *= 2
} else {
imageCostRatio *= 1.5
@ -124,5 +137,5 @@ func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
return 0, errors.New("size not supported for this image model")
}
return int(imageCostRatio*1000) * imageRequest.N, nil
return int(imageCostRatio*1000) * n, nil
}

View File

@ -65,6 +65,10 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeImagesGenerations:
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeImagesEdits:
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit")
case common.RelayModeImagesVariations:
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation")
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
@ -336,7 +340,7 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var imageRequest types.ImageRequest
isModelMapped := false
speechProvider, ok := provider.(providers_base.ImageGenerationsInterface)
imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
@ -374,5 +378,60 @@ func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInte
if quota_err != nil {
return nil, quota_err
}
return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
}
func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var imageEditRequest types.ImageEditRequest
isModelMapped := false
var imageEditsProvider providers_base.ImageEditsInterface
var imageVariations providers_base.ImageVariationsInterface
var ok bool
if imageType == "edit" {
imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
} else {
imageVariations, ok = provider.(providers_base.ImageVariationsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
}
err := common.UnmarshalBodyReusable(c, &imageEditRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if imageEditRequest.Model == "" {
imageEditRequest.Model = "dall-e-2"
}
if imageEditRequest.Size == "" {
imageEditRequest.Size = "1024x1024"
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
promptTokens, err := common.CountTokenImage(imageEditRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError)
}
quotaInfo.modelName = imageEditRequest.Model
quotaInfo.promptTokens = promptTokens
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
if imageType == "edit" {
return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
}
return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
}

View File

@ -244,6 +244,10 @@ func Relay(c *gin.Context) {
relayMode = common.RelayModeAudioTranslation
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = common.RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
relayMode = common.RelayModeImagesEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/variations") {
relayMode = common.RelayModeImagesVariations
}
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
// relayMode = RelayModeEdits

View File

@ -11,10 +11,35 @@ import (
"github.com/gin-gonic/gin"
)
type ModelRequestInterface interface {
GetModel() string
SetModel(string)
}
type ModelRequest struct {
Model string `json:"model"`
}
func (m *ModelRequest) GetModel() string {
return m.Model
}
func (m *ModelRequest) SetModel(model string) {
m.Model = model
}
type ModelFormRequest struct {
Model string `form:"model"`
}
func (m *ModelFormRequest) GetModel() string {
return m.Model
}
func (m *ModelFormRequest) SetModel(model string) {
m.Model = model
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
@ -39,35 +64,36 @@ func Distribute() func(c *gin.Context) {
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
modelRequest := getModelRequest(c)
err := common.UnmarshalBodyReusable(c, modelRequest)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("text-moderation-stable")
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
if modelRequest.GetModel() == "" {
modelRequest.SetModel(c.Param("model"))
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("dall-e-2")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
if modelRequest.GetModel() == "" {
modelRequest.SetModel("whisper-1")
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel())
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel())
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
@ -94,3 +120,14 @@ func Distribute() func(c *gin.Context) {
c.Next()
}
}
func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
modelRequest = &ModelRequest{}
} else if strings.HasPrefix(contentType, "multipart/form-data") {
modelRequest = &ModelFormRequest{}
}
return
}

View File

@ -146,7 +146,7 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool {
return p.Moderation != ""
case common.RelayModeImagesGenerations:
return p.ImagesGenerations != ""
case common.RelayModeImagesEdit:
case common.RelayModeImagesEdits:
return p.ImagesEdit != ""
case common.RelayModeImagesVariations:
return p.ImagesVariations != ""

View File

@ -56,11 +56,23 @@ type TranslationInterface interface {
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 图片生成接口
type ImageGenerationsInterface interface {
ProviderInterface
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 图片编辑接口
type ImageEditsInterface interface {
ProviderInterface
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
type ImageVariationsInterface interface {
ProviderInterface
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
BalanceAction(channel *model.Channel) (float64, error)

View File

@ -39,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edit",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
Context: c,
},

View File

@ -0,0 +1,104 @@
package openai
import (
"bytes"
"fmt"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
err := b.CreateFormFile("image", request.Image)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}
err = b.WriteField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
err = b.WriteField("model", request.Model)
if err != nil {
return fmt.Errorf("writing model name: %w", err)
}
if request.Mask != nil {
err = b.CreateFormFile("mask", request.Mask)
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}
if request.ResponseFormat != "" {
err = b.WriteField("response_format", request.ResponseFormat)
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}
if request.N != 0 {
err = b.WriteField("n", fmt.Sprintf("%.2f", request.N))
if err != nil {
return fmt.Errorf("writing temperature: %w", err)
}
}
if request.Size != "" {
err = b.WriteField("size", request.Size)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}
if request.User != "" {
err = b.WriteField("user", request.User)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}
return b.Close()
}

View File

@ -0,0 +1,49 @@
package openai
import (
"bytes"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@ -23,8 +23,8 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/chat/completions", controller.Relay)
relayV1Router.POST("/edits", controller.Relay)
relayV1Router.POST("/images/generations", controller.Relay)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/images/edits", controller.Relay)
relayV1Router.POST("/images/variations", controller.Relay)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay)