refactor: remove consumeQuota related logic (#738)

* feat: 删除relay-text中的consumeQuota变量

该变量始终为true,可以删除

* chore: remove useless code

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
igophper 2023-11-24 20:42:29 +08:00 committed by GitHub
parent 495fc628e4
commit d85e356b6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 106 deletions

View File

@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var imageRequest ImageRequest var imageRequest ImageRequest
if consumeQuota { err := common.UnmarshalBodyReusable(c, &imageRequest)
err := common.UnmarshalBodyReusable(c, &imageRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} }
// Size validation // Size validation
@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
quota := int(ratio*imageCostRatio*1000) * imageRequest.N quota := int(ratio*imageCostRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 { if userQuota-quota < 0 {
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
} }
@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse var textResponse ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { err := model.PostConsumeTokenQuota(tokenId, quota)
err := model.PostConsumeTokenQuota(tokenId, quota) if err != nil {
if err != nil { common.SysError("error consuming token remain quota: " + err.Error())
common.SysError("error consuming token remain quota: " + err.Error()) }
} err = model.CacheUpdateUserQuota(userId)
err = model.CacheUpdateUserQuota(userId) if err != nil {
if err != nil { common.SysError("error update user quota cache: " + err.Error())
common.SysError("error update user quota cache: " + err.Error()) }
} if quota != 0 {
if quota != 0 { tokenName := c.GetString("token_name")
tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id")
channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
}(c.Request.Context()) }(c.Request.Context())
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header { for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])

View File

@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
return nil, responseText return nil, responseText
} }
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse var textResponse TextResponse
if consumeQuota { responseBody, err := io.ReadAll(resp.Body)
responseBody, err := io.ReadAll(resp.Body) if err != nil {
if err != nil { return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
} }
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set. // And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response. // So the httpClient will be confused by the response.
@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
c.Writer.Header().Set(k, v[0]) c.Writer.Header().Set(k, v[0])
} }
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body) _, err = io.Copy(c.Writer, resp.Body)
if err != nil { if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }

View File

@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
userId := c.GetInt("id") userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group") group := c.GetString("group")
var textRequest GeneralOpenAIRequest var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM { err := common.UnmarshalBodyReusable(c, &textRequest)
err := common.UnmarshalBodyReusable(c, &textRequest) if err != nil {
if err != nil { return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
} }
if relayMode == RelayModeModerations && textRequest.Model == "" { if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest" textRequest.Model = "text-moderation-latest"
@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
} }
if consumeQuota && preConsumedQuota > 0 { if preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil { if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
defer func(ctx context.Context) { defer func(ctx context.Context) {
// c.Writer.Flush() // c.Writer.Flush()
go func() { go func() {
if consumeQuota { quota := 0
quota := 0 completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens = textResponse.Usage.PromptTokens
promptTokens = textResponse.Usage.PromptTokens completionTokens = textResponse.Usage.CompletionTokens
completionTokens = textResponse.Usage.CompletionTokens quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 {
if ratio != 0 && quota <= 0 { quota = 1
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
} }
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.LogError(ctx, "error update user quota cache: "+err.Error())
}
if quota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
}
}() }()
}(c.Request.Context()) }(c.Request.Context())
switch apiType { switch apiType {
@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
return nil return nil
} else { } else {
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model) err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
if err != nil { if err != nil {
return err return err
} }

View File

@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])