diff --git a/controller/relay-image.go b/controller/relay-image.go index 1d1b71ba..6c590e37 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -10,6 +10,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" "github.com/gin-gonic/gin" ) @@ -104,6 +105,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) + var requestBody io.Reader if isModelMapped { jsonStr, err := json.Marshal(imageRequest) @@ -127,10 +129,36 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + + if channelType == common.ChannelTypeAzure { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString("api_version") + } + + params := fmt.Sprintf("?api-version=%s", apiVersion) + baseURL = c.GetString("base_url") + + if imageModel == "dall-e-2" { + fullRequestURL = fmt.Sprintf("%s/openai/images/generations:submit%s", baseURL, params) + } else { + fullRequestURL = fmt.Sprintf("%s/openai/deployments/dalle3/images/generations%s", baseURL, params) + } + + req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + + req.Header.Set("api-key", apiKey) + } else { + req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + } + if err != nil { return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) @@ -148,7 +176,50 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse ImageResponse + + var textResponse OpenAIImageResponse + + contentLength := resp.ContentLength + + if consumeQuota { + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + + if channelType == common.ChannelTypeAzure && imageModel == "dall-e-2" { + var azureDalle2Response AzureDalle2Response + err = json.Unmarshal(responseBody, &azureDalle2Response) + + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + // Rerite response body + textResponse = OpenAIImageResponse{ + Created: azureDalle2Response.Created, + Data: azureDalle2Response.Result.Data, + } + + responseBody, err = json.Marshal(textResponse) + + // Fix transfer closed with ... bytes remaining to read + contentLength = int64(len(responseBody)) + } else { + err = json.Unmarshal(responseBody, &textResponse) + } + + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + } defer func(ctx context.Context) { if consumeQuota { @@ -171,28 +242,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) + c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", contentLength)) _, err = io.Copy(c.Writer, resp.Body) if err != nil { diff --git a/controller/relay.go b/controller/relay.go index 863267b4..3761085b 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -154,13 +154,27 @@ type OpenAIEmbeddingResponse struct { Usage `json:"usage"` } -type ImageResponse struct { +type OpenAIImageResponse struct { Created int `json:"created"` Data []struct { Url string `json:"url"` + } `json:"data"` +} + +type AzureDalle2ResultData struct { + Data []struct { + Url string `json:"url"` } } +type AzureDalle2Response struct { + Created int `json:"created"` + Expires int `json:"expires"` + ID string `json:"id"` + Result AzureDalle2ResultData `json:"result"` + Status string `json:"status"` +} + type ChatCompletionsStreamResponseChoice struct { Delta struct { Content string `json:"content"`