refactor: update implementation
This commit is contained in:
parent
aa904af903
commit
a64dc6a909
@ -98,12 +98,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||||
query := c.Request.URL.Query()
|
apiVersion := GetAPIVersion(c)
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
baseURL = c.GetString("base_url")
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
apiVersion := c.GetString("api_version")
|
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
|
||||||
@ -103,11 +102,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
|
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
|
||||||
// make Azure channel requestURL
|
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||||
if channelType == common.ChannelTypeAzure {
|
apiVersion := GetAPIVersion(c)
|
||||||
// url https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||||
fullRequestURL = fmt.Sprintf("%s/%s%s?api-version=%s", fullRequestURL, imageModel, "/images/generations", apiVersion)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
|
@ -129,11 +129,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if channelType == common.ChannelTypeAzure {
|
if channelType == common.ChannelTypeAzure {
|
||||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||||
query := c.Request.URL.Query()
|
apiVersion := GetAPIVersion(c)
|
||||||
apiVersion := query.Get("api-version")
|
|
||||||
if apiVersion == "" {
|
|
||||||
apiVersion = c.GetString("api_version")
|
|
||||||
}
|
|
||||||
requestURL := strings.Split(requestURL, "?")[0]
|
requestURL := strings.Split(requestURL, "?")[0]
|
||||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
|
@ -191,12 +191,6 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
|||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Azure dall-e model compatibility
|
|
||||||
if channelType == common.ChannelTypeAzure && requestURL == "/v1/images/generations" {
|
|
||||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, "/openai/deployments")
|
|
||||||
}
|
|
||||||
|
|
||||||
return fullRequestURL
|
return fullRequestURL
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,3 +215,12 @@ func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
|
|||||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetAPIVersion(c *gin.Context) string {
|
||||||
|
query := c.Request.URL.Query()
|
||||||
|
apiVersion := query.Get("api-version")
|
||||||
|
if apiVersion == "" {
|
||||||
|
apiVersion = c.GetString("api_version")
|
||||||
|
}
|
||||||
|
return apiVersion
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user