From b7570d5c772f5e4ba8aac4441f8ab9ec86b7aa55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ShinChven=20=E2=9C=A8?= Date: Sun, 3 Dec 2023 17:34:59 +0800 Subject: [PATCH] feat: support dalle for Azure (#754) * feat: Add Message-ID to email headers to comply with RFC 5322 - Extract domain from SMTPFrom - Generate a unique Message-ID - Add Message-ID to email headers * chore: check slice length * feat: Add Azure compatibility for relayImageHelper - Handle Azure channel requestURL compatibility - Set api-key header for Azure channel authentication - Handle Azure channel request body fixes: https://github.com/songquanpeng/one-api/issues/751 * refactor: update implementation --------- Co-authored-by: JustSong --- controller/relay-audio.go | 7 +------ controller/relay-image.go | 18 ++++++++++++++++-- controller/relay-text.go | 6 +----- controller/relay-utils.go | 10 +++++++++- 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 5b8898a7..9e78dadc 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -98,12 +98,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } - baseURL = c.GetString("base_url") + apiVersion := GetAPIVersion(c) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } diff --git a/controller/relay-image.go b/controller/relay-image.go index 0ff18309..b3248fcc 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" "github.com/gin-gonic/gin" ) @@ -101,8 +102,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + apiVersion := GetAPIVersion(c) + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + } + var requestBody io.Reader - if isModelMapped { + if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -127,7 +135,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + token := c.Request.Header.Get("Authorization") + if channelType == common.ChannelTypeAzure { // Azure authentication + token = strings.TrimPrefix(token, "Bearer ") + req.Header.Set("api-key", token) + } else { + req.Header.Set("Authorization", token) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-text.go b/controller/relay-text.go index dd9e7153..a3e233d3 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -129,11 +129,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString("api_version") - } + apiVersion := GetAPIVersion(c) requestURL := strings.Split(requestURL, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) baseURL = c.GetString("base_url") diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 391f28b4..839d6ae5 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -191,7 +191,6 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } - return fullRequestURL } @@ -216,3 +215,12 @@ func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } } + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + return apiVersion +}