From a013b1a166cf0ab8e3bb3b0e2d1119a320658edc Mon Sep 17 00:00:00 2001 From: Martial BE Date: Fri, 1 Dec 2023 10:54:07 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20transcriptions=20api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/client.go | 19 ++- common/form_builder.go | 14 ++- common/gin.go | 9 +- controller/relay-helper.go | 37 ++++++ controller/relay.go | 6 +- providers/ali/chat.go | 2 + providers/baidu/chat.go | 2 + providers/base/common.go | 2 + providers/base/interface.go | 6 + providers/claude/chat.go | 2 + providers/openai/base.go | 1 + providers/openai/interface.go | 10 -- providers/openai/transcriptions.go | 181 +++++++++++++++++++++++++++++ providers/openai/type.go | 11 ++ providers/palm/chat.go | 2 + providers/tencent/chat.go | 2 +- providers/zhipu/chat.go | 3 +- types/audio.go | 19 +++ 18 files changed, 304 insertions(+), 24 deletions(-) create mode 100644 providers/openai/transcriptions.go diff --git a/common/client.go b/common/client.go index 8d81f9e4..e7402378 100644 --- a/common/client.go +++ b/common/client.go @@ -27,13 +27,13 @@ func init() { type Client struct { requestBuilder RequestBuilder - createFormBuilder func(io.Writer) FormBuilder + CreateFormBuilder func(io.Writer) FormBuilder } func NewClient() *Client { return &Client{ requestBuilder: NewRequestBuilder(), - createFormBuilder: func(body io.Writer) FormBuilder { + CreateFormBuilder: func(body io.Writer) FormBuilder { return NewFormBuilder(body) }, } @@ -46,6 +46,10 @@ type requestOptions struct { type requestOption func(*requestOptions) +type Stringer interface { + GetString() *string +} + func WithBody(body any) requestOption { return func(args *requestOptions) { args.body = body @@ -60,6 +64,12 @@ func WithHeader(header map[string]string) requestOption { } } +func WithContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + type RequestError struct { HTTPStatusCode int Err error @@ -173,6 +183,11 @@ func DecodeResponse(body io.Reader, v any) error { if result, ok := v.(*string); ok { return DecodeString(body, result) } + + if stringer, ok := v.(Stringer); ok { + return DecodeString(body, stringer.GetString()) + } + return json.NewDecoder(body).Decode(v) } diff --git a/common/form_builder.go b/common/form_builder.go index a30e18ff..e8d13dd2 100644 --- a/common/form_builder.go +++ b/common/form_builder.go @@ -4,12 +4,11 @@ import ( "fmt" "io" "mime/multipart" - "os" "path" ) type FormBuilder interface { - CreateFormFile(fieldname string, file *os.File) error + CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error CreateFormFileReader(fieldname string, r io.Reader, filename string) error WriteField(fieldname, value string) error Close() error @@ -26,8 +25,15 @@ func NewFormBuilder(body io.Writer) *DefaultFormBuilder { } } -func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) error { - return fb.createFormFile(fieldname, file, file.Name()) +func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, fileHeader *multipart.FileHeader) error { + file, err := fileHeader.Open() + if err != nil { + return err + } + + defer file.Close() + + return fb.createFormFile(fieldname, file, fileHeader.Filename) } func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { diff --git a/common/gin.go b/common/gin.go index f5012688..fa58fc44 100644 --- a/common/gin.go +++ b/common/gin.go @@ -3,9 +3,10 @@ package common import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" "io" "strings" + + "github.com/gin-gonic/gin" ) func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -20,9 +21,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { err = json.Unmarshal(requestBody, &v) - } else { - // skip for now - // TODO: someday non json request have variant model, we will need to implementation this + } else if strings.HasPrefix(contentType, "multipart/form-data") { + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = c.ShouldBind(v) } if err != nil { return err diff --git a/controller/relay-helper.go b/controller/relay-helper.go index 8a9611c4..60934691 100644 --- a/controller/relay-helper.go +++ b/controller/relay-helper.go @@ -3,6 +3,7 @@ package controller import ( "context" "errors" + "fmt" "net/http" "one-api/common" "one-api/model" @@ -58,6 +59,8 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group) case common.RelayModeAudioSpeech: usage, openAIErrorWithStatusCode = handleSpeech(c, provider, modelMap, quotaInfo, group) + case common.RelayModeAudioTranscription: + usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group) default: return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) } @@ -257,3 +260,37 @@ func handleSpeech(c *gin.Context, provider providers_base.ProviderInterface, mod } return speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens) } + +func handleTranscriptions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var audioRequest types.AudioRequest + isModelMapped := false + speechProvider, ok := provider.(providers_base.TranscriptionsInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + + err := common.UnmarshalBodyReusable(c, &audioRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + if audioRequest.File == nil { + fmt.Println(audioRequest) + return nil, types.ErrorWrapper(errors.New("field file is required"), "required_field_missing", http.StatusBadRequest) + } + + if modelMap != nil && modelMap[audioRequest.Model] != "" { + audioRequest.Model = modelMap[audioRequest.Model] + isModelMapped = true + } + promptTokens := 0 + + quotaInfo.modelName = audioRequest.Model + quotaInfo.promptTokens = promptTokens + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return speechProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) +} diff --git a/controller/relay.go b/controller/relay.go index e526cb62..e6c8ff10 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -222,6 +222,7 @@ type CompletionsStreamResponse struct { } func Relay(c *gin.Context) { + defer c.Request.Body.Close() var err *types.OpenAIErrorWithStatusCode relayMode := common.RelayModeUnknown @@ -237,13 +238,14 @@ func Relay(c *gin.Context) { relayMode = common.RelayModeModerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { relayMode = common.RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + relayMode = common.RelayModeAudioTranscription } // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { // relayMode = RelayModeImagesGenerations // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { // relayMode = RelayModeEdits - // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - // relayMode = RelayModeAudioTranscription + // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { // relayMode = RelayModeAudioTranslation // } diff --git a/providers/ali/chat.go b/providers/ali/chat.go index 721bfe13..00e5c522 100644 --- a/providers/ali/chat.go +++ b/providers/ali/chat.go @@ -151,6 +151,8 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty // 发送流请求 func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + defer req.Body.Close() + usage = &types.Usage{} // 发送请求 resp, err := common.HttpClient.Do(req) diff --git a/providers/baidu/chat.go b/providers/baidu/chat.go index 8fc7dafe..e961f4bf 100644 --- a/providers/baidu/chat.go +++ b/providers/baidu/chat.go @@ -124,6 +124,8 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea } func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + defer req.Body.Close() + usage = &types.Usage{} // 发送请求 resp, err := common.HttpClient.Do(req) diff --git a/providers/base/common.go b/providers/base/common.go index a7b19104..ba6c13fd 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -54,6 +54,7 @@ func (p *BaseProvider) CommonRequestHeaders(headers map[string]string) { // 发送请求 func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseHandler, rawOutput bool) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + defer req.Body.Close() resp, openAIErrorWithStatusCode := common.SendRequest(req, response, true) if openAIErrorWithStatusCode != nil { @@ -95,6 +96,7 @@ func (p *BaseProvider) SendRequest(req *http.Request, response ProviderResponseH } func (p *BaseProvider) SendRequestRaw(req *http.Request) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode) { + defer req.Body.Close() // 发送请求 resp, err := common.HttpClient.Do(req) diff --git a/providers/base/interface.go b/providers/base/interface.go index 9981b1fc..3269ad63 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -44,6 +44,12 @@ type SpeechInterface interface { SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +// 语音转文字接口 +type TranscriptionsInterface interface { + ProviderInterface + TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + // 余额接口 type BalanceInterface interface { BalanceAction(channel *model.Channel) (float64, error) diff --git a/providers/claude/chat.go b/providers/claude/chat.go index bef595fc..02c7bc89 100644 --- a/providers/claude/chat.go +++ b/providers/claude/chat.go @@ -134,6 +134,8 @@ func (p *ClaudeProvider) streamResponseClaude2OpenAI(claudeResponse *ClaudeRespo } func (p *ClaudeProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + defer req.Body.Close() + // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { diff --git a/providers/openai/base.go b/providers/openai/base.go index f8be57c1..e10bed50 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -93,6 +93,7 @@ func (p *OpenAIProvider) getRequestBody(request any, isModelMapped bool) (reques // 发送流式请求 func (p *OpenAIProvider) sendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) { + defer req.Body.Close() resp, err := common.HttpClient.Do(req) if err != nil { diff --git a/providers/openai/interface.go b/providers/openai/interface.go index 1695be8c..1ddd3b87 100644 --- a/providers/openai/interface.go +++ b/providers/openai/interface.go @@ -1,15 +1,5 @@ package openai -import ( - "net/http" - "one-api/types" -) - -type OpenAIProviderResponseHandler interface { - // 请求处理函数 - responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) -} - type OpenAIProviderStreamResponseHandler interface { // 请求流处理函数 responseStreamHandler() (responseText string) diff --git a/providers/openai/transcriptions.go b/providers/openai/transcriptions.go new file mode 100644 index 00000000..f4218cce --- /dev/null +++ b/providers/openai/transcriptions.go @@ -0,0 +1,181 @@ +package openai + +import ( + "bufio" + "bytes" + "fmt" + "net/http" + "one-api/common" + "one-api/types" + "regexp" + "strconv" + "strings" +) + +func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + if c.Error.Type != "" { + errWithCode = &types.OpenAIErrorWithStatusCode{ + OpenAIError: c.Error, + StatusCode: resp.StatusCode, + } + return + } + return nil, nil +} + +func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) { + return nil, nil +} + +func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + + var formBody bytes.Buffer + var req *http.Request + var err error + if isModelMapped { + builder := client.CreateFormBuilder(&formBody) + if err := audioMultipartForm(request, builder); err != nil { + return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) + } + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + + } else { + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + } + + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + var textResponse string + if hasJSONResponse(request) { + openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{} + errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true) + if errWithCode != nil { + return + } + textResponse = openAIProviderTranscriptionsResponse.Text + } else { + openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse) + errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) + if errWithCode != nil { + return + } + textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat) + } + + completionTokens := common.CountTokenText(textResponse, request.Model) + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + return +} + +func hasJSONResponse(request *types.AudioRequest) bool { + return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json" +} + +func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error { + + err := b.CreateFormFile("file", request.File) + if err != nil { + return fmt.Errorf("creating form file: %w", err) + } + + err = b.WriteField("model", request.Model) + if err != nil { + return fmt.Errorf("writing model name: %w", err) + } + + if request.Prompt != "" { + err = b.WriteField("prompt", request.Prompt) + if err != nil { + return fmt.Errorf("writing prompt: %w", err) + } + } + + if request.ResponseFormat != "" { + err = b.WriteField("response_format", request.ResponseFormat) + if err != nil { + return fmt.Errorf("writing format: %w", err) + } + } + + if request.Temperature != 0 { + err = b.WriteField("temperature", fmt.Sprintf("%.2f", request.Temperature)) + if err != nil { + return fmt.Errorf("writing temperature: %w", err) + } + } + + if request.Language != "" { + err = b.WriteField("language", request.Language) + if err != nil { + return fmt.Errorf("writing language: %w", err) + } + } + + return b.Close() +} + +func getTextContent(text, format string) string { + switch format { + case "srt": + return extractTextFromSRT(text) + case "vtt": + return extractTextFromVTT(text) + default: + return text + } +} + +func extractTextFromVTT(vttContent string) string { + scanner := bufio.NewScanner(strings.NewReader(vttContent)) + re := regexp.MustCompile(`\d{2}:\d{2}:\d{2}\.\d{3} --> \d{2}:\d{2}:\d{2}\.\d{3}`) + text := []string{} + isStart := true + + for scanner.Scan() { + line := scanner.Text() + if isStart && strings.HasPrefix(line, "WEBVTT") { + isStart = false + continue + } + if !re.MatchString(line) && !isNumber(line) && line != "" { + text = append(text, line) + } + } + + return strings.Join(text, " ") +} + +func extractTextFromSRT(srtContent string) string { + scanner := bufio.NewScanner(strings.NewReader(srtContent)) + re := regexp.MustCompile(`\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}`) + text := []string{} + isContent := false + + for scanner.Scan() { + line := scanner.Text() + if re.MatchString(line) { + isContent = true + } else if line == "" { + isContent = false + } else if isContent { + text = append(text, line) + } + } + + return strings.Join(text, " ") +} + +func isNumber(s string) bool { + _, err := strconv.Atoi(s) + return err == nil +} diff --git a/providers/openai/type.go b/providers/openai/type.go index 544e3e5f..85051847 100644 --- a/providers/openai/type.go +++ b/providers/openai/type.go @@ -26,3 +26,14 @@ type OpenAIProviderModerationResponse struct { types.ModerationResponse types.OpenAIErrorResponse } + +type OpenAIProviderTranscriptionsResponse struct { + types.AudioResponse + types.OpenAIErrorResponse +} + +type OpenAIProviderTranscriptionsTextResponse string + +func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string { + return (*string)(a) +} diff --git a/providers/palm/chat.go b/providers/palm/chat.go index 158e8aad..772f3f9b 100644 --- a/providers/palm/chat.go +++ b/providers/palm/chat.go @@ -128,6 +128,8 @@ func (p *PalmProvider) streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) } func (p *PalmProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + defer req.Body.Close() + // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { diff --git a/providers/tencent/chat.go b/providers/tencent/chat.go index 46e05f8c..070b191b 100644 --- a/providers/tencent/chat.go +++ b/providers/tencent/chat.go @@ -140,6 +140,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC } func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) { + defer req.Body.Close() // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { @@ -208,6 +209,5 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr return false } }) - return nil, responseText } diff --git a/providers/zhipu/chat.go b/providers/zhipu/chat.go index 7254effe..02c149af 100644 --- a/providers/zhipu/chat.go +++ b/providers/zhipu/chat.go @@ -139,6 +139,8 @@ func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStrea } func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) { + defer req.Body.Close() + // 发送请求 resp, err := common.HttpClient.Do(req) if err != nil { @@ -221,6 +223,5 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError return false } }) - return nil, usage } diff --git a/types/audio.go b/types/audio.go index 075cbf7a..86e351ba 100644 --- a/types/audio.go +++ b/types/audio.go @@ -1,5 +1,7 @@ package types +import "mime/multipart" + type SpeechAudioRequest struct { Model string `json:"model"` Input string `json:"input"` @@ -7,3 +9,20 @@ type SpeechAudioRequest struct { ResponseFormat string `json:"response_format,omitempty"` Speed float64 `json:"speed,omitempty"` } + +type AudioRequest struct { + File *multipart.FileHeader `form:"file"` + Model string `form:"model"` + Language string `form:"language"` + Prompt string `form:"prompt"` + ResponseFormat string `form:"response_format"` + Temperature float32 `form:"temperature"` +} + +type AudioResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Segments any `json:"segments,omitempty"` + Text string `json:"text"` +}