fix: make the token number calculation more accurate
This commit is contained in:
parent
1aa82b18b5
commit
519077185f
61
controller/relay-utils.go
Normal file
61
controller/relay-utils.go
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
"one-api/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
|
||||||
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
|
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog(fmt.Sprintf("failed to get token encoder for model %s: %s", model, err.Error()))
|
||||||
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
return tokenEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokenMessages(messages []Message, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
// Reference:
|
||||||
|
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||||
|
//
|
||||||
|
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
var tokensPerMessage int
|
||||||
|
var tokensPerName int
|
||||||
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
|
tokensPerMessage = 4
|
||||||
|
tokensPerName = -1 // If there's a name, the role is omitted
|
||||||
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
} else {
|
||||||
|
tokensPerMessage = 3
|
||||||
|
tokensPerName = 1
|
||||||
|
}
|
||||||
|
tokenNum := 0
|
||||||
|
for _, message := range messages {
|
||||||
|
tokenNum += tokensPerMessage
|
||||||
|
tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil))
|
||||||
|
tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil))
|
||||||
|
if message.Name != "" {
|
||||||
|
tokenNum += tokensPerName
|
||||||
|
tokenNum += len(tokenEncoder.Encode(message.Name, nil, nil))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return tokenNum
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokenText(text string, model string) int {
|
||||||
|
tokenEncoder := getTokenEncoder(model)
|
||||||
|
token := tokenEncoder.Encode(text, nil, nil)
|
||||||
|
return len(token)
|
||||||
|
}
|
@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@ -17,6 +16,7 @@ import (
|
|||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
@ -65,40 +65,6 @@ type StreamResponse struct {
|
|||||||
} `json:"choices"`
|
} `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokenMessages(messages []Message, model string) int {
|
|
||||||
// 获取模型的编码器
|
|
||||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
|
||||||
|
|
||||||
// 参照官方的token计算cookbook
|
|
||||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
||||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
|
||||||
var tokens_per_message int
|
|
||||||
if strings.HasPrefix(model, "gpt-3.5") {
|
|
||||||
tokens_per_message = 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
||||||
} else if strings.HasPrefix(model, "gpt-4") {
|
|
||||||
tokens_per_message = 3
|
|
||||||
} else {
|
|
||||||
tokens_per_message = 3
|
|
||||||
}
|
|
||||||
|
|
||||||
token := 0
|
|
||||||
for _, message := range messages {
|
|
||||||
token += tokens_per_message
|
|
||||||
token += len(tokenEncoder.Encode(message.Content, nil, nil))
|
|
||||||
token += len(tokenEncoder.Encode(message.Role, nil, nil))
|
|
||||||
}
|
|
||||||
// 经过测试这个assistant的token是算在prompt里面的,而不是算在Completion里面的
|
|
||||||
token += 3 // every reply is primed with <|start|>assistant<|message|>
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
func countTokenText(text string, model string) int {
|
|
||||||
// 获取模型的编码器
|
|
||||||
tokenEncoder, _ := tiktoken.EncodingForModel(model)
|
|
||||||
token := tokenEncoder.Encode(text, nil, nil)
|
|
||||||
return len(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
err := relayHelper(c)
|
err := relayHelper(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -230,7 +196,8 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
|
|||||||
completionRatio = 2
|
completionRatio = 2
|
||||||
}
|
}
|
||||||
if isStream {
|
if isStream {
|
||||||
quota = promptTokens + countTokenText(streamResponseText, textRequest.Model)*completionRatio
|
responseTokens := countTokenText(streamResponseText, textRequest.Model)
|
||||||
|
quota = promptTokens + responseTokens*completionRatio
|
||||||
} else {
|
} else {
|
||||||
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user