diff --git a/controller/relay.go b/controller/relay.go index 932e023b..dadec293 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -29,7 +29,7 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: - fallthrough + err = controller.RelayAudioSpeechHelper(c) case relaymode.AudioTranslation: fallthrough case relaymode.AudioTranscription: diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index 42d49c0a..4269e39a 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -45,6 +45,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index 4aa8a11a..787855b6 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -1,9 +1,11 @@ package ali import ( + "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -76,7 +78,19 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return aliRequest, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + aliRequest := ConvertTextToSpeechRequest(*request) + return aliRequest, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + if meta.Mode == relaymode.AudioSpeech { + return a.DoWSSRequest(c, meta, requestBody) + } return adaptor.DoRequestHelper(a, c, meta, requestBody) } @@ -89,6 +103,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met err, usage = EmbeddingHandler(c, resp) case relaymode.ImagesGenerations: err, usage = ImageHandler(c, resp) + case relaymode.AudioSpeech: + err, usage = AudioSpeechHandler(c, resp) default: err, usage = Handler(c, resp) } @@ -103,3 +119,74 @@ func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetChannelName() string { return "ali" } + +func (a *Adaptor) DoWSSRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + baseURL := "wss://dashscope.aliyuncs.com/api-ws/v1/inference" + var usage Usage + // Create an empty http.Response object + response := &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(nil), + } + + conn, _, err := websocket.DefaultDialer.Dial(baseURL, http.Header{"Authorization": {"Bearer " + meta.APIKey}}) + if err != nil { + return response, errors.New("ali_wss_conn_failed") + } + defer conn.Close() + + var requestBodyBytes []byte + requestBodyBytes, err = io.ReadAll(requestBody) + if err != nil { + return response, errors.New("ali_failed_to_read_request_body") + } + + // Convert JSON strings to map[string]interface{} + var requestBodyMap map[string]interface{} + err = json.Unmarshal(requestBodyBytes, &requestBodyMap) + if err != nil { + return response, errors.New("ali_failed_to_parse_request_body") + } + + if err := conn.WriteJSON(requestBodyMap); err != nil { + return response, errors.New("ali_wss_write_msg_failed") + } + + const chunkSize = 1024 + + for { + messageType, audioData, err := conn.ReadMessage() + if err != nil { + if err == io.EOF { + break + } + return response, errors.New("ali_wss_read_msg_failed") + } + + var msg WSSMessage + switch messageType { + case websocket.TextMessage: + err = json.Unmarshal(audioData, &msg) + if msg.Header.Event == "task-finished" { + response.StatusCode = http.StatusOK + usage.TotalTokens = msg.Payload.Usage.Characters + return response, nil + } + case websocket.BinaryMessage: + for i := 0; i < len(audioData); i += chunkSize { + end := i + chunkSize + if end > len(audioData) { + end = len(audioData) + } + chunk := audioData[i:end] + + _, writeErr := c.Writer.Write(chunk) + if writeErr != nil { + return response, errors.New("wss_write_chunk_failed") + } + } + } + } + + return response, nil +} diff --git a/relay/adaptor/ali/audio-speech.go b/relay/adaptor/ali/audio-speech.go new file mode 100644 index 00000000..7dcd976d --- /dev/null +++ b/relay/adaptor/ali/audio-speech.go @@ -0,0 +1,21 @@ +package ali + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "net/http" +) + +func AudioSpeechHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/adaptor/ali/constants.go b/relay/adaptor/ali/constants.go index 3f24ce2e..c97c345e 100644 --- a/relay/adaptor/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -4,4 +4,48 @@ var ModelList = []string{ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "text-embedding-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", + + "sambert-zhichu-v1", + "sambert-zhiwei-v1", + "sambert-zhixiang-v1", + "sambert-zhide-v1", + "sambert-zhijia-v1", + "sambert-zhinan-v1", + "sambert-zhiqi-v1", + "sambert-zhiqian-v1", + "sambert-zhiru-v1", + "sambert-zhimiao-emo-v1", + "sambert-zhida-v1", + "sambert-zhifei-v1", + "sambert-zhigui-v1", + "sambert-zhihao-v1", + "sambert-zhijing-v1", + "sambert-zhilun-v1", + "sambert-zhimao-v1", + "sambert-zhiming-v1", + "sambert-zhimo-v1", + "sambert-zhina-v1", + "sambert-zhishu-v1", + "sambert-zhishuo-v1", + "sambert-zhistella-v1", + "sambert-zhiting-v1", + "sambert-zhixiao-v1", + "sambert-zhiya-v1", + "sambert-zhiye-v1", + "sambert-zhiying-v1", + "sambert-zhiyuan-v1", + "sambert-zhiyue-v1", + "sambert-camila-v1", + "sambert-perla-v1", + "sambert-indah-v1", + "sambert-clara-v1", + "sambert-hanna-v1", + "sambert-beth-v1", + "sambert-betty-v1", + "sambert-cally-v1", + "sambert-cindy-v1", + "sambert-eva-v1", + "sambert-donna-v1", + "sambert-brian-v1", + "sambert-waan-v1", } diff --git a/relay/adaptor/ali/main.go b/relay/adaptor/ali/main.go index 0462c26b..a054d4eb 100644 --- a/relay/adaptor/ali/main.go +++ b/relay/adaptor/ali/main.go @@ -4,6 +4,7 @@ import ( "bufio" "encoding/json" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -77,6 +78,37 @@ func ConvertImageRequest(request model.ImageRequest) *ImageRequest { return &imageRequest } +func ConvertTextToSpeechRequest(request model.TextToSpeechRequest) *WSSMessage { + var ttsRequest WSSMessage + ttsRequest.Header.Action = "run-task" + ttsRequest.Header.Streaming = "out" + ttsRequest.Header.TaskID = uuid.New().String() + ttsRequest.Payload.Function = "SpeechSynthesizer" + ttsRequest.Payload.Input.Text = request.Input + ttsRequest.Payload.Model = request.Model + ttsRequest.Payload.Parameters.Format = "wav" + //ttsRequest.Payload.Parameters.SampleRate = 48000 + ttsRequest.Payload.Parameters.Rate = 1.0 + ttsRequest.Payload.Task = "tts" + ttsRequest.Payload.TaskGroup = "audio" + + format := map[string]bool{ + "pcm": true, + "wav": true, + "mp3": true, + } + + if _, ok := format[request.ResponseFormat]; ok { + ttsRequest.Payload.Parameters.Format = request.ResponseFormat + } + + if 0.5 <= request.Speed && request.Speed <= 2 { + ttsRequest.Payload.Parameters.Rate = request.Speed + } + + return &ttsRequest +} + func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var aliResponse EmbeddingResponse err := json.NewDecoder(resp.Body).Decode(&aliResponse) diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index b1136e84..c59808a3 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -48,6 +48,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/aws/adapter.go b/relay/adaptor/aws/adapter.go index 7245d3d9..ecc2e02f 100644 --- a/relay/adaptor/aws/adapter.go +++ b/relay/adaptor/aws/adapter.go @@ -57,6 +57,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return nil, nil } diff --git a/relay/adaptor/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go index 15306b95..19af64dd 100644 --- a/relay/adaptor/baidu/adaptor.go +++ b/relay/adaptor/baidu/adaptor.go @@ -116,6 +116,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/cloudflare/adaptor.go b/relay/adaptor/cloudflare/adaptor.go index 6ff6b0d3..d3ae2b29 100644 --- a/relay/adaptor/cloudflare/adaptor.go +++ b/relay/adaptor/cloudflare/adaptor.go @@ -44,6 +44,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/cohere/adaptor.go b/relay/adaptor/cohere/adaptor.go index 6fdb1b04..523bf761 100644 --- a/relay/adaptor/cohere/adaptor.go +++ b/relay/adaptor/cohere/adaptor.go @@ -42,6 +42,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/coze/adaptor.go b/relay/adaptor/coze/adaptor.go index 44f560e8..95f84f01 100644 --- a/relay/adaptor/coze/adaptor.go +++ b/relay/adaptor/coze/adaptor.go @@ -45,6 +45,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/deepl/adaptor.go b/relay/adaptor/deepl/adaptor.go index d018a096..ff4efec4 100644 --- a/relay/adaptor/deepl/adaptor.go +++ b/relay/adaptor/deepl/adaptor.go @@ -46,6 +46,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go index 12f48c71..aff89287 100644 --- a/relay/adaptor/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -66,6 +66,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go index 01b2e2cb..f0af3f14 100644 --- a/relay/adaptor/interface.go +++ b/relay/adaptor/interface.go @@ -14,6 +14,7 @@ type Adaptor interface { SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) ConvertImageRequest(request *model.ImageRequest) (any, error) + ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) GetModelList() []string diff --git a/relay/adaptor/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go index 66702c5d..4c726d35 100644 --- a/relay/adaptor/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -55,6 +55,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index db569e4f..f6668c55 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -32,6 +32,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) return fullRequestURL, nil + } else if meta.Mode == relaymode.AudioTranscription { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil + } else if meta.Mode == relaymode.AudioSpeech { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) + return fullRequestURL, nil } // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api @@ -57,6 +65,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me adaptor.SetupCommonRequestHeader(c, req, meta) if meta.ChannelType == channeltype.Azure { req.Header.Set("api-key", meta.APIKey) + if meta.Mode == relaymode.AudioTranscription || meta.Mode == relaymode.AudioSpeech { + req.ContentLength = c.Request.ContentLength + } return nil } req.Header.Set("Authorization", "Bearer "+meta.APIKey) @@ -81,6 +92,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } @@ -100,6 +118,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.AudioSpeech: + err, _ = TextToSpeechHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } diff --git a/relay/adaptor/openai/audio.go b/relay/adaptor/openai/audio.go new file mode 100644 index 00000000..674f3e8d --- /dev/null +++ b/relay/adaptor/openai/audio.go @@ -0,0 +1,26 @@ +package openai + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func TextToSpeechHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var err error + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/adaptor/palm/adaptor.go b/relay/adaptor/palm/adaptor.go index 98aa3e18..0ecc6744 100644 --- a/relay/adaptor/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -43,6 +43,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go index 0de92d4a..0b95a9f7 100644 --- a/relay/adaptor/tencent/adaptor.go +++ b/relay/adaptor/tencent/adaptor.go @@ -65,6 +65,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/adaptor/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go index b5967f26..809dfc6b 100644 --- a/relay/adaptor/xunfei/adaptor.go +++ b/relay/adaptor/xunfei/adaptor.go @@ -46,6 +46,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return request, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go index 78b01fb3..5e4317f5 100644 --- a/relay/adaptor/zhipu/adaptor.go +++ b/relay/adaptor/zhipu/adaptor.go @@ -92,6 +92,13 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) return newRequest, nil } +func (a *Adaptor) ConvertTextToSpeechRequest(request *model.TextToSpeechRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return adaptor.DoRequestHelper(a, c, meta, requestBody) } diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index b1a8a5b4..dc056f4a 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -120,6 +120,50 @@ var ModelRatio = map[string]float64{ "ali-stable-diffusion-xl": 8, "ali-stable-diffusion-v1.5": 8, "wanx-v1": 8, + // https://help.aliyun.com/zh/dashscope/developer-reference/sambert-speech-synthesis-metered-billing?spm=a2c4g.11186623.0.0.12a52e5cvlyKYj + "sambert-zhichu-v1": 0.1 * RMB, // 1 RMB / 10K characters -> 0.1 RMB / 1K characters -> 0.1 RMB / 1K tokens + "sambert-zhiwei-v1": 0.1 * RMB, + "sambert-zhixiang-v1": 0.1 * RMB, + "sambert-zhide-v1": 0.1 * RMB, + "sambert-zhijia-v1": 0.1 * RMB, + "sambert-zhinan-v1": 0.1 * RMB, + "sambert-zhiqi-v1": 0.1 * RMB, + "sambert-zhiqian-v1": 0.1 * RMB, + "sambert-zhiru-v1": 0.1 * RMB, + "sambert-zhimiao-emo-v1": 0.1 * RMB, + "sambert-zhida-v1": 0.1 * RMB, + "sambert-zhifei-v1": 0.1 * RMB, + "sambert-zhigui-v1": 0.1 * RMB, + "sambert-zhihao-v1": 0.1 * RMB, + "sambert-zhijing-v1": 0.1 * RMB, + "sambert-zhilun-v1": 0.1 * RMB, + "sambert-zhimao-v1": 0.1 * RMB, + "sambert-zhiming-v1": 0.1 * RMB, + "sambert-zhimo-v1": 0.1 * RMB, + "sambert-zhina-v1": 0.1 * RMB, + "sambert-zhishu-v1": 0.1 * RMB, + "sambert-zhishuo-v1": 0.1 * RMB, + "sambert-zhistella-v1": 0.1 * RMB, + "sambert-zhiting-v1": 0.1 * RMB, + "sambert-zhixiao-v1": 0.1 * RMB, + "sambert-zhiya-v1": 0.1 * RMB, + "sambert-zhiye-v1": 0.1 * RMB, + "sambert-zhiying-v1": 0.1 * RMB, + "sambert-zhiyuan-v1": 0.1 * RMB, + "sambert-zhiyue-v1": 0.1 * RMB, + "sambert-camila-v1": 0.1 * RMB, + "sambert-perla-v1": 0.1 * RMB, + "sambert-indah-v1": 0.1 * RMB, + "sambert-clara-v1": 0.1 * RMB, + "sambert-hanna-v1": 0.1 * RMB, + "sambert-beth-v1": 0.1 * RMB, + "sambert-betty-v1": 0.1 * RMB, + "sambert-cally-v1": 0.1 * RMB, + "sambert-cindy-v1": 0.1 * RMB, + "sambert-eva-v1": 0.1 * RMB, + "sambert-donna-v1": 0.1 * RMB, + "sambert-brian-v1": 0.1 * RMB, + "sambert-waan-v1": 0.1 * RMB, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 8f9708d0..81f63e8d 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" @@ -26,6 +27,140 @@ import ( "strings" ) +func RelayAudioSpeechHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + audioModel := "tts-1" + + tokenId := c.GetInt(ctxkey.TokenId) + channelId := c.GetInt(ctxkey.ChannelId) + userId := c.GetInt(ctxkey.Id) + group := c.GetString(ctxkey.Group) + tokenName := c.GetString(ctxkey.TokenName) + + ttsRequest, err := getTextToSpeechRequest(c) + if err != nil { + logger.Errorf(ctx, "getTextToSpeechRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_tts_request", http.StatusBadRequest) + } + + audioModel = ttsRequest.Model + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + modelRatio := billingratio.GetModelRatio(audioModel) + groupRatio := billingratio.GetGroupRatio(group) + ratio := modelRatio * groupRatio + + preConsumedQuota := int64(float64(len(ttsRequest.Input)) * ratio) + quota := preConsumedQuota + + userQuota, err := model.CacheGetUserQuota(ctx, userId) + if err != nil { + return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) + } + + // Check if user quota is enough + if userQuota-preConsumedQuota < 0 { + return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } + } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + }() + + // map model name + modelMapping := c.GetString(ctxkey.ModelMapping) + if modelMapping != "" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[audioModel] != "" { + audioModel = modelMap[audioModel] + } + } + + var requestBody io.Reader + + switch meta.ChannelType { + case channeltype.Ali: + finalRequest, err := adaptor.ConvertTextToSpeechRequest(ttsRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + default: + requestBody = c.Request.Body + } + + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + if resp.StatusCode != http.StatusOK { + return RelayErrorHandler(resp) + } + succeed = true + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) diff --git a/relay/controller/helper.go b/relay/controller/helper.go index dccff486..030e21a5 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -58,6 +58,16 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e return imageRequest, nil } +func getTextToSpeechRequest(c *gin.Context) (*relaymodel.TextToSpeechRequest, error) { + ttsRequest := &relaymodel.TextToSpeechRequest{} + err := common.UnmarshalBodyReusable(c, ttsRequest) + if err != nil { + return nil, err + } + + return ttsRequest, nil +} + func isValidImageSize(model string, size string) bool { if model == "cogview-3" { return true diff --git a/relay/model/audio.go b/relay/model/audio.go new file mode 100644 index 00000000..7542036a --- /dev/null +++ b/relay/model/audio.go @@ -0,0 +1,9 @@ +package model + +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` +}