diff --git a/common/client.go b/common/client.go index e7402378..4a592c05 100644 --- a/common/client.go +++ b/common/client.go @@ -117,7 +117,7 @@ func SendRequest(req *http.Request, response any, outputResp bool) (*http.Respon // 将响应体重新写入 resp.Body resp.Body = io.NopCloser(&buf) } else { - err = DecodeResponse(resp.Body, nil) + err = DecodeResponse(resp.Body, response) } if err != nil { return nil, types.ErrorWrapper(err, "decode_response_failed", http.StatusInternalServerError) diff --git a/common/constants.go b/common/constants.go index 36e88769..52ac4395 100644 --- a/common/constants.go +++ b/common/constants.go @@ -222,6 +222,8 @@ const ( RelayModeEmbeddings RelayModeModerations RelayModeImagesGenerations + RelayModeImagesEdit + RelayModeImagesVariations RelayModeEdits RelayModeAudioSpeech RelayModeAudioTranscription diff --git a/common/token.go b/common/token.go index 5cac6f20..4a6a6fbb 100644 --- a/common/token.go +++ b/common/token.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "strings" @@ -107,3 +108,21 @@ func CountTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) } + +func CountTokenImage(imageRequest types.ImageRequest) (int, error) { + imageCostRatio, hasValidSize := DalleSizeRatios[imageRequest.Model][imageRequest.Size] + + if hasValidSize { + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + } else { + return 0, errors.New("size not supported for this image model") + } + + return int(imageCostRatio*1000) * imageRequest.N, nil +} diff --git a/controller/relay-helper.go b/controller/relay-helper.go index 63b474b3..94af4d64 100644 --- a/controller/relay-helper.go +++ b/controller/relay-helper.go @@ -63,6 +63,8 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group) case common.RelayModeAudioTranslation: usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group) + case common.RelayModeImagesGenerations: + usage, openAIErrorWithStatusCode = handleImageGenerations(c, provider, modelMap, quotaInfo, group) default: return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) } @@ -330,3 +332,47 @@ func handleTranslations(c *gin.Context, provider providers_base.ProviderInterfac } return speechProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) } + +func handleImageGenerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var imageRequest types.ImageRequest + isModelMapped := false + speechProvider, ok := provider.(providers_base.ImageGenerationsInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" + } + + if modelMap != nil && modelMap[imageRequest.Model] != "" { + imageRequest.Model = modelMap[imageRequest.Model] + isModelMapped = true + } + promptTokens, err := common.CountTokenImage(imageRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "count_token_image_failed", http.StatusInternalServerError) + } + + quotaInfo.modelName = imageRequest.Model + quotaInfo.promptTokens = promptTokens + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return speechProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens) +} diff --git a/controller/relay.go b/controller/relay.go index 92a4db98..c1bae21f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -242,9 +242,9 @@ func Relay(c *gin.Context) { relayMode = common.RelayModeAudioTranscription } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { relayMode = common.RelayModeAudioTranslation + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + relayMode = common.RelayModeImagesGenerations } - // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - // relayMode = RelayModeImagesGenerations // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { // relayMode = RelayModeEdits diff --git a/providers/azure/base.go b/providers/azure/base.go index 8a3def4e..1a2f0aaa 100644 --- a/providers/azure/base.go +++ b/providers/azure/base.go @@ -20,10 +20,13 @@ func CreateAzureProvider(c *gin.Context) *AzureProvider { Completions: "/completions", ChatCompletions: "/chat/completions", Embeddings: "/embeddings", - AudioSpeech: "/audio/speech", AudioTranscriptions: "/audio/transcriptions", AudioTranslations: "/audio/translations", - Context: c, + ImagesGenerations: "/images/generations", + // ImagesEdit: "/images/edit", + // ImagesVariations: "/images/variations", + Context: c, + // AudioSpeech: "/audio/speech", }, IsAzure: true, }, diff --git a/providers/azure/image_generations.go b/providers/azure/image_generations.go new file mode 100644 index 00000000..ed5c7378 --- /dev/null +++ b/providers/azure/image_generations.go @@ -0,0 +1,102 @@ +package azure + +import ( + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/providers/openai" + "one-api/types" + "time" +) + +func (c *ImageAzureResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + if c.Status == "canceled" || c.Status == "failed" { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: c.Error.Message, + Type: "one_api_error", + Code: c.Error.Code, + }, + StatusCode: resp.StatusCode, + } + return + } + + operation_location := resp.Header.Get("operation-location") + if operation_location == "" { + return nil, types.ErrorWrapper(errors.New("image url is empty"), "get_images_url_failed", http.StatusInternalServerError) + } + + client := common.NewClient() + req, err := client.NewRequest("GET", operation_location, common.WithHeader(c.Header)) + if err != nil { + return nil, types.ErrorWrapper(err, "get_images_request_failed", http.StatusInternalServerError) + } + + getImageAzureResponse := ImageAzureResponse{} + for i := 0; i < 3; i++ { + // 休眠 2 秒 + time.Sleep(2 * time.Second) + _, errWithCode = common.SendRequest(req, &getImageAzureResponse, false) + fmt.Println("getImageAzureResponse", getImageAzureResponse) + if errWithCode != nil { + return + } + + if getImageAzureResponse.Status == "canceled" || getImageAzureResponse.Status == "failed" { + return nil, &types.OpenAIErrorWithStatusCode{ + OpenAIError: types.OpenAIError{ + Message: c.Error.Message, + Type: "get_images_request_failed", + Code: c.Error.Code, + }, + StatusCode: resp.StatusCode, + } + } + if getImageAzureResponse.Status == "succeeded" { + return getImageAzureResponse.Result, nil + } + } + + return nil, types.ErrorWrapper(errors.New("get image Timeout"), "get_images_url_failed", http.StatusInternalServerError) +} + +func (p *AzureProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + + requestBody, err := p.GetRequestBody(&request, isModelMapped) + if err != nil { + return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + + fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + if request.Model == "dall-e-2" { + imageAzureResponse := &ImageAzureResponse{ + Header: headers, + } + errWithCode = p.SendRequest(req, imageAzureResponse, false) + } else { + openAIProviderImageResponseResponse := &openai.OpenAIProviderImageResponseResponse{} + errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) + } + + if errWithCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + + return +} diff --git a/providers/azure/type.go b/providers/azure/type.go new file mode 100644 index 00000000..7452fee1 --- /dev/null +++ b/providers/azure/type.go @@ -0,0 +1,21 @@ +package azure + +import "one-api/types" + +type ImageAzureResponse struct { + ID string `json:"id,omitempty"` + Created int64 `json:"created,omitempty"` + Expires int64 `json:"expires,omitempty"` + Result types.ImageResponse `json:"result,omitempty"` + Status string `json:"status,omitempty"` + Error ImageAzureError `json:"error,omitempty"` + Header map[string]string `json:"header,omitempty"` +} + +type ImageAzureError struct { + Code string `json:"code,omitempty"` + Target string `json:"target,omitempty"` + Message string `json:"message,omitempty"` + Details []string `json:"details,omitempty"` + InnerError any `json:"innererror,omitempty"` +} diff --git a/providers/base/common.go b/providers/base/common.go index ba6c13fd..e6cc49f8 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -23,6 +23,9 @@ type BaseProvider struct { Moderation string AudioTranscriptions string AudioTranslations string + ImagesGenerations string + ImagesEdit string + ImagesVariations string Proxy string Context *gin.Context } @@ -141,6 +144,12 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool { return p.AudioTranslations != "" case common.RelayModeModerations: return p.Moderation != "" + case common.RelayModeImagesGenerations: + return p.ImagesGenerations != "" + case common.RelayModeImagesEdit: + return p.ImagesEdit != "" + case common.RelayModeImagesVariations: + return p.ImagesVariations != "" default: return false } diff --git a/providers/base/interface.go b/providers/base/interface.go index 03a3338a..6c6bcfd4 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -50,11 +50,17 @@ type TranscriptionsInterface interface { TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +// 语音翻译接口 type TranslationInterface interface { ProviderInterface TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +type ImageGenerationsInterface interface { + ProviderInterface + ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + // 余额接口 type BalanceInterface interface { BalanceAction(channel *model.Channel) (float64, error) diff --git a/providers/openai/base.go b/providers/openai/base.go index e10bed50..cd34edb9 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -38,6 +38,9 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider { AudioSpeech: "/v1/audio/speech", AudioTranscriptions: "/v1/audio/transcriptions", AudioTranslations: "/v1/audio/translations", + ImagesGenerations: "/v1/images/generations", + ImagesEdit: "/v1/images/edit", + ImagesVariations: "/v1/images/variations", Context: c, }, IsAzure: false, @@ -50,7 +53,13 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) if p.IsAzure { apiVersion := p.Context.GetString("api_version") - requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) + if modelName == "dall-e-2" { + // 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本 + // 已经没有dall-e-2了,所以暂时写死 + requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL) + } else { + requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) + } } if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { @@ -78,7 +87,7 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) { } // 获取请求体 -func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) { +func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) { if isModelMapped { jsonStr, err := json.Marshal(request) if err != nil { diff --git a/providers/openai/chat.go b/providers/openai/chat.go index 41905f92..981c07c6 100644 --- a/providers/openai/chat.go +++ b/providers/openai/chat.go @@ -26,7 +26,7 @@ func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText } func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.getRequestBody(&request, isModelMapped) + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } diff --git a/providers/openai/completion.go b/providers/openai/completion.go index cee0256b..8e935da0 100644 --- a/providers/openai/completion.go +++ b/providers/openai/completion.go @@ -26,7 +26,7 @@ func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText } func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.getRequestBody(&request, isModelMapped) + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } diff --git a/providers/openai/embeddings.go b/providers/openai/embeddings.go index 6cc48d7f..67354d26 100644 --- a/providers/openai/embeddings.go +++ b/providers/openai/embeddings.go @@ -19,7 +19,7 @@ func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.getRequestBody(&request, isModelMapped) + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } diff --git a/providers/openai/image_generations.go b/providers/openai/image_generations.go new file mode 100644 index 00000000..f63fb337 --- /dev/null +++ b/providers/openai/image_generations.go @@ -0,0 +1,49 @@ +package openai + +import ( + "net/http" + "one-api/common" + "one-api/types" +) + +func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + if c.Error.Type != "" { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: c.Error, + StatusCode: resp.StatusCode, + } + return + } + return nil, nil +} + +func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + + requestBody, err := p.GetRequestBody(&request, isModelMapped) + if err != nil { + return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) + } + + fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers)) + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{} + errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true) + if errWithCode != nil { + return + } + + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: 0, + TotalTokens: promptTokens, + } + + return +} diff --git a/providers/openai/moderation.go b/providers/openai/moderation.go index 2eceb12d..a49641e8 100644 --- a/providers/openai/moderation.go +++ b/providers/openai/moderation.go @@ -19,7 +19,7 @@ func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.getRequestBody(&request, isModelMapped) + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } diff --git a/providers/openai/speech.go b/providers/openai/speech.go index 4968387f..410c6c74 100644 --- a/providers/openai/speech.go +++ b/providers/openai/speech.go @@ -8,7 +8,7 @@ import ( func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { - requestBody, err := p.getRequestBody(&request, isModelMapped) + requestBody, err := p.GetRequestBody(&request, isModelMapped) if err != nil { return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } diff --git a/providers/openai/type.go b/providers/openai/type.go index 85051847..7153739f 100644 --- a/providers/openai/type.go +++ b/providers/openai/type.go @@ -37,3 +37,8 @@ type OpenAIProviderTranscriptionsTextResponse string func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string { return (*string)(a) } + +type OpenAIProviderImageResponseResponse struct { + types.ImageResponse + types.OpenAIErrorResponse +} diff --git a/types/image.go b/types/image.go index a3254769..6cb07887 100644 --- a/types/image.go +++ b/types/image.go @@ -1,5 +1,7 @@ package types +import "mime/multipart" + type ImageRequest struct { Prompt string `json:"prompt,omitempty"` Model string `json:"model,omitempty"` @@ -21,3 +23,14 @@ type ImageResponseDataInner struct { B64JSON string `json:"b64_json,omitempty"` RevisedPrompt string `json:"revised_prompt,omitempty"` } + +type ImageEditRequest struct { + Image *multipart.FileHeader `form:"image"` + Mask *multipart.FileHeader `form:"mask"` + Model string `form:"model"` + Prompt string `form:"prompt"` + N int `form:"n"` + Size string `form:"size"` + ResponseFormat string `form:"response_format"` + User string `form:"user"` +}