diff --git a/.gitignore b/.gitignore index 60abb13e..1cfa1e7f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ upload build *.db-journal logs -data \ No newline at end of file +data +node_modules diff --git a/common/audio/audio.go b/common/audio/audio.go new file mode 100644 index 00000000..58a8b308 --- /dev/null +++ b/common/audio/audio.go @@ -0,0 +1,19 @@ +package audio + +import ( + "bytes" + "context" + "os/exec" + "strconv" +) + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, err + } + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/common/file.go b/common/file.go new file mode 100644 index 00000000..695d8306 --- /dev/null +++ b/common/file.go @@ -0,0 +1,20 @@ +package common + +import ( + "io" + "os" +) + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + f, err := os.CreateTemp(os.TempDir(), filename) + if err != nil { + return "", err + } + defer f.Close() + _, err = io.Copy(f, data) + if err != nil { + return "", err + } + return f.Name(), nil +} diff --git a/common/http.go b/common/http.go new file mode 100644 index 00000000..c0d75c6f --- /dev/null +++ b/common/http.go @@ -0,0 +1,22 @@ +package common + +import ( + "bytes" + "io" + "net/http" +) + +func CloneRequest(old *http.Request) *http.Request { + req := old.Clone(old.Context()) + oldBody, err := io.ReadAll(old.Body) + if err != nil { + return nil + } + err = old.Body.Close() + if err != nil { + return nil + } + old.Body = io.NopCloser(bytes.NewBuffer(oldBody)) + req.Body = io.NopCloser(bytes.NewBuffer(oldBody)) + return req +} diff --git a/common/model-ratio.go b/common/model-ratio.go index 97cb060d..7b05b27a 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -61,7 +61,7 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 1, // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens "tts-1": 7.5, // $0.015 / 1K characters "tts-1-1106": 7.5, "tts-1-hd": 15, // $0.030 / 1K characters @@ -157,5 +157,8 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "claude-2") { return 2.965517 } + if strings.HasPrefix(name, "whisper-1") { + return 0 // only count input audio duration + } return 1 } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 08d9af2a..6a1ad8a2 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -7,17 +7,45 @@ import ( "encoding/json" "errors" "fmt" - "github.com/gin-gonic/gin" "io" + "math" "net/http" "one-api/common" + "one-api/common/audio" "one-api/model" "one-api/relay/channel/openai" "one-api/relay/constant" "one-api/relay/util" + "os" "strings" + + "github.com/gin-gonic/gin" ) +const ( + TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens +) + +func countAudioTokens(req *http.Request) (int, error) { + cloned := common.CloneRequest(req) + defer cloned.Body.Close() + file, header, err := cloned.FormFile("file") + if err != nil { + return 0, err + } + defer file.Close() + f, err := common.SaveTmpFile(header.Filename, file) + if err != nil { + return 0, err + } + defer os.Remove(f) + duration, err := audio.GetAudioDuration(cloned.Context(), f) + if err != nil { + return 0, err + } + return int(math.Ceil(duration)) * TokensPerSecond, nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode { audioModel := "whisper-1" @@ -28,8 +56,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode group := c.GetString("group") tokenName := c.GetString("token_name") + var inputTokens int var ttsRequest openai.TextToSpeechRequest - if relayMode == constant.RelayModeAudioSpeech { + modelRatio := common.GetModelRatio(audioModel) + groupRatio := common.GetGroupRatio(group) + ratio := modelRatio * groupRatio + var quota int + var preConsumedQuota int + switch relayMode { + case constant.RelayModeAudioSpeech: // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid @@ -41,20 +76,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode if len(ttsRequest.Input) > 4096 { return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) } - } - - modelRatio := common.GetModelRatio(audioModel) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - var quota int - var preConsumedQuota int - switch relayMode { - case constant.RelayModeAudioSpeech: - preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) - quota = preConsumedQuota + inputTokens = len(ttsRequest.Input) default: - preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + // whisper-1 audio transcription + audioTokens, err := countAudioTokens(c.Request) + if err != nil { + return openai.ErrorWrapper(err, "get_audio_duration_failed", http.StatusInternalServerError) + } + inputTokens = audioTokens } + preConsumedQuota = int(float64(inputTokens) * ratio) + quota = preConsumedQuota userQuota, err := model.CacheGetUserQuota(userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -112,7 +144,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -145,44 +176,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != constant.RelayModeAudioSpeech { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - - var openAIErr openai.SlimTextResponse - if err = json.Unmarshal(responseBody, &openAIErr); err == nil { - if openAIErr.Error.Message != "" { - return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) - } - } - - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - default: - return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) - } - if err != nil { - return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) - } - quota = openai.CountTokenText(text, audioModel) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } if resp.StatusCode != http.StatusOK { if preConsumedQuota > 0 { // we need to roll back the pre-consumed quota diff --git a/relay/util/common.go b/relay/util/common.go index 9d13b12e..c5015b51 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -136,11 +136,13 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) + if quotaDelta != 0 { + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } } - err = model.CacheUpdateUserQuota(userId) + err := model.CacheUpdateUserQuota(userId) if err != nil { common.SysError("error update user quota cache: " + err.Error()) }