✨ add: add images edits and variations API
This commit is contained in:
parent
9dd92bbddd
commit
0f038d715d
@ -222,7 +222,7 @@ const (
|
|||||||
RelayModeEmbeddings
|
RelayModeEmbeddings
|
||||||
RelayModeModerations
|
RelayModeModerations
|
||||||
RelayModeImagesGenerations
|
RelayModeImagesGenerations
|
||||||
RelayModeImagesEdit
|
RelayModeImagesEdits
|
||||||
RelayModeImagesVariations
|
RelayModeImagesVariations
|
||||||
RelayModeEdits
|
RelayModeEdits
|
||||||
RelayModeAudioSpeech
|
RelayModeAudioSpeech
|
||||||
|
@ -109,12 +109,25 @@ func CountTokenText(text string, model string) int {
|
|||||||
return getTokenNum(tokenEncoder, text)
|
return getTokenNum(tokenEncoder, text)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
|
func CountTokenImage(input interface{}) (int, error) {
|
||||||
imageCostRatio, hasValidSize := DalleSizeRatios[imageRequest.Model][imageRequest.Size]
|
switch v := input.(type) {
|
||||||
|
case types.ImageRequest:
|
||||||
|
// 处理 ImageRequest
|
||||||
|
return calculateToken(v.Model, v.Size, v.N, v.Quality)
|
||||||
|
case types.ImageEditRequest:
|
||||||
|
// 处理 ImageEditsRequest
|
||||||
|
return calculateToken(v.Model, v.Size, v.N, "")
|
||||||
|
default:
|
||||||
|
return 0, errors.New("unsupported type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateToken(model string, size string, n int, quality string) (int, error) {
|
||||||
|
imageCostRatio, hasValidSize := DalleSizeRatios[model][size]
|
||||||
|
|
||||||
if hasValidSize {
|
if hasValidSize {
|
||||||
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
|
if quality == "hd" && model == "dall-e-3" {
|
||||||
if imageRequest.Size == "1024x1024" {
|
if size == "1024x1024" {
|
||||||
imageCostRatio *= 2
|
imageCostRatio *= 2
|
||||||
} else {
|
} else {
|
||||||
imageCostRatio *= 1.5
|
imageCostRatio *= 1.5
|
||||||
@ -124,5 +137,5 @@ func CountTokenImage(imageRequest types.ImageRequest) (int, error) {
|
|||||||
return 0, errors.New("size not supported for this image model")
|
return 0, errors.New("size not supported for this image model")
|
||||||
}
|
}
|
||||||
|
|
||||||
return int(imageCostRatio*1000) * imageRequest.N, nil
|
return int(imageCostRatio*1000) * n, nil
|
||||||
}
|
}
|
||||||
|
@ -65,6 +65,10 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
|
|||||||
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
|
usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group)
|
||||||
case common.RelayModeImagesGenerations:
|
case common.RelayModeImagesGenerations:
|
||||||
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
|
usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group)
|
||||||
|
case common.RelayModeImagesEdits:
|
||||||
|
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "edit")
|
||||||
|
case common.RelayModeImagesVariations:
|
||||||
|
usage, openAIErrorWithStatusCode = handleImageEdits(c, provider, modelMap, quotaInfo, group, "variation")
|
||||||
default:
|
default:
|
||||||
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
@ -336,7 +340,7 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac
|
|||||||
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
var imageRequest types.ImageRequest
|
var imageRequest types.ImageRequest
|
||||||
isModelMapped := false
|
isModelMapped := false
|
||||||
speechProvider, ok := provider.(providers_base.ImageGenerationsInterface)
|
imageGenerationsProvider, ok := provider.(providers_base.ImageGenerationsInterface)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
}
|
}
|
||||||
@ -374,5 +378,60 @@ func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInte
|
|||||||
if quota_err != nil {
|
if quota_err != nil {
|
||||||
return nil, quota_err
|
return nil, quota_err
|
||||||
}
|
}
|
||||||
return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
|
return imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleImageEdits(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string, imageType string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||||
|
var imageEditRequest types.ImageEditRequest
|
||||||
|
isModelMapped := false
|
||||||
|
var imageEditsProvider providers_base.ImageEditsInterface
|
||||||
|
var imageVariations providers_base.ImageVariationsInterface
|
||||||
|
var ok bool
|
||||||
|
if imageType == "edit" {
|
||||||
|
imageEditsProvider, ok = provider.(providers_base.ImageEditsInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
imageVariations, ok = provider.(providers_base.ImageVariationsInterface)
|
||||||
|
if !ok {
|
||||||
|
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err := common.UnmarshalBodyReusable(c, &imageEditRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageEditRequest.Model == "" {
|
||||||
|
imageEditRequest.Model = "dall-e-2"
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageEditRequest.Size == "" {
|
||||||
|
imageEditRequest.Size = "1024x1024"
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
|
||||||
|
imageEditRequest.Model = modelMap[imageEditRequest.Model]
|
||||||
|
isModelMapped = true
|
||||||
|
}
|
||||||
|
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
quotaInfo.modelName = imageEditRequest.Model
|
||||||
|
quotaInfo.promptTokens = promptTokens
|
||||||
|
quotaInfo.initQuotaInfo(group)
|
||||||
|
quota_err := quotaInfo.preQuotaConsumption()
|
||||||
|
if quota_err != nil {
|
||||||
|
return nil, quota_err
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageType == "edit" {
|
||||||
|
return imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
|
||||||
}
|
}
|
||||||
|
@ -244,6 +244,10 @@ func Relay(c *gin.Context) {
|
|||||||
relayMode = common.RelayModeAudioTranslation
|
relayMode = common.RelayModeAudioTranslation
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
relayMode = common.RelayModeImagesGenerations
|
relayMode = common.RelayModeImagesGenerations
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
|
relayMode = common.RelayModeImagesEdits
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/variations") {
|
||||||
|
relayMode = common.RelayModeImagesVariations
|
||||||
}
|
}
|
||||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||||
// relayMode = RelayModeEdits
|
// relayMode = RelayModeEdits
|
||||||
|
@ -11,10 +11,35 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ModelRequestInterface interface {
|
||||||
|
GetModel() string
|
||||||
|
SetModel(string)
|
||||||
|
}
|
||||||
|
|
||||||
type ModelRequest struct {
|
type ModelRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ModelRequest) GetModel() string {
|
||||||
|
return m.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelRequest) SetModel(model string) {
|
||||||
|
m.Model = model
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelFormRequest struct {
|
||||||
|
Model string `form:"model"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFormRequest) GetModel() string {
|
||||||
|
return m.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelFormRequest) SetModel(model string) {
|
||||||
|
m.Model = model
|
||||||
|
}
|
||||||
|
|
||||||
func Distribute() func(c *gin.Context) {
|
func Distribute() func(c *gin.Context) {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
@ -39,35 +64,36 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var modelRequest ModelRequest
|
modelRequest := getModelRequest(c)
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, modelRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.GetModel() == "" {
|
||||||
modelRequest.Model = "text-moderation-stable"
|
modelRequest.SetModel("text-moderation-stable")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.GetModel() == "" {
|
||||||
modelRequest.Model = c.Param("model")
|
modelRequest.SetModel(c.Param("model"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.GetModel() == "" {
|
||||||
modelRequest.Model = "dall-e-2"
|
modelRequest.SetModel("dall-e-2")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.GetModel() == "" {
|
||||||
modelRequest.Model = "whisper-1"
|
modelRequest.SetModel("whisper-1")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.GetModel())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.GetModel())
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
@ -94,3 +120,14 @@ func Distribute() func(c *gin.Context) {
|
|||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getModelRequest(c *gin.Context) (modelRequest ModelRequestInterface) {
|
||||||
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
|
modelRequest = &ModelRequest{}
|
||||||
|
} else if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
modelRequest = &ModelFormRequest{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -146,7 +146,7 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
|||||||
return p.Moderation != ""
|
return p.Moderation != ""
|
||||||
case common.RelayModeImagesGenerations:
|
case common.RelayModeImagesGenerations:
|
||||||
return p.ImagesGenerations != ""
|
return p.ImagesGenerations != ""
|
||||||
case common.RelayModeImagesEdit:
|
case common.RelayModeImagesEdits:
|
||||||
return p.ImagesEdit != ""
|
return p.ImagesEdit != ""
|
||||||
case common.RelayModeImagesVariations:
|
case common.RelayModeImagesVariations:
|
||||||
return p.ImagesVariations != ""
|
return p.ImagesVariations != ""
|
||||||
|
@ -56,11 +56,23 @@ type TranslationInterface interface {
|
|||||||
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 图片生成接口
|
||||||
type ImageGenerationsInterface interface {
|
type ImageGenerationsInterface interface {
|
||||||
ProviderInterface
|
ProviderInterface
|
||||||
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 图片编辑接口
|
||||||
|
type ImageEditsInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageVariationsInterface interface {
|
||||||
|
ProviderInterface
|
||||||
|
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
// 余额接口
|
// 余额接口
|
||||||
type BalanceInterface interface {
|
type BalanceInterface interface {
|
||||||
BalanceAction(channel *model.Channel) (float64, error)
|
BalanceAction(channel *model.Channel) (float64, error)
|
||||||
|
@ -39,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
|||||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||||
AudioTranslations: "/v1/audio/translations",
|
AudioTranslations: "/v1/audio/translations",
|
||||||
ImagesGenerations: "/v1/images/generations",
|
ImagesGenerations: "/v1/images/generations",
|
||||||
ImagesEdit: "/v1/images/edit",
|
ImagesEdit: "/v1/images/edits",
|
||||||
ImagesVariations: "/v1/images/variations",
|
ImagesVariations: "/v1/images/variations",
|
||||||
Context: c,
|
Context: c,
|
||||||
},
|
},
|
||||||
|
104
providers/openai/image_edits.go
Normal file
104
providers/openai/image_edits.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
|
||||||
|
var formBody bytes.Buffer
|
||||||
|
var req *http.Request
|
||||||
|
var err error
|
||||||
|
if isModelMapped {
|
||||||
|
builder := client.CreateFormBuilder(&formBody)
|
||||||
|
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||||
|
req.ContentLength = int64(formBody.Len())
|
||||||
|
|
||||||
|
} else {
|
||||||
|
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||||
|
req.ContentLength = p.Context.Request.ContentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||||
|
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: promptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
|
||||||
|
err := b.CreateFormFile("image", request.Image)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating form file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = b.WriteField("prompt", request.Prompt)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing prompt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = b.WriteField("model", request.Model)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing model name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Mask != nil {
|
||||||
|
err = b.CreateFormFile("mask", request.Mask)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing format: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.ResponseFormat != "" {
|
||||||
|
err = b.WriteField("response_format", request.ResponseFormat)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing format: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.N != 0 {
|
||||||
|
err = b.WriteField("n", fmt.Sprintf("%.2f", request.N))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing temperature: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Size != "" {
|
||||||
|
err = b.WriteField("size", request.Size)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing language: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.User != "" {
|
||||||
|
err = b.WriteField("user", request.User)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing language: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Close()
|
||||||
|
}
|
49
providers/openai/image_variations.go
Normal file
49
providers/openai/image_variations.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||||
|
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
|
||||||
|
headers := p.GetRequestHeaders()
|
||||||
|
|
||||||
|
client := common.NewClient()
|
||||||
|
|
||||||
|
var formBody bytes.Buffer
|
||||||
|
var req *http.Request
|
||||||
|
var err error
|
||||||
|
if isModelMapped {
|
||||||
|
builder := client.CreateFormBuilder(&formBody)
|
||||||
|
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||||
|
req.ContentLength = int64(formBody.Len())
|
||||||
|
|
||||||
|
} else {
|
||||||
|
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||||
|
req.ContentLength = p.Context.Request.ContentLength
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||||
|
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||||
|
if errWithCode != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
usage = &types.Usage{
|
||||||
|
PromptTokens: promptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: promptTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -23,8 +23,8 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||||
relayV1Router.POST("/edits", controller.Relay)
|
relayV1Router.POST("/edits", controller.Relay)
|
||||||
relayV1Router.POST("/images/generations", controller.Relay)
|
relayV1Router.POST("/images/generations", controller.Relay)
|
||||||
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/edits", controller.Relay)
|
||||||
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
|
relayV1Router.POST("/images/variations", controller.Relay)
|
||||||
relayV1Router.POST("/embeddings", controller.Relay)
|
relayV1Router.POST("/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||||
|
Loading…
Reference in New Issue
Block a user