diff --git a/controller/relay-image.go b/controller/relay-image.go index 1d1b71ba..ed092679 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" "github.com/gin-gonic/gin" ) @@ -32,6 +33,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") + apiVersion := c.GetString("api_version") userId := c.GetInt("id") consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") @@ -104,8 +106,15 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + + // make Azure channel requestURL + if channelType == common.ChannelTypeAzure { + // url https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview + fullRequestURL = fmt.Sprintf("%s/%s%s?api-version=%s", fullRequestURL, imageModel, "/images/generations", apiVersion) + } + var requestBody io.Reader - if isModelMapped { + if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) @@ -130,7 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + token := c.Request.Header.Get("Authorization") + if channelType == common.ChannelTypeAzure { // Azure authentication + token = strings.TrimPrefix(token, "Bearer ") + req.Header.Set("api-key", token) + } else { + req.Header.Set("Authorization", token) + } req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index c7cd4766..4765fdc1 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -192,6 +192,11 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin } } + // Azure dall-e model compatibility + if channelType == common.ChannelTypeAzure && requestURL == "/v1/images/generations" { + fullRequestURL = fmt.Sprintf("%s%s", baseURL, "/openai/deployments") + } + return fullRequestURL }