diff --git a/controller/relay-audio.go b/controller/relay-audio.go index e6f54f01..381c6feb 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -31,6 +32,9 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) diff --git a/controller/relay-image.go b/controller/relay-image.go index fb30895c..998a7851 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -99,7 +99,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*sizeRatio*1000) * imageRequest.N if consumeQuota && userQuota-quota < 0 { - return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden) + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) diff --git a/controller/relay-text.go b/controller/relay-text.go index 5a5f355b..3041e3a9 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -204,6 +204,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if err != nil { return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } + if userQuota-preConsumedQuota < 0 { + return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) + } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)