fix: make the token number calculation more accurate
This commit is contained in:
parent
1aa82b18b5
commit
519077185f
61
controller/relay-utils.go
Normal file
61
controller/relay-utils.go
Normal file
@ -0,0 +1,61 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
||||
return tokenEncoder
|
||||
}
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error()))
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
|
||||
func countTokenMessages(messages []Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||
tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||
if message.Name != "" {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += len(tokenEncoder.Encode(message.Name, nil, nil))
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
func countTokenText(text string, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
token := tokenEncoder.Encode(text, nil, nil)
|
||||
return len(token)
|
||||
}
|
@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@ -17,6 +16,7 @@ import (
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
@ -65,40 +65,6 @@ type StreamResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func countTokenMessages(messages []Message, model string) int {
|
||||
// 获取模型的编码器
|
||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||
|
||||
// 参照官方的token计算cookbook
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
var tokens_per_message int
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokens_per_message = 3
|
||||
} else {
|
||||
tokens_per_message = 3
|
||||
}
|
||||
|
||||
token := 0
|
||||
for _, message := range messages {
|
||||
token += tokens_per_message
|
||||
token += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||
token += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||
}
|
||||
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
|
||||
token += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||
return token
|
||||
}
|
||||
|
||||
func countTokenText(text string, model string) int {
|
||||
// 获取模型的编码器
|
||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
||||
token := tokenEncoder.Encode(text, nil, nil)
|
||||
return len(token)
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
err := relayHelper(c)
|
||||
if err != nil {
|
||||
@ -230,7 +196,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
||||
completionRatio = 2
|
||||
}
|
||||
if isStream {
|
||||
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
|
||||
responseTokens := countTokenText(streamResponseText, textRequest.Model)
|
||||
quota = promptTokens + responseTokens*completionRatio
|
||||
} else {
|
||||
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user