diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..bebf26b5 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(errors.New(fmt.Sprintf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message)), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJson(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSrt(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJson(responseBody) + case "vtt": + text, err = getTextFromVtt(responseBody) + default: + return errorWrapper(errors.New("not_support_response_format"), "not_support_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVtt(body []byte) (string, error) { + return getTextFromSrt(body) +} + +func getTextFromVerboseJson(body []byte) (string, error) { + var whisperResponse WhisperVerboseJsonResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSrt(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJson(body []byte) (string, error) { + var whisperResponse WhisperJsonResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay.go b/controller/relay.go index 58ee8381..9d410fe8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJsonResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJsonResponse struct { + Task string `json:"task"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Text string `json:"text"` + Segments []Segment `json:"segments"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"`