refactor: update implementation

This commit is contained in:
JustSong 2023-12-03 17:33:24 +08:00
parent aa904af903
commit a64dc6a909
4 changed files with 16 additions and 23 deletions

View File

@ -98,12 +98,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { if relayMode == RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
query := c.Request.URL.Query() apiVersion := GetAPIVersion(c)
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
baseURL = c.GetString("base_url")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} }

View File

@ -33,7 +33,6 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
apiVersion := c.GetString("api_version")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
@ -103,11 +102,11 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")
} }
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure && relayMode == RelayModeImagesGenerations {
// make Azure channel requestURL // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
if channelType == common.ChannelTypeAzure { apiVersion := GetAPIVersion(c)
// url https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
fullRequestURL = fmt.Sprintf("%s/%s%s?api-version=%s", fullRequestURL, imageModel, "/images/generations", apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
} }
var requestBody io.Reader var requestBody io.Reader

View File

@ -129,11 +129,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case APITypeOpenAI: case APITypeOpenAI:
if channelType == common.ChannelTypeAzure { if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query() apiVersion := GetAPIVersion(c)
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0] requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url") baseURL = c.GetString("base_url")

View File

@ -191,12 +191,6 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
} }
} }
// Azure dall-e model compatibility
if channelType == common.ChannelTypeAzure && requestURL == "/v1/images/generations" {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, "/openai/deployments")
}
return fullRequestURL return fullRequestURL
} }
@ -221,3 +215,12 @@ func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
} }
} }
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
return apiVersion
}