Make token calculation more accurate.
This commit is contained in:
parent
481ba41fbd
commit
1aa82b18b5
@ -65,9 +65,36 @@ type StreamResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
|
||||
func countTokenMessages(messages []Message, model string) int {
|
||||
// 获取模型的编码器
|
||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||
|
||||
func countToken(text string) int {
|
||||
// 参照官方的token计算cookbook
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
var tokens_per_message int
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokens_per_message = 3
|
||||
} else {
|
||||
tokens_per_message = 3
|
||||
}
|
||||
|
||||
token := 0
|
||||
for _, message := range messages {
|
||||
token += tokens_per_message
|
||||
token += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||
token += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||
}
|
||||
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
|
||||
token += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||
return token
|
||||
}
|
||||
|
||||
func countTokenText(text string, model string) int {
|
||||
// 获取模型的编码器
|
||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||
token := tokenEncoder.Encode(text, nil, nil)
|
||||
return len(token)
|
||||
}
|
||||
@ -149,11 +176,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||
}
|
||||
var promptText string
|
||||
for _, message := range textRequest.Messages {
|
||||
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
|
||||
}
|
||||
promptTokens := countToken(promptText) + 3
|
||||
|
||||
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||
@ -206,8 +230,7 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
||||
completionRatio = 2
|
||||
}
|
||||
if isStream {
|
||||
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
|
||||
quota = promptTokens + countToken(completionText)*completionRatio
|
||||
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
|
||||
} else {
|
||||
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user