This commit is contained in:
Qiying Wang 2024-01-20 12:39:11 +08:00 committed by GitHub
commit 1b60505289
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 119 additions and 59 deletions

3
.gitignore vendored
View File

@ -6,4 +6,5 @@ upload
build
*.db-journal
logs
data
data
node_modules

19
common/audio/audio.go Normal file
View File

@ -0,0 +1,19 @@
package audio
import (
"bytes"
"context"
"os/exec"
"strconv"
)
// GetAudioDuration returns the duration of an audio file in seconds.
func GetAudioDuration(ctx context.Context, filename string) (float64, error) {
// ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}}
c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename)
output, err := c.Output()
if err != nil {
return 0, err
}
return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64)
}

20
common/file.go Normal file
View File

@ -0,0 +1,20 @@
package common
import (
"io"
"os"
)
// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string.
func SaveTmpFile(filename string, data io.Reader) (string, error) {
f, err := os.CreateTemp(os.TempDir(), filename)
if err != nil {
return "", err
}
defer f.Close()
_, err = io.Copy(f, data)
if err != nil {
return "", err
}
return f.Name(), nil
}

22
common/http.go Normal file
View File

@ -0,0 +1,22 @@
package common
import (
"bytes"
"io"
"net/http"
)
func CloneRequest(old *http.Request) *http.Request {
req := old.Clone(old.Context())
oldBody, err := io.ReadAll(old.Body)
if err != nil {
return nil
}
err = old.Body.Close()
if err != nil {
return nil
}
old.Body = io.NopCloser(bytes.NewBuffer(oldBody))
req.Body = io.NopCloser(bytes.NewBuffer(oldBody))
return req
}

View File

@ -61,7 +61,7 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": 1, // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
@ -157,5 +157,8 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "claude-2") {
return 2.965517
}
if strings.HasPrefix(name, "whisper-1") {
return 0 // only count input audio duration
}
return 1
}

View File

@ -7,17 +7,45 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"math"
"net/http"
"one-api/common"
"one-api/common/audio"
"one-api/model"
"one-api/relay/channel/openai"
"one-api/relay/constant"
"one-api/relay/util"
"os"
"strings"
"github.com/gin-gonic/gin"
)
const (
TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
)
func countAudioTokens(req *http.Request) (int, error) {
cloned := common.CloneRequest(req)
defer cloned.Body.Close()
file, header, err := cloned.FormFile("file")
if err != nil {
return 0, err
}
defer file.Close()
f, err := common.SaveTmpFile(header.Filename, file)
if err != nil {
return 0, err
}
defer os.Remove(f)
duration, err := audio.GetAudioDuration(cloned.Context(), f)
if err != nil {
return 0, err
}
return int(math.Ceil(duration)) * TokensPerSecond, nil
}
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
audioModel := "whisper-1"
@ -28,8 +56,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
group := c.GetString("group")
tokenName := c.GetString("token_name")
var inputTokens int
var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech {
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int
var preConsumedQuota int
switch relayMode {
case constant.RelayModeAudioSpeech:
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@ -41,20 +76,17 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
if len(ttsRequest.Input) > 4096 {
return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
var quota int
var preConsumedQuota int
switch relayMode {
case constant.RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
inputTokens = len(ttsRequest.Input)
default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
// whisper-1 audio transcription
audioTokens, err := countAudioTokens(c.Request)
if err != nil {
return openai.ErrorWrapper(err, "get_audio_duration_failed", http.StatusInternalServerError)
}
inputTokens = audioTokens
}
preConsumedQuota = int(float64(inputTokens) * ratio)
quota = preConsumedQuota
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
@ -112,7 +144,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json")
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@ -145,44 +176,6 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != constant.RelayModeAudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var openAIErr openai.SlimTextResponse
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
if openAIErr.Error.Message != "" {
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
}
}
var text string
switch responseFormat {
case "json":
text, err = getTextFromJSON(responseBody)
case "text":
text, err = getTextFromText(responseBody)
case "srt":
text, err = getTextFromSRT(responseBody)
case "verbose_json":
text, err = getTextFromVerboseJSON(responseBody)
case "vtt":
text, err = getTextFromVTT(responseBody)
default:
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
}
if err != nil {
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
quota = openai.CountTokenText(text, audioModel)
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota

View File

@ -136,11 +136,13 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
if quotaDelta != 0 {
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
}
err = model.CacheUpdateUserQuota(userId)
err := model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}