From c9ac5e391ff0b5885e23d89158edea5c45dd15ad Mon Sep 17 00:00:00 2001 From: JustSong Date: Tue, 16 May 2023 16:18:35 +0800 Subject: [PATCH] feat: support max_tokens now (#52) --- controller/relay.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 590842f4..93cabb58 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,9 +26,10 @@ type ChatRequest struct { } type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` + Model string `json:"model"` + Messages []Message `json:"messages"` + Prompt string `json:"prompt"` + MaxTokens int `json:"max_tokens"` //Stream bool `json:"stream"` } @@ -128,8 +129,17 @@ func relayHelper(c *gin.Context) error { model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) } + var promptText string + for _, message := range textRequest.Messages { + promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content) + } + promptTokens := countToken(promptText) + 3 + preConsumedTokens := common.PreConsumedQuota + if textRequest.MaxTokens != 0 { + preConsumedTokens = promptTokens + textRequest.MaxTokens + } ratio := common.GetModelRatio(textRequest.Model) - preConsumedQuota := int(float64(common.PreConsumedQuota) * ratio) + preConsumedQuota := int(float64(preConsumedTokens) * ratio) if consumeQuota { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { @@ -176,12 +186,8 @@ func relayHelper(c *gin.Context) error { completionRatio = 2 } if isStream { - var promptText string - for _, message := range textRequest.Messages { - promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content) - } completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) - quota = countToken(promptText) + countToken(completionText)*completionRatio + 3 + quota = promptTokens + countToken(completionText)*completionRatio } else { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio }