diff --git a/controller/relay-helper.go b/controller/relay-helper.go index 60934691..63b474b3 100644 --- a/controller/relay-helper.go +++ b/controller/relay-helper.go @@ -61,6 +61,8 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode usage, openAIErrorWithStatusCode = handleSpeech(c, provider, modelMap, quotaInfo, group) case common.RelayModeAudioTranscription: usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group) + case common.RelayModeAudioTranslation: + usage, openAIErrorWithStatusCode = handleTranslations(c, provider, modelMap, quotaInfo, group) default: return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest) } @@ -294,3 +296,37 @@ func handleTranscriptions(c *gin.Context, provider providers_base.ProviderInterf } return speechProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens) } + +func handleTranslations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) { + var audioRequest types.AudioRequest + isModelMapped := false + speechProvider, ok := provider.(providers_base.TranslationInterface) + if !ok { + return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented) + } + + err := common.UnmarshalBodyReusable(c, &audioRequest) + if err != nil { + return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + if audioRequest.File == nil { + fmt.Println(audioRequest) + return nil, types.ErrorWrapper(errors.New("field file is required"), "required_field_missing", http.StatusBadRequest) + } + + if modelMap != nil && modelMap[audioRequest.Model] != "" { + audioRequest.Model = modelMap[audioRequest.Model] + isModelMapped = true + } + promptTokens := 0 + + quotaInfo.modelName = audioRequest.Model + quotaInfo.promptTokens = promptTokens + quotaInfo.initQuotaInfo(group) + quota_err := quotaInfo.preQuotaConsumption() + if quota_err != nil { + return nil, quota_err + } + return speechProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens) +} diff --git a/controller/relay.go b/controller/relay.go index e6c8ff10..92a4db98 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -240,15 +240,14 @@ func Relay(c *gin.Context) { relayMode = common.RelayModeAudioSpeech } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { relayMode = common.RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + relayMode = common.RelayModeAudioTranslation } // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { // relayMode = RelayModeImagesGenerations // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { // relayMode = RelayModeEdits - // } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - // relayMode = RelayModeAudioTranslation - // } switch relayMode { // case RelayModeImagesGenerations: // err = relayImageHelper(c, relayMode) diff --git a/providers/base/interface.go b/providers/base/interface.go index 3269ad63..03a3338a 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -50,6 +50,11 @@ type TranscriptionsInterface interface { TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) } +type TranslationInterface interface { + ProviderInterface + TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) +} + // 余额接口 type BalanceInterface interface { BalanceAction(channel *model.Channel) (float64, error) diff --git a/providers/openai/translations.go b/providers/openai/translations.go new file mode 100644 index 00000000..59adacdf --- /dev/null +++ b/providers/openai/translations.go @@ -0,0 +1,60 @@ +package openai + +import ( + "bytes" + "net/http" + "one-api/common" + "one-api/types" +) + +func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) { + fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model) + headers := p.GetRequestHeaders() + + client := common.NewClient() + + var formBody bytes.Buffer + var req *http.Request + var err error + if isModelMapped { + builder := client.CreateFormBuilder(&formBody) + if err := audioMultipartForm(request, builder); err != nil { + return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError) + } + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType())) + req.ContentLength = int64(formBody.Len()) + + } else { + req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type"))) + req.ContentLength = p.Context.Request.ContentLength + } + + if err != nil { + return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + } + + var textResponse string + if hasJSONResponse(request) { + openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{} + errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true) + if errWithCode != nil { + return + } + textResponse = openAIProviderTranscriptionsResponse.Text + } else { + openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse) + errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true) + if errWithCode != nil { + return + } + textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat) + } + + completionTokens := common.CountTokenText(textResponse, request.Model) + usage = &types.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + return +}