Added support for Text-to-Speech models and

endpoints
This commit is contained in:
ckt1031 2023-11-17 16:22:56 +08:00
parent bc7c9105f4
commit cfb1e2ac5b
6 changed files with 161 additions and 46 deletions

View File

@ -33,6 +33,10 @@ var ModelRatio = map[string]float64{
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5,
"tts-1-1106": 7.5,
"tts-1-hd": 15,
"tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,

View File

@ -72,6 +72,42 @@ func init() {
Root: "whisper-1", Root: "whisper-1",
Parent: nil, Parent: nil,
}, },
{
Id: "tts-1",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1",
Parent: nil,
},
{
Id: "tts-1-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-1106",
Parent: nil,
},
{
Id: "tts-1-hd",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd",
Parent: nil,
},
{
Id: "tts-1-hd-1106",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "tts-1-hd-1106",
Parent: nil,
},
{ {
Id: "gpt-3.5-turbo", Id: "gpt-3.5-turbo",
Object: "model", Object: "model",

View File

@ -6,11 +6,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"path"
"github.com/gin-gonic/gin"
) )
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -22,31 +24,69 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
// Get last path of request URL
// Example: v1/audio/speech -> speech
requestPath := path.Base(c.Request.URL.Path) // speech
var ttsRequest TextToSpeechRequest
if requestPath == "speech" {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
if err != nil {
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
}
audioModel = ttsRequest.Model
// Check if text is too long 4096
if len(ttsRequest.Input) > 4096 {
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel) modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) quota := 0
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) // Check if user quota is enough
if err != nil { if requestPath == "speech" {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
}
if userQuota > 100*preConsumedQuota { fmt.Print(len(ttsRequest.Input), quota)
// in this case, we do not pre-consume quota
// because the user has enough quota if quota > userQuota {
preConsumedQuota = 0 return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
if preConsumedQuota > 0 { } else {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if userQuota-preConsumedQuota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
} }
} }
@ -93,30 +133,6 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
var audioResponse AudioResponse
defer func(ctx context.Context) {
go func() {
quota := countTokenText(audioResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
@ -127,9 +143,59 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil { if requestPath == "speech" {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) defer func(ctx context.Context) {
go func(quota int) {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(quota)
}(c.Request.Context())
} else {
var whisperResponse WhisperResponse
defer func(ctx context.Context) {
go func() {
quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
err = json.Unmarshal(responseBody, &whisperResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

View File

@ -83,10 +83,18 @@ type ImageRequest struct {
Size string `json:"size"` Size string `json:"size"`
} }
type AudioResponse struct { type WhisperResponse struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
} }
type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"`
Speed int `json:"speed"`
ReponseFormat string `json:"response_format"`
}
type Usage struct { type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`

View File

@ -41,7 +41,7 @@ func Distribute() func(c *gin.Context) {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error var err error
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
err = common.UnmarshalBodyReusable(c, &modelRequest) err = common.UnmarshalBodyReusable(c, &modelRequest)
} }
if err != nil { if err != nil {
@ -63,7 +63,7 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "dall-e" modelRequest.Model = "dall-e"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "whisper-1" modelRequest.Model = "whisper-1"
} }

View File

@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)