feat: add cloudflare ai gateway support for image & audio (#607)

* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
vc 2023-10-22 17:50:52 +08:00 committed by GitHub
parent 22980b4c44
commit 3b483639a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 18 deletions

View File

@ -5,13 +5,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {

View File

@ -6,12 +6,11 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin"
) )
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body requestBody := c.Request.Body
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)

View File

@ -6,12 +6,11 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"github.com/gin-gonic/gin"
) )
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true isModelMapped = true
} }
} }
baseURL := common.ChannelBaseURLs[channelType] baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String() requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader var requestBody io.Reader
if isModelMapped { if isModelMapped {
jsonStr, err := json.Marshal(imageRequest) jsonStr, err := json.Marshal(imageRequest)

View File

@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if c.GetString("base_url") != "" { if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
switch apiType { switch apiType {
case APITypeOpenAI: case APITypeOpenAI:
if channelType == common.ChannelTypeAzure { if channelType == common.ChannelTypeAzure {

View File

@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return return
} }
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}