fix: fix log recording & error handling for relay-audio
This commit is contained in:
parent
9889377f0e
commit
0e73418cdf
@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
|
||||||
modelRatio := common.GetModelRatio(audioModel)
|
modelRatio := common.GetModelRatio(audioModel)
|
||||||
groupRatio := common.GetGroupRatio(group)
|
groupRatio := common.GetGroupRatio(group)
|
||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
var quota int
|
||||||
|
var preConsumedQuota int
|
||||||
|
switch relayMode {
|
||||||
|
case RelayModeAudioSpeech:
|
||||||
|
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
||||||
|
quota = preConsumedQuota
|
||||||
|
default:
|
||||||
|
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
||||||
|
}
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
quota := 0
|
|
||||||
// Check if user quota is enough
|
// Check if user quota is enough
|
||||||
if relayMode == RelayModeAudioSpeech {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
if quota > userQuota {
|
}
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
}
|
if err != nil {
|
||||||
} else {
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
if userQuota-preConsumedQuota < 0 {
|
}
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
if userQuota > 100*preConsumedQuota {
|
||||||
}
|
// in this case, we do not pre-consume quota
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
}
|
|
||||||
if userQuota > 100*preConsumedQuota {
|
|
||||||
// in this case, we do not pre-consume quota
|
|
||||||
// because the user has enough quota
|
|
||||||
preConsumedQuota = 0
|
|
||||||
}
|
|
||||||
if preConsumedQuota > 0 {
|
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if relayMode == RelayModeAudioSpeech {
|
if relayMode != RelayModeAudioSpeech {
|
||||||
defer func(ctx context.Context) {
|
|
||||||
go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
|
||||||
}(c.Request.Context())
|
|
||||||
} else {
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
defer func(ctx context.Context) {
|
quota = countTokenText(whisperResponse.Text, audioModel)
|
||||||
quota := countTokenText(whisperResponse.Text, audioModel)
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
|
||||||
}(c.Request.Context())
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
}
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
// we need to roll back the pre-consumed quota
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go func() {
|
||||||
|
// negative means add quota back for token & user
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}(c.Request.Context())
|
||||||
|
}
|
||||||
|
return relayErrorHandler(resp)
|
||||||
|
}
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||||
|
}(c.Request.Context())
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
}
|
}
|
||||||
|
@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
|
|||||||
return fullRequestURL
|
return fullRequestURL
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
// quotaDelta is remaining quota to be consumed
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
}
|
}
|
||||||
@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
}
|
}
|
||||||
if quota != 0 {
|
// totalQuota is total quota consumed
|
||||||
|
if totalQuota != 0 {
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
|
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||||
|
}
|
||||||
|
if totalQuota <= 0 {
|
||||||
|
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user