diff --git a/controller/channel-test.go b/controller/channel-test.go index 1974ef6e..ae4ab7d5 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,13 +5,14 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strconv" "sync" "time" + + "github.com/gin-gonic/gin" ) func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 381c6feb..53833108 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -6,12 +6,11 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - - "github.com/gin-gonic/gin" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) requestBody := c.Request.Body req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) diff --git a/controller/relay-image.go b/controller/relay-image.go index 998a7851..ccd52dce 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -6,12 +6,11 @@ import ( "encoding/json" "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - - "github.com/gin-gonic/gin" ) func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode isModelMapped = true } } - baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) var requestBody io.Reader if isModelMapped { jsonStr, err := json.Marshal(imageRequest) diff --git a/controller/relay-text.go b/controller/relay-text.go index db1ec3a2..109cf5a8 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - if channelType == common.ChannelTypeOpenAI { - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - } - } + fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) switch apiType { case APITypeOpenAI: if channelType == common.ChannelTypeAzure { diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 4775ec88..cf5d9b69 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr openAIErrorWithStatusCode.OpenAIError = textResponse.Error return } + +func getFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + if channelType == common.ChannelTypeOpenAI { + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + } + } + return fullRequestURL +}