From cfb1e2ac5b28b033408c6814bc650917ff835a15 Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Fri, 17 Nov 2023 16:22:56 +0800 Subject: [PATCH] Added support for Text-to-Speech models and endpoints --- common/model-ratio.go | 4 + controller/model.go | 36 +++++++++ controller/relay-audio.go | 152 +++++++++++++++++++++++++++----------- controller/relay.go | 10 ++- middleware/distributor.go | 4 +- router/relay-router.go | 1 + 6 files changed, 161 insertions(+), 46 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 8f4be8c3..18a6e8bc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -33,6 +33,10 @@ var ModelRatio = map[string]float64{ "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, + "tts-1-1106": 7.5, + "tts-1-hd": 15, + "tts-1-hd-1106": 15, "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index 2a7dc538..0ac0ecd3 100644 --- a/controller/model.go +++ b/controller/model.go @@ -72,6 +72,42 @@ func init() { Root: "whisper-1", Parent: nil, }, + { + Id: "tts-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1", + Parent: nil, + }, + { + Id: "tts-1-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-1106", + Parent: nil, + }, + { + Id: "tts-1-hd", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd", + Parent: nil, + }, + { + Id: "tts-1-hd-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 53833108..717175a7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -6,11 +6,13 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" + "path" + + "github.com/gin-gonic/gin" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -22,31 +24,69 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode userId := c.GetInt("id") group := c.GetString("group") + // Get last path of request URL + // Example: v1/audio/speech -> speech + requestPath := path.Base(c.Request.URL.Path) // speech + + var ttsRequest TextToSpeechRequest + + if requestPath == "speech" { + // Read JSON + err := common.UnmarshalBodyReusable(c, &ttsRequest) + + // Check if JSON is valid + if err != nil { + return errorWrapper(err, "invalid_json", http.StatusBadRequest) + } + + audioModel = ttsRequest.Model + + // Check if text is too long 4096 + if len(ttsRequest.Input) > 4096 { + return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) + } + } + preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) + if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + + quota := 0 + + // Check if user quota is enough + if requestPath == "speech" { + quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) + + fmt.Print(len(ttsRequest.Input), quota) + + if quota > userQuota { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + } else { + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) + } } } @@ -93,30 +133,6 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var audioResponse AudioResponse - - defer func(ctx context.Context) { - go func() { - quota := countTokenText(audioResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) responseBody, err := io.ReadAll(resp.Body) @@ -127,9 +143,59 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + + if requestPath == "speech" { + defer func(ctx context.Context) { + go func(quota int) { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }(quota) + }(c.Request.Context()) + } else { + var whisperResponse WhisperResponse + + defer func(ctx context.Context) { + go func() { + quota := countTokenText(whisperResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) + } + }() + }(c.Request.Context()) + + err = json.Unmarshal(responseBody, &whisperResponse) + + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/controller/relay.go b/controller/relay.go index 1926110e..f89bbd1f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -83,10 +83,18 @@ type ImageRequest struct { Size string `json:"size"` } -type AudioResponse struct { +type WhisperResponse struct { Text string `json:"text,omitempty"` } +type TextToSpeechRequest struct { + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed int `json:"speed"` + ReponseFormat string `json:"response_format"` +} + type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` diff --git a/middleware/distributor.go b/middleware/distributor.go index d80945fc..17137eea 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -41,7 +41,7 @@ func Distribute() func(c *gin.Context) { // Select a channel for the user var modelRequest ModelRequest var err error - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { err = common.UnmarshalBodyReusable(c, &modelRequest) } if err != nil { @@ -63,7 +63,7 @@ func Distribute() func(c *gin.Context) { modelRequest.Model = "dall-e" } } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } diff --git a/router/relay-router.go b/router/relay-router.go index e84f02db..912f4989 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)