diff --git a/controller/channel-test.go b/controller/channel-test.go index 1123af69..ae4ab7d5 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -9,7 +9,6 @@ import ( "one-api/common" "one-api/model" "strconv" - "strings" "sync" "time" @@ -42,25 +41,21 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai default: request.Model = "gpt-3.5-turbo" } - - baseURL := common.ChannelBaseURLs[channel.Type] - requestURL := "/v1/chat/completions" // 这是原始的请求URL路径 - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - - // 构建 fullRequestURL - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channel.Type == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + requestURL := common.ChannelBaseURLs[channel.Type] + if channel.Type == common.ChannelTypeAzure { + requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) + } else { + if channel.GetBaseURL() != "" { + requestURL = channel.GetBaseURL() } + requestURL += "/v1/chat/completions" } + jsonData, err := json.Marshal(request) if err != nil { return err, nil } - req, err := http.NewRequest("POST", fullRequestURL, bytes.NewBuffer(jsonData)) // 使用 fullRequestURL 替换 requestURL + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) if err != nil { return err, nil } diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 5beea15e..53833108 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -6,13 +6,11 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - "strings" - - "github.com/gin-gonic/gin" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -71,12 +69,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - } - } + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) diff --git a/controller/relay-image.go b/controller/relay-image.go index f297e0b7..ccd52dce 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -6,13 +6,11 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - "strings" - - "github.com/gin-gonic/gin" ) func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -67,12 +65,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - } - } + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) var requestBody io.Reader if isModelMapped { jsonStr, err := json.Marshal(imageRequest) diff --git a/controller/relay-text.go b/controller/relay-text.go index db1ec3a2..109cf5a8 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - } - } + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) switch apiType { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 4775ec88..cf5d9b69 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr openAIErrorWithStatusCode.OpenAIError = textResponse.Error return } + +func getFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + if channelType == common.ChannelTypeOpenAI { + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + } + } + return fullRequestURL +}