From 1aa82b18b5974149108a7023c2f9b1916b6691c4 Mon Sep 17 00:00:00 2001 From: quzard <1191890118@qq.com> Date: Thu, 18 May 2023 22:09:22 +0800 Subject: [PATCH] Make token calculation more accurate. --- controller/relay.go | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index bc350f0d..db6298fa 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -65,9 +65,36 @@ type StreamResponse struct { } `json:"choices"` } -var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base") +func countTokenMessages(messages []Message, model string) int { + // 获取模型的编码器 + tokenEncoder, _ := tiktoken.EncodingForModel(model) -func countToken(text string) int { + // 参照官方的token计算cookbook + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + var tokens_per_message int + if strings.HasPrefix(model, "gpt-3.5") { + tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n + } else if strings.HasPrefix(model, "gpt-4") { + tokens_per_message = 3 + } else { + tokens_per_message = 3 + } + + token := 0 + for _, message := range messages { + token += tokens_per_message + token += len(tokenEncoder.Encode(message.Content, nil, nil)) + token += len(tokenEncoder.Encode(message.Role, nil, nil)) + } + // 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的 + token += 3 // every reply is primed with <|start|>assistant<|message|> + return token +} + +func countTokenText(text string, model string) int { + // 获取模型的编码器 + tokenEncoder, _ := tiktoken.EncodingForModel(model) token := tokenEncoder.Encode(text, nil, nil) return len(token) } @@ -149,11 +176,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { model_ = strings.TrimSuffix(model_, "-0314") fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task) } - var promptText string - for _, message := range textRequest.Messages { - promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content) - } - promptTokens := countToken(promptText) + 3 + + promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model) preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + textRequest.MaxTokens @@ -206,8 +230,7 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { completionRatio = 2 } if isStream { - completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) - quota = promptTokens + countToken(completionText)*completionRatio + quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio } else { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio }