perf: lazy initialization for token encoders (close #566)

This commit is contained in:
JustSong 2023-09-29 17:56:11 +08:00
parent 197d1d7a9d
commit 594f06e7b0

View File

@ -9,44 +9,53 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
common.SysLog("initializing token encoders") common.SysLog("initializing token encoders")
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model, _ := range common.ModelRatio { for model, _ := range common.ModelRatio {
tokenEncoder, err := tiktoken.EncodingForModel(model) if strings.HasPrefix(model, "gpt-3.5") {
if err != nil { tokenEncoderMap[model] = gpt35TokenEncoder
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) } else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = fallbackTokenEncoder tokenEncoderMap[model] = gpt4TokenEncoder
continue } else {
tokenEncoderMap[model] = nil
} }
tokenEncoderMap[model] = tokenEncoder
} }
common.SysLog("token encoders initialized") common.SysLog("token encoders initialized")
} }
func getTokenEncoder(model string) *tiktoken.Tiktoken { func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok { tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
return tokenEncoder return tokenEncoder
} }
tokenEncoder, err := tiktoken.EncodingForModel(model) if ok {
if err != nil { tokenEncoder, err := tiktoken.EncodingForModel(model)
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error())) common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
} }
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
} }
tokenEncoderMap[model] = tokenEncoder return defaultTokenEncoder
return tokenEncoder
} }
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {