chore: update impl

This commit is contained in:
JustSong 2023-11-17 20:49:33 +08:00
parent cfb1e2ac5b
commit d816c7aae1
5 changed files with 63 additions and 93 deletions

View File

@ -32,10 +32,10 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10, "text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, "tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5, "tts-1-1106": 7.5,
"tts-1-hd": 15, "tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15, "tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,

View File

@ -5,14 +5,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"path"
"github.com/gin-gonic/gin"
) )
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@ -23,24 +20,17 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
tokenName := c.GetString("token_name")
// Get last path of request URL
// Example: v1/audio/speech -> speech
requestPath := path.Base(c.Request.URL.Path) // speech
var ttsRequest TextToSpeechRequest var ttsRequest TextToSpeechRequest
if relayMode == RelayModeAudioSpeech {
if requestPath == "speech" {
// Read JSON // Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest) err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid // Check if JSON is valid
if err != nil { if err != nil {
return errorWrapper(err, "invalid_json", http.StatusBadRequest) return errorWrapper(err, "invalid_json", http.StatusBadRequest)
} }
audioModel = ttsRequest.Model audioModel = ttsRequest.Model
// Check if text is too long 4096 // Check if text is too long 4096
if len(ttsRequest.Input) > 4096 { if len(ttsRequest.Input) > 4096 {
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
@ -53,19 +43,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
quota := 0 quota := 0
// Check if user quota is enough // Check if user quota is enough
if requestPath == "speech" { if relayMode == RelayModeAudioSpeech {
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
fmt.Print(len(ttsRequest.Input), quota)
if quota > userQuota { if quota > userQuota {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
@ -134,72 +119,31 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
responseBody, err := io.ReadAll(resp.Body) if relayMode == RelayModeAudioSpeech {
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if requestPath == "speech" {
defer func(ctx context.Context) { defer func(ctx context.Context) {
go func(quota int) { go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}(quota)
}(c.Request.Context()) }(c.Request.Context())
} else { } else {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
var whisperResponse WhisperResponse var whisperResponse WhisperResponse
defer func(ctx context.Context) {
go func() {
quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}()
}(c.Request.Context())
err = json.Unmarshal(responseBody, &whisperResponse) err = json.Unmarshal(responseBody, &whisperResponse)
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
defer func(ctx context.Context) {
quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }

View File

@ -1,6 +1,7 @@
package controller package controller
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -8,6 +9,7 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
) )
@ -186,3 +188,20 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
} }
return fullRequestURL return fullRequestURL
} }
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}

View File

@ -24,7 +24,9 @@ const (
RelayModeModerations RelayModeModerations
RelayModeImagesGenerations RelayModeImagesGenerations
RelayModeEdits RelayModeEdits
RelayModeAudio RelayModeAudioSpeech
RelayModeAudioTranscription
RelayModeAudioTranslation
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
@ -88,11 +90,11 @@ type WhisperResponse struct {
} }
type TextToSpeechRequest struct { type TextToSpeechRequest struct {
Model string `json:"model" binding:"required"` Model string `json:"model" binding:"required"`
Input string `json:"input" binding:"required"` Input string `json:"input" binding:"required"`
Voice string `json:"voice" binding:"required"` Voice string `json:"voice" binding:"required"`
Speed int `json:"speed"` Speed float64 `json:"speed"`
ReponseFormat string `json:"response_format"` ResponseFormat string `json:"response_format"`
} }
type Usage struct { type Usage struct {
@ -191,14 +193,22 @@ func Relay(c *gin.Context) {
relayMode = RelayModeImagesGenerations relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits relayMode = RelayModeEdits
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = RelayModeAudio relayMode = RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcription") {
relayMode = RelayModeAudioTranscription
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translation") {
relayMode = RelayModeAudioTranslation
} }
var err *OpenAIErrorWithStatusCode var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case RelayModeImagesGenerations: case RelayModeImagesGenerations:
err = relayImageHelper(c, relayMode) err = relayImageHelper(c, relayMode)
case RelayModeAudio: case RelayModeAudioSpeech:
fallthrough
case RelayModeAudioTranslation:
fallthrough
case RelayModeAudioTranscription:
err = relayAudioHelper(c, relayMode) err = relayAudioHelper(c, relayMode)
default: default:
err = relayTextHelper(c, relayMode) err = relayTextHelper(c, relayMode)

View File

@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
} else { } else {
// Select a channel for the user // Select a channel for the user
var modelRequest ModelRequest var modelRequest ModelRequest
var err error err := common.UnmarshalBodyReusable(c, &modelRequest)
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
}
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return return
@ -60,7 +57,7 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "dall-e" modelRequest.Model = "dall-e-2"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {