feat(audio): count whisper-1 quota by audio duration
This commit is contained in:
parent
eed9f5fdf0
commit
89799d84c5
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,3 +7,4 @@ build
|
|||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
data
|
data
|
||||||
|
node_modules
|
||||||
|
19
common/audio/audio.go
Normal file
19
common/audio/audio.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package audio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAudioDuration returns the duration of an audio file in seconds.
|
||||||
|
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
|
||||||
|
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
|
||||||
|
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
|
||||||
|
output, err := c.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
|
||||||
|
}
|
20
common/file.go
Normal file
20
common/file.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
|
||||||
|
func SaveTmpFile(filename string, data io.Reader) (string, error) {
|
||||||
|
f, err := os.CreateTemp(os.TempDir(), filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
_, err = io.Copy(f, data)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return f.Name(), nil
|
||||||
|
}
|
22
common/http.go
Normal file
22
common/http.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CloneRequest(old *http.Request) *http.Request {
|
||||||
|
req := old.Clone(old.Context())
|
||||||
|
oldBody, err := io.ReadAll(old.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err = old.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
old.Body = io.NopCloser(bytes.NewBuffer(oldBody))
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(oldBody))
|
||||||
|
return req
|
||||||
|
}
|
@ -61,7 +61,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"text-davinci-003": 10,
|
"text-davinci-003": 10,
|
||||||
"text-davinci-edit-001": 10,
|
"text-davinci-edit-001": 10,
|
||||||
"code-davinci-edit-001": 10,
|
"code-davinci-edit-001": 10,
|
||||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
"whisper-1": 1, // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
|
||||||
"tts-1": 7.5, // $0.015 / 1K characters
|
"tts-1": 7.5, // $0.015 / 1K characters
|
||||||
"tts-1-1106": 7.5,
|
"tts-1-1106": 7.5,
|
||||||
"tts-1-hd": 15, // $0.030 / 1K characters
|
"tts-1-hd": 15, // $0.030 / 1K characters
|
||||||
@ -157,5 +157,8 @@ func GetCompletionRatio(name string) float64 {
|
|||||||
if strings.HasPrefix(name, "claude-2") {
|
if strings.HasPrefix(name, "claude-2") {
|
||||||
return 2.965517
|
return 2.965517
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(name, "whisper-1") {
|
||||||
|
return 0 // only count input audio duration
|
||||||
|
}
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
@ -7,17 +7,45 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/common/audio"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/relay/util"
|
"one-api/relay/util"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
func countAudioTokens(req *http.Request) (int, error) {
|
||||||
|
cloned := common.CloneRequest(req)
|
||||||
|
defer cloned.Body.Close()
|
||||||
|
file, header, err := cloned.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
f, err := common.SaveTmpFile(header.Filename, file)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer os.Remove(f)
|
||||||
|
duration, err := audio.GetAudioDuration(cloned.Context(), f)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int(math.Ceil(duration)) * TokensPerSecond, nil
|
||||||
|
}
|
||||||
|
|
||||||
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||||
audioModel := "whisper-1"
|
audioModel := "whisper-1"
|
||||||
|
|
||||||
@ -28,8 +56,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
|
|
||||||
|
var inputTokens int
|
||||||
var ttsRequest openai.TextToSpeechRequest
|
var ttsRequest openai.TextToSpeechRequest
|
||||||
if relayMode == constant.RelayModeAudioSpeech {
|
modelRatio := common.GetModelRatio(audioModel)
|
||||||
|
groupRatio := common.GetGroupRatio(group)
|
||||||
|
ratio := modelRatio * groupRatio
|
||||||
|
var quota int
|
||||||
|
var preConsumedQuota int
|
||||||
|
switch relayMode {
|
||||||
|
case constant.RelayModeAudioSpeech:
|
||||||
// Read JSON
|
// Read JSON
|
||||||
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
||||||
// Check if JSON is valid
|
// Check if JSON is valid
|
||||||
@ -41,20 +76,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
|||||||
if len(ttsRequest.Input) > 4096 {
|
if len(ttsRequest.Input) > 4096 {
|
||||||
return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
}
|
inputTokens = len(ttsRequest.Input)
|
||||||
|
|
||||||
modelRatio := common.GetModelRatio(audioModel)
|
|
||||||
groupRatio := common.GetGroupRatio(group)
|
|
||||||
ratio := modelRatio * groupRatio
|
|
||||||
var quota int
|
|
||||||
var preConsumedQuota int
|
|
||||||
switch relayMode {
|
|
||||||
case constant.RelayModeAudioSpeech:
|
|
||||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
|
||||||
quota = preConsumedQuota
|
|
||||||
default:
|
default:
|
||||||
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
// whisper-1 audio transcription
|
||||||
|
audioTokens, err := countAudioTokens(c.Request)
|
||||||
|
if err != nil {
|
||||||
|
return openai.ErrorWrapper(err, "get_audio_duration_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
inputTokens = audioTokens
|
||||||
}
|
}
|
||||||
|
preConsumedQuota = int(float64(inputTokens) * ratio)
|
||||||
|
quota = preConsumedQuota
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
@ -112,7 +144,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
|||||||
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -145,44 +176,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
|
|||||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode != constant.RelayModeAudioSpeech {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
var openAIErr openai.SlimTextResponse
|
|
||||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
|
||||||
if openAIErr.Error.Message != "" {
|
|
||||||
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var text string
|
|
||||||
switch responseFormat {
|
|
||||||
case "json":
|
|
||||||
text, err = getTextFromJSON(responseBody)
|
|
||||||
case "text":
|
|
||||||
text, err = getTextFromText(responseBody)
|
|
||||||
case "srt":
|
|
||||||
text, err = getTextFromSRT(responseBody)
|
|
||||||
case "verbose_json":
|
|
||||||
text, err = getTextFromVerboseJSON(responseBody)
|
|
||||||
case "vtt":
|
|
||||||
text, err = getTextFromVTT(responseBody)
|
|
||||||
default:
|
|
||||||
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
quota = openai.CountTokenText(text, audioModel)
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
// we need to roll back the pre-consumed quota
|
// we need to roll back the pre-consumed quota
|
||||||
|
@ -136,11 +136,13 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
|||||||
|
|
||||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||||
// quotaDelta is remaining quota to be consumed
|
// quotaDelta is remaining quota to be consumed
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
if quotaDelta != 0 {
|
||||||
if err != nil {
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
err := model.CacheUpdateUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user