feat: support cloudflare gateway for azure (#666)

* 🐛 Fix cloudflare gateway request failure

* 🐛 fix channel test url error
This commit is contained in:
Buer 2023-11-19 15:52:35 +08:00 committed by GitHub
parent 34d517cfa2
commit 54e5f8ecd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 8 deletions

View File

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
@ -13,6 +12,8 @@ import (
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
@ -43,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
} }
requestURL := common.ChannelBaseURLs[channel.Type] requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure { if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
} else { } else {
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
requestURL = baseURL requestURL = baseURL
} }
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
} }
jsonData, err := json.Marshal(request) jsonData, err := json.Marshal(request)
if err != nil { if err != nil {
return err, nil return err, nil

View File

@ -147,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0301") model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314") model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613") model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
} }
case APITypeClaude: case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete" fullRequestURL = "https://api.anthropic.com/v1/complete"

View File

@ -4,14 +4,15 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/pkoukk/tiktoken-go"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@ -181,11 +182,16 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
func getFullRequestURL(baseURL string, requestURL string, channelType int) string { func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case common.ChannelTypeOpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case common.ChannelTypeAzure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
} }
} }
return fullRequestURL return fullRequestURL
} }