refactor: remove consumeQuota related logic (#738)
* feat: 删除relay-text中的consumeQuota变量 该变量始终为true,可以删除 * chore: remove useless code --------- Co-authored-by: JustSong <songquanpeng@foxmail.com>
This commit is contained in:
parent
495fc628e4
commit
d85e356b6e
@ -33,15 +33,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
|
||||||
var imageRequest ImageRequest
|
var imageRequest ImageRequest
|
||||||
if consumeQuota {
|
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size validation
|
// Size validation
|
||||||
@ -122,7 +119,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,43 +148,39 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
var textResponse ImageResponse
|
var textResponse ImageResponse
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
if consumeQuota {
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
if err != nil {
|
||||||
if err != nil {
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
}
|
||||||
}
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
if err != nil {
|
||||||
if err != nil {
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
}
|
||||||
}
|
if quota != 0 {
|
||||||
if quota != 0 {
|
tokenName := c.GetString("token_name")
|
||||||
tokenName := c.GetString("token_name")
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
channelId := c.GetInt("channel_id")
|
||||||
channelId := c.GetInt("channel_id")
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
|
|
||||||
if consumeQuota {
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
|
@ -88,30 +88,29 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
|||||||
return nil, responseText
|
return nil, responseText
|
||||||
}
|
}
|
||||||
|
|
||||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
var textResponse TextResponse
|
var textResponse TextResponse
|
||||||
if consumeQuota {
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
if err != nil {
|
||||||
if err != nil {
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
if textResponse.Error.Type != "" {
|
|
||||||
return &OpenAIErrorWithStatusCode{
|
|
||||||
OpenAIError: textResponse.Error,
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// Reset response body
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
}
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &textResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if textResponse.Error.Type != "" {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: textResponse.Error,
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
// Reset response body
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||||
// So the httpClient will be confused by the response.
|
// So the httpClient will be confused by the response.
|
||||||
@ -120,7 +119,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
|
|||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
}
|
}
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err := io.Copy(c.Writer, resp.Body)
|
_, err = io.Copy(c.Writer, resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
@ -51,14 +51,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
consumeQuota := c.GetBool("consume_quota")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
var textRequest GeneralOpenAIRequest
|
var textRequest GeneralOpenAIRequest
|
||||||
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
if err != nil {
|
||||||
if err != nil {
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
||||||
textRequest.Model = "text-moderation-latest"
|
textRequest.Model = "text-moderation-latest"
|
||||||
@ -235,7 +232,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||||
}
|
}
|
||||||
if consumeQuota && preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
@ -414,37 +411,36 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
defer func(ctx context.Context) {
|
defer func(ctx context.Context) {
|
||||||
// c.Writer.Flush()
|
// c.Writer.Flush()
|
||||||
go func() {
|
go func() {
|
||||||
if consumeQuota {
|
quota := 0
|
||||||
quota := 0
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
promptTokens = textResponse.Usage.PromptTokens
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
completionTokens = textResponse.Usage.CompletionTokens
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
if ratio != 0 && quota <= 0 {
|
||||||
if ratio != 0 && quota <= 0 {
|
quota = 1
|
||||||
quota = 1
|
|
||||||
}
|
|
||||||
totalTokens := promptTokens + completionTokens
|
|
||||||
if totalTokens == 0 {
|
|
||||||
// in this case, must be some error happened
|
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
|
||||||
quota = 0
|
|
||||||
}
|
|
||||||
quotaDelta := quota - preConsumedQuota
|
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
totalTokens := promptTokens + completionTokens
|
||||||
|
if totalTokens == 0 {
|
||||||
|
// in this case, must be some error happened
|
||||||
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
|
quota = 0
|
||||||
|
}
|
||||||
|
quotaDelta := quota - preConsumedQuota
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
|
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
|
||||||
}()
|
}()
|
||||||
}(c.Request.Context())
|
}(c.Request.Context())
|
||||||
switch apiType {
|
switch apiType {
|
||||||
@ -458,7 +454,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Set("id", token.UserId)
|
c.Set("id", token.UserId)
|
||||||
c.Set("token_id", token.Id)
|
c.Set("token_id", token.Id)
|
||||||
c.Set("token_name", token.Name)
|
c.Set("token_name", token.Name)
|
||||||
requestURL := c.Request.URL.String()
|
|
||||||
consumeQuota := true
|
|
||||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
|
||||||
consumeQuota = false
|
|
||||||
}
|
|
||||||
c.Set("consume_quota", consumeQuota)
|
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set("channelId", parts[1])
|
c.Set("channelId", parts[1])
|
||||||
|
Loading…
Reference in New Issue
Block a user