Make token calculation more accurate.
This commit is contained in:
parent
481ba41fbd
commit
1aa82b18b5
@ -65,9 +65,36 @@ type StreamResponse struct {
|
|||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenEncoder, _ = tiktoken.GetEncoding("cl100k_base")
|
func countTokenMessages(messages []Message, model string) int {
|
||||||
|
// 获取模型的编码器
|
||||||
|
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||||
|
|
||||||
func countToken(text string) int {
|
// 参照官方的token计算cookbook
|
||||||
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
|
var tokens_per_message int
|
||||||
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
|
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
|
tokens_per_message = 3
|
||||||
|
} else {
|
||||||
|
tokens_per_message = 3
|
||||||
|
}
|
||||||
|
|
||||||
|
token := 0
|
||||||
|
for _, message := range messages {
|
||||||
|
token += tokens_per_message
|
||||||
|
token += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||||
|
token += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||||
|
}
|
||||||
|
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
|
||||||
|
token += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokenText(text string, model string) int {
|
||||||
|
// 获取模型的编码器
|
||||||
|
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||||
token := tokenEncoder.Encode(text, nil, nil)
|
token := tokenEncoder.Encode(text, nil, nil)
|
||||||
return len(token)
|
return len(token)
|
||||||
}
|
}
|
||||||
@ -149,11 +176,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||||
}
|
}
|
||||||
var promptText string
|
|
||||||
for _, message := range textRequest.Messages {
|
promptTokens := countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||||
promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
|
|
||||||
}
|
|
||||||
promptTokens := countToken(promptText) + 3
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if textRequest.MaxTokens != 0 {
|
if textRequest.MaxTokens != 0 {
|
||||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||||
@ -206,8 +230,7 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
completionRatio = 2
|
completionRatio = 2
|
||||||
}
|
}
|
||||||
if isStream {
|
if isStream {
|
||||||
completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText)
|
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
|
||||||
quota = promptTokens + countToken(completionText)*completionRatio
|
|
||||||
} else {
|
} else {
|
||||||
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user