diff --git a/common/model-ratio.go b/common/model-ratio.go index 18a6e8bc..0e81157c 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -32,10 +32,10 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "tts-1": 7.5, + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // $0.015 / 1K characters "tts-1-1106": 7.5, - "tts-1-hd": 15, + "tts-1-hd": 15, // $0.030 / 1K characters "tts-1-hd-1106": 15, "davinci": 10, "curie": 10, diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 717175a7..01267fbf 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -5,14 +5,11 @@ import ( "context" "encoding/json" "errors" - "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/model" - "path" - - "github.com/gin-gonic/gin" ) func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { @@ -23,24 +20,17 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelId := c.GetInt("channel_id") userId := c.GetInt("id") group := c.GetString("group") - - // Get last path of request URL - // Example: v1/audio/speech -> speech - requestPath := path.Base(c.Request.URL.Path) // speech + tokenName := c.GetString("token_name") var ttsRequest TextToSpeechRequest - - if requestPath == "speech" { + if relayMode == RelayModeAudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) - // Check if JSON is valid if err != nil { return errorWrapper(err, "invalid_json", http.StatusBadRequest) } - audioModel = ttsRequest.Model - // Check if text is too long 4096 if len(ttsRequest.Input) > 4096 { return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) @@ -53,19 +43,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } quota := 0 - // Check if user quota is enough - if requestPath == "speech" { + if relayMode == RelayModeAudioSpeech { quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) - - fmt.Print(len(ttsRequest.Input), quota) - if quota > userQuota { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -134,72 +119,31 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - - if requestPath == "speech" { + if relayMode == RelayModeAudioSpeech { defer func(ctx context.Context) { - go func(quota int) { - err := model.PostConsumeTokenQuota(tokenId, quota) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }(quota) + go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) } else { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } var whisperResponse WhisperResponse - - defer func(ctx context.Context) { - go func() { - quota := countTokenText(whisperResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - err = json.Unmarshal(responseBody, &whisperResponse) - if err != nil { return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } + defer func(ctx context.Context) { + quota := countTokenText(whisperResponse.Text, audioModel) + quotaDelta := quota - preConsumedQuota + go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index cf5d9b69..888187cb 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -1,6 +1,7 @@ package controller import ( + "context" "encoding/json" "fmt" "github.com/gin-gonic/gin" @@ -8,6 +9,7 @@ import ( "io" "net/http" "one-api/common" + "one-api/model" "strconv" "strings" ) @@ -186,3 +188,20 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin } return fullRequestURL } + +func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.UpdateChannelUsedQuota(channelId, quota) + } +} diff --git a/controller/relay.go b/controller/relay.go index f89bbd1f..0832ea7f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -24,7 +24,9 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits - RelayModeAudio + RelayModeAudioSpeech + RelayModeAudioTranscription + RelayModeAudioTranslation ) // https://platform.openai.com/docs/api-reference/chat @@ -88,11 +90,11 @@ type WhisperResponse struct { } type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed int `json:"speed"` - ReponseFormat string `json:"response_format"` + Model string `json:"model" binding:"required"` + Input string `json:"input" binding:"required"` + Voice string `json:"voice" binding:"required"` + Speed float64 `json:"speed"` + ResponseFormat string `json:"response_format"` } type Usage struct { @@ -191,14 +193,22 @@ func Relay(c *gin.Context) { relayMode = RelayModeImagesGenerations } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - relayMode = RelayModeAudio + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + relayMode = RelayModeAudioSpeech + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") { + relayMode = RelayModeAudioTranscription + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") { + relayMode = RelayModeAudioTranslation } var err *OpenAIErrorWithStatusCode switch relayMode { case RelayModeImagesGenerations: err = relayImageHelper(c, relayMode) - case RelayModeAudio: + case RelayModeAudioSpeech: + fallthrough + case RelayModeAudioTranslation: + fallthrough + case RelayModeAudioTranscription: err = relayAudioHelper(c, relayMode) default: err = relayTextHelper(c, relayMode) diff --git a/middleware/distributor.go b/middleware/distributor.go index 17137eea..c4ddc3a0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) { } else { // Select a channel for the user var modelRequest ModelRequest - var err error - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - err = common.UnmarshalBodyReusable(c, &modelRequest) - } + err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -60,7 +57,7 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if modelRequest.Model == "" { - modelRequest.Model = "dall-e" + modelRequest.Model = "dall-e-2" } } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {