diff --git a/controller/channel-test.go b/controller/channel-test.go index 3c6c8f43..c5d4d26c 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,14 +5,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" - "strings" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { @@ -43,16 +43,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) + requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { if channel.GetBaseURL() != "" { requestURL = channel.GetBaseURL() } - requestURL += "/v1/chat/completions" - } - // for Cloudflare AI gateway: https://github.com/songquanpeng/one-api/pull/639 - requestURL = strings.Replace(requestURL, "/v1/v1", "/v1", 1) + requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) + } jsonData, err := json.Marshal(request) if err != nil { return err, nil diff --git a/controller/relay-text.go b/controller/relay-text.go index a61c6f7c..4806229e 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -147,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0613") - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) + + requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) } case APITypeClaude: fullRequestURL = "https://api.anthropic.com/v1/complete" diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..e888efa7 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -3,13 +3,14 @@ package controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" "strconv" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkoukk/tiktoken-go" ) var stopFinishReason = "stop" @@ -179,10 +180,15 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr func getFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case common.ChannelTypeOpenAI: fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case common.ChannelTypeAzure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) } } + return fullRequestURL }