chore: using a util function
This commit is contained in:
parent
5135c9f711
commit
fed5cdba5b
@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -42,25 +41,21 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
default:
|
default:
|
||||||
request.Model = "gpt-3.5-turbo"
|
request.Model = "gpt-3.5-turbo"
|
||||||
}
|
}
|
||||||
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
requestURL := "/v1/chat/completions" // 这是原始的请求URL路径
|
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
||||||
|
} else {
|
||||||
if channel.GetBaseURL() != "" {
|
if channel.GetBaseURL() != "" {
|
||||||
baseURL = channel.GetBaseURL()
|
requestURL = channel.GetBaseURL()
|
||||||
|
}
|
||||||
|
requestURL += "/v1/chat/completions"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建 fullRequestURL
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
if channel.Type == common.ChannelTypeOpenAI {
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(request)
|
jsonData, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("POST", fullRequestURL, bytes.NewBuffer(jsonData)) // 使用 fullRequestURL 替换 requestURL
|
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
@ -6,13 +6,11 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
@ -71,12 +69,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if channelType == common.ChannelTypeOpenAI {
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
requestBody := c.Request.Body
|
requestBody := c.Request.Body
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
@ -6,13 +6,11 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
@ -67,12 +65,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if channelType == common.ChannelTypeOpenAI {
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if isModelMapped {
|
if isModelMapped {
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
jsonStr, err := json.Marshal(imageRequest)
|
||||||
|
@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if channelType == common.ChannelTypeOpenAI {
|
|
||||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if channelType == common.ChannelTypeAzure {
|
if channelType == common.ChannelTypeAzure {
|
||||||
|
@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
|||||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
if channelType == common.ChannelTypeOpenAI {
|
||||||
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fullRequestURL
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user