From 40d7b692529b756a39465644ea175208ac482b3f Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 24 Nov 2023 20:40:38 +0800 Subject: [PATCH] chore: remove useless code --- controller/relay-image.go | 71 ++++++++++++++++++--------------------- middleware/auth.go | 6 ---- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/controller/relay-image.go b/controller/relay-image.go index 1d1b71ba..0ff18309 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode channelType := c.GetInt("channel") channelId := c.GetInt("channel_id") userId := c.GetInt("id") - consumeQuota := c.GetBool("consume_quota") group := c.GetString("group") var imageRequest ImageRequest - if consumeQuota { - err := common.UnmarshalBodyReusable(c, &imageRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } + err := common.UnmarshalBodyReusable(c, &imageRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) } // Size validation @@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode quota := int(ratio*imageCostRatio*1000) * imageRequest.N - if consumeQuota && userQuota-quota < 0 { + if userQuota-quota < 0 { return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } @@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var textResponse ImageResponse defer func(ctx context.Context) { - if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, quota) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } + err := model.PostConsumeTokenQuota(tokenId, quota) + if err != nil { + common.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(userId) + if err != nil { + common.SysError("error update user quota cache: " + err.Error()) + } + if quota != 0 { + tokenName := c.GetString("token_name") + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + channelId := c.GetInt("channel_id") + model.UpdateChannelUsedQuota(channelId, quota) } }(c.Request.Context()) - if consumeQuota { - responseBody, err := io.ReadAll(resp.Body) + responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + } + err = json.Unmarshal(responseBody, &textResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) diff --git a/middleware/auth.go b/middleware/auth.go index b0803612..ad7e64b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) { c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) - requestURL := c.Request.URL.String() - consumeQuota := true - if strings.HasPrefix(requestURL, "/v1/models") { - consumeQuota = false - } - c.Set("consume_quota", consumeQuota) if len(parts) > 1 { if model.IsAdmin(token.UserId) { c.Set("channelId", parts[1])