fix: fix log recording & error handling for relay-audio

This commit is contained in:
JustSong 2023-11-26 12:05:16 +08:00
parent 9889377f0e
commit 0e73418cdf
2 changed files with 57 additions and 41 deletions

View File

@ -39,41 +39,40 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
} }
} }
preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel) modelRatio := common.GetModelRatio(audioModel)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) var quota int
var preConsumedQuota int
switch relayMode {
case RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId) userQuota, err := model.CacheGetUserQuota(userId)
if err != nil { if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
} }
quota := 0
// Check if user quota is enough // Check if user quota is enough
if relayMode == RelayModeAudioSpeech { if userQuota-preConsumedQuota < 0 {
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
if quota > userQuota { }
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
} if err != nil {
} else { return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
if userQuota-preConsumedQuota < 0 { }
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) if userQuota > 100*preConsumedQuota {
} // in this case, we do not pre-consume quota
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) // because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
if userQuota > 100*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
} }
} }
@ -141,11 +140,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
} }
if relayMode == RelayModeAudioSpeech { if relayMode != RelayModeAudioSpeech {
defer func(ctx context.Context) {
go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
} else {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@ -159,13 +154,29 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
} }
defer func(ctx context.Context) { quota = countTokenText(whisperResponse.Text, audioModel)
quota := countTokenText(whisperResponse.Text, audioModel)
quotaDelta := quota - preConsumedQuota
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
if resp.StatusCode != http.StatusOK {
if preConsumedQuota > 0 {
// we need to roll back the pre-consumed quota
defer func(ctx context.Context) {
go func() {
// negative means add quota back for token & user
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
}
}()
}(c.Request.Context())
}
return relayErrorHandler(resp)
}
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go postConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }

View File

@ -195,8 +195,9 @@ func getFullRequestURL(baseURL string, requestURL string, channelType int) strin
return fullRequestURL return fullRequestURL
} }
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
err := model.PostConsumeTokenQuota(tokenId, quota) // quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@ -204,10 +205,14 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
if err != nil { if err != nil {
common.SysError("error update user quota cache: " + err.Error()) common.SysError("error update user quota cache: " + err.Error())
} }
if quota != 0 { // totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent) model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
} }
} }