✨ add: add images edits and variations API
This commit is contained in:
parent
9dd92bbddd
commit
0f038d715d
@ -222,7 +222,7 @@ const (
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeImagesEdit
|
||||
RelayModeImagesEdits
|
||||
RelayModeImagesVariations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
|
@ -109,12 +109,25 @@ func CountTokenText(text string, model string) int {
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
|
||||
imageCostRatio, hasValidSize := DalleSizeRatios[imageRequest.Model][imageRequest.Size]
|
||||
func CountTokenImage(input interface{}) (int, error) {
|
||||
switch v := input.(type) {
|
||||
case types.ImageRequest:
|
||||
// 处理 ImageRequest
|
||||
return calculateToken(v.Model, v.Size, v.N, v.Quality)
|
||||
case types.ImageEditRequest:
|
||||
// 处理 ImageEditsRequest
|
||||
return calculateToken(v.Model, v.Size, v.N, "")
|
||||
default:
|
||||
return 0, errors.New("unsupported type")
|
||||
}
|
||||
}
|
||||
|
||||
func calculateToken(model string, size string, n int, quality string) (int, error) {
|
||||
imageCostRatio, hasValidSize := DalleSizeRatios[model][size]
|
||||
|
||||
if hasValidSize {
|
||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
||||
if imageRequest.Size == "1024x1024" {
|
||||
if quality == "hd" && model == "dall-e-3" {
|
||||
if size == "1024x1024" {
|
||||
imageCostRatio *= 2
|
||||
} else {
|
||||
imageCostRatio *= 1.5
|
||||
@ -124,5 +137,5 @@ func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
|
||||
return 0, errors.New("size not supported for this image model")
|
||||
}
|
||||
|
||||
return int(imageCostRatio*1000) * imageRequest.N, nil
|
||||
return int(imageCostRatio*1000) * n, nil
|
||||
}
|
||||
|
@ -65,6 +65,10 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
|
||||
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
|
||||
case common.RelayModeImagesGenerations:
|
||||
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
|
||||
case common.RelayModeImagesEdits:
|
||||
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit")
|
||||
case common.RelayModeImagesVariations:
|
||||
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation")
|
||||
default:
|
||||
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||
}
|
||||
@ -336,7 +340,7 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac
|
||||
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var imageRequest types.ImageRequest
|
||||
isModelMapped := false
|
||||
speechProvider, ok := provider.(providers_base.ImageGenerationsInterface)
|
||||
imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
@ -374,5 +378,60 @@ func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInte
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
|
||||
return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var imageEditRequest types.ImageEditRequest
|
||||
isModelMapped := false
|
||||
var imageEditsProvider providers_base.ImageEditsInterface
|
||||
var imageVariations providers_base.ImageVariationsInterface
|
||||
var ok bool
|
||||
if imageType == "edit" {
|
||||
imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
} else {
|
||||
imageVariations, ok = provider.(providers_base.ImageVariationsInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &imageEditRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if imageEditRequest.Model == "" {
|
||||
imageEditRequest.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if imageEditRequest.Size == "" {
|
||||
imageEditRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
|
||||
imageEditRequest.Model = modelMap[imageEditRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
quotaInfo.modelName = imageEditRequest.Model
|
||||
quotaInfo.promptTokens = promptTokens
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
|
||||
if imageType == "edit" {
|
||||
return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
@ -244,6 +244,10 @@ func Relay(c *gin.Context) {
|
||||
relayMode = common.RelayModeAudioTranslation
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
relayMode = common.RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||
relayMode = common.RelayModeImagesEdits
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/variations") {
|
||||
relayMode = common.RelayModeImagesVariations
|
||||
}
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||
// relayMode = RelayModeEdits
|
||||
|
@ -11,10 +11,35 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ModelRequestInterface interface {
|
||||
GetModel() string
|
||||
SetModel(string)
|
||||
}
|
||||
|
||||
type ModelRequest struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func (m *ModelRequest) GetModel() string {
|
||||
return m.Model
|
||||
}
|
||||
|
||||
func (m *ModelRequest) SetModel(model string) {
|
||||
m.Model = model
|
||||
}
|
||||
|
||||
type ModelFormRequest struct {
|
||||
Model string `form:"model"`
|
||||
}
|
||||
|
||||
func (m *ModelFormRequest) GetModel() string {
|
||||
return m.Model
|
||||
}
|
||||
|
||||
func (m *ModelFormRequest) SetModel(model string) {
|
||||
m.Model = model
|
||||
}
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
@ -39,35 +64,36 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
modelRequest := getModelRequest(c)
|
||||
err := common.UnmarshalBodyReusable(c, modelRequest)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
if modelRequest.GetModel() == "" {
|
||||
modelRequest.SetModel("text-moderation-stable")
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
if modelRequest.GetModel() == "" {
|
||||
modelRequest.SetModel(c.Param("model"))
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
if modelRequest.GetModel() == "" {
|
||||
modelRequest.SetModel("dall-e-2")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
if modelRequest.GetModel() == "" {
|
||||
modelRequest.SetModel("whisper-1")
|
||||
}
|
||||
}
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel())
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel())
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
@ -94,3 +120,14 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) {
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
modelRequest = &ModelRequest{}
|
||||
} else if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
modelRequest = &ModelFormRequest{}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
||||
return p.Moderation != ""
|
||||
case common.RelayModeImagesGenerations:
|
||||
return p.ImagesGenerations != ""
|
||||
case common.RelayModeImagesEdit:
|
||||
case common.RelayModeImagesEdits:
|
||||
return p.ImagesEdit != ""
|
||||
case common.RelayModeImagesVariations:
|
||||
return p.ImagesVariations != ""
|
||||
|
@ -56,11 +56,23 @@ type TranslationInterface interface {
|
||||
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片生成接口
|
||||
type ImageGenerationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片编辑接口
|
||||
type ImageEditsInterface interface {
|
||||
ProviderInterface
|
||||
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ImageVariationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
BalanceAction(channel *model.Channel) (float64, error)
|
||||
|
@ -39,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edit",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
Context: c,
|
||||
},
|
||||
|
104
providers/openai/image_edits.go
Normal file
104
providers/openai/image_edits.go
Normal file
@ -0,0 +1,104 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
|
||||
err := b.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form file: %w", err)
|
||||
}
|
||||
|
||||
err = b.WriteField("prompt", request.Prompt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing prompt: %w", err)
|
||||
}
|
||||
|
||||
err = b.WriteField("model", request.Model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing model name: %w", err)
|
||||
}
|
||||
|
||||
if request.Mask != nil {
|
||||
err = b.CreateFormFile("mask", request.Mask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.ResponseFormat != "" {
|
||||
err = b.WriteField("response_format", request.ResponseFormat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.N != 0 {
|
||||
err = b.WriteField("n", fmt.Sprintf("%.2f", request.N))
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing temperature: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.Size != "" {
|
||||
err = b.WriteField("size", request.Size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.User != "" {
|
||||
err = b.WriteField("user", request.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return b.Close()
|
||||
}
|
49
providers/openai/image_variations.go
Normal file
49
providers/openai/image_variations.go
Normal file
@ -0,0 +1,49 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
@ -23,8 +23,8 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
relayV1Router.POST("/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/generations", controller.Relay)
|
||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/images/edits", controller.Relay)
|
||||
relayV1Router.POST("/images/variations", controller.Relay)
|
||||
relayV1Router.POST("/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||
|
Loading…
Reference in New Issue
Block a user