diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 89a311a0..5b8898a7 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode } } - preConsumedTokens := common.PreConsumedQuota modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) + var quota int + var preConsumedQuota int + switch relayMode { + case RelayModeAudioSpeech: + preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + quota = preConsumedQuota + default: + preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio) + } userQuota, err := model.CacheGetUserQuota(userId) if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } - quota := 0 // Check if user quota is enough - if relayMode == RelayModeAudioSpeech { - quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) - if quota > userQuota { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - } else { - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } + err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + if err != nil { + return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) + } + if userQuota > 100*preConsumedQuota { + // in this case, we do not pre-consume quota + // because the user has enough quota + preConsumedQuota = 0 + } + if preConsumedQuota > 0 { + err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } + return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } @@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode == RelayModeAudioSpeech { - defer func(ctx context.Context) { - go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) - } else { + if relayMode != RelayModeAudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - defer func(ctx context.Context) { - quota := countTokenText(whisperResponse.Text, audioModel) - quotaDelta := quota - preConsumedQuota - go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) - }(c.Request.Context()) + quota = countTokenText(whisperResponse.Text, audioModel) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } + if resp.StatusCode != http.StatusOK { + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func(ctx context.Context) { + go func() { + // negative means add quota back for token & user + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) + } + }() + }(c.Request.Context()) + } + return relayErrorHandler(resp) + } + quotaDelta := quota - preConsumedQuota + defer func(ctx context.Context) { + go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + }(c.Request.Context()) + for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } diff --git a/controller/relay-utils.go b/controller/relay-utils.go index c7cd4766..391f28b4 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - err := model.PostConsumeTokenQuota(tokenId, quota) +func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } @@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c if err != nil { common.SysError("error update user quota cache: " + err.Error()) } - if quota != 0 { + // totalQuota is total quota consumed + if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) + model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) } }