From 519077185f7316d861fcdde9717ff18a52af3ec0 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 19 May 2023 09:40:21 +0800 Subject: [PATCH] fix: make the token number calculation more accurate --- controller/relay-utils.go | 61 +++++++++++++++++++++++++++++++++++++++ controller/relay.go | 39 ++----------------------- 2 files changed, 64 insertions(+), 36 deletions(-) create mode 100644 controller/relay-utils.go diff --git a/controller/relay-utils.go b/controller/relay-utils.go new file mode 100644 index 00000000..a202e69b --- /dev/null +++ b/controller/relay-utils.go @@ -0,0 +1,61 @@ +package controller + +import ( + "fmt" + "github.com/pkoukk/tiktoken-go" + "one-api/common" + "strings" +) + +var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} + +func getTokenEncoder(model string) *tiktoken.Tiktoken { + if tokenEncoder, ok := tokenEncoderMap[model]; ok { + return tokenEncoder + } + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error())) + } + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder +} + +func countTokenMessages(messages []Message, model string) int { + tokenEncoder := getTokenEncoder(model) + // Reference: + // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + // https://github.com/pkoukk/tiktoken-go/issues/6 + // + // Every message follows <|start|>{role/name}\n{content}<|end|>\n + var tokensPerMessage int + var tokensPerName int + if strings.HasPrefix(model, "gpt-3.5") { + tokensPerMessage = 4 + tokensPerName = -1 // If there's a name, the role is omitted + } else if strings.HasPrefix(model, "gpt-4") { + tokensPerMessage = 3 + tokensPerName = 1 + } else { + tokensPerMessage = 3 + tokensPerName = 1 + } + tokenNum := 0 + for _, message := range messages { + tokenNum += tokensPerMessage + tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil)) + tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil)) + if message.Name != "" { + tokenNum += tokensPerName + tokenNum += len(tokenEncoder.Encode(message.Name, nil, nil)) + } + } + tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> + return tokenNum +} + +func countTokenText(text string, model string) int { + tokenEncoder := getTokenEncoder(model) + token := tokenEncoder.Encode(text, nil, nil) + return len(token) +} diff --git a/controller/relay.go b/controller/relay.go index db6298fa..d84a741c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" - "github.com/pkoukk/tiktoken-go" "io" "net/http" "one-api/common" @@ -17,6 +16,7 @@ import ( type Message struct { Role string `json:"role"` Content string `json:"content"` + Name string `json:"name"` } type ChatRequest struct { @@ -65,40 +65,6 @@ type StreamResponse struct { } `json:"choices"` } -func countTokenMessages(messages []Message, model string) int { - // 获取模型的编码器 - tokenEncoder, _ := tiktoken.EncodingForModel(model) - - // 参照官方的token计算cookbook - // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - // https://github.com/pkoukk/tiktoken-go/issues/6 - var tokens_per_message int - if strings.HasPrefix(model, "gpt-3.5") { - tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n - } else if strings.HasPrefix(model, "gpt-4") { - tokens_per_message = 3 - } else { - tokens_per_message = 3 - } - - token := 0 - for _, message := range messages { - token += tokens_per_message - token += len(tokenEncoder.Encode(message.Content, nil, nil)) - token += len(tokenEncoder.Encode(message.Role, nil, nil)) - } - // 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的 - token += 3 // every reply is primed with <|start|>assistant<|message|> - return token -} - -func countTokenText(text string, model string) int { - // 获取模型的编码器 - tokenEncoder, _ := tiktoken.EncodingForModel(model) - token := tokenEncoder.Encode(text, nil, nil) - return len(token) -} - func Relay(c *gin.Context) { err := relayHelper(c) if err != nil { @@ -230,7 +196,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode { completionRatio = 2 } if isStream { - quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio + responseTokens := countTokenText(streamResponseText, textRequest.Model) + quota = promptTokens + responseTokens*completionRatio } else { quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio }