From 6c5ef2d7f7d4408b7734a944cd40350bb9750976 Mon Sep 17 00:00:00 2001 From: igophper Date: Wed, 6 Dec 2023 00:50:18 +0800 Subject: [PATCH] feat: Refactor response parsing logic to support multiple formats The parsing logic for responses in relay.go and relay-audio.go was refactored to support multiple response formats - 'json', 'text', 'srt', 'verbose_json', and 'vtt'. The existing `WhisperResponse` struct was renamed to `WhisperJsonResponse` and a new struct `WhisperVerboseJsonResponse` was added to support the 'verbose_json' format. Additional parsing functions were added to extract text from these new response types. This change was necessary to make the parsing logic more flexible and extendable for different types of responses. --- controller/relay-audio.go | 85 ++++++++++++++++++++++++++++++++++++--- controller/relay.go | 23 ++++++++++- 2 files changed, 101 insertions(+), 7 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 9e78dadc..bebf26b5 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -1,6 +1,7 @@ package controller import ( + "bufio" "bytes" "context" "encoding/json" @@ -102,7 +103,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) } - requestBody := c.Request.Body + requestBody := &bytes.Buffer{} + _, err = io.Copy(requestBody, c.Request.Body) + if err != nil { + return errorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -144,12 +151,33 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - var whisperResponse WhisperResponse - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + var openAIErr TextResponse + if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + if openAIErr.Error.Message != "" { + return errorWrapper(errors.New(fmt.Sprintf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message)), "request_error", http.StatusInternalServerError) + } } - quota = countTokenText(whisperResponse.Text, audioModel) + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJson(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSrt(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJson(responseBody) + case "vtt": + text, err = getTextFromVtt(responseBody) + default: + return errorWrapper(errors.New("not_support_response_format"), "not_support_response_format", http.StatusInternalServerError) + } + if err != nil { + return errorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + } + quota = countTokenText(text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { @@ -187,3 +215,48 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } return nil } + +func getTextFromVtt(body []byte) (string, error) { + return getTextFromSrt(body) +} + +func getTextFromVerboseJson(body []byte) (string, error) { + var whisperResponse WhisperVerboseJsonResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSrt(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJson(body []byte) (string, error) { + var whisperResponse WhisperJsonResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/controller/relay.go b/controller/relay.go index 58ee8381..9d410fe8 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -141,10 +141,31 @@ type ImageRequest struct { User string `json:"user,omitempty"` } -type WhisperResponse struct { +type WhisperJsonResponse struct { Text string `json:"text,omitempty"` } +type WhisperVerboseJsonResponse struct { + Task string `json:"task"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Text string `json:"text"` + Segments []Segment `json:"segments"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + type TextToSpeechRequest struct { Model string `json:"model" binding:"required"` Input string `json:"input" binding:"required"`