Make token calculation more accurate.

This commit is contained in:
quzard 2023-05-18 22:09:22 +08:00
parent 481ba41fbd
commit 1aa82b18b5

View File

@ -65,9 +65,36 @@ type StreamResponse struct {
} `json:"choices"`
}
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
func countTokenMessages(messages []Message, model string) int {
// 获取模型的编码器
tokenEncoder, _ := tiktoken.EncodingForModel(model)
func countToken(text string) int {
// 参照官方的token计算cookbook
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
var tokens_per_message int
if strings.HasPrefix(model, "gpt-3.5") {
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
} else if strings.HasPrefix(model, "gpt-4") {
tokens_per_message = 3
} else {
tokens_per_message = 3
}
token := 0
for _, message := range messages {
token += tokens_per_message
token += len(tokenEncoder.Encode(message.Content, nil, nil))
token += len(tokenEncoder.Encode(message.Role, nil, nil))
}
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
token += 3 // every reply is primed with <|start|>assistant<|message|>
return token
}
func countTokenText(text string, model string) int {
// 获取模型的编码器
tokenEncoder, _ := tiktoken.EncodingForModel(model)
token := tokenEncoder.Encode(text, nil, nil)
return len(token)
}
@ -149,11 +176,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
model_ = strings.TrimSuffix(model_, "-0314")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
var promptText string
for _, message := range textRequest.Messages {
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
promptTokens := countToken(promptText) + 3
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
@ -206,8 +230,7 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
completionRatio = 2
}
if isStream {
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
quota = promptTokens + countToken(completionText)*completionRatio
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
} else {
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
}