diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 1a9ee0d1..9010d275 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -15,6 +15,24 @@ var stopFinishReason = "stop" var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} +func InitTokenEncoders() { + common.SysLog("initializing token encoders") + fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + if err != nil { + common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error())) + } + for model, _ := range common.ModelRatio { + tokenEncoder, err := tiktoken.EncodingForModel(model) + if err != nil { + common.SysError(fmt.Sprintf("using fallback encoder for model %s", model)) + tokenEncoderMap[model] = fallbackTokenEncoder + continue + } + tokenEncoderMap[model] = tokenEncoder + } + common.SysLog("token encoders initialized") +} + func getTokenEncoder(model string) *tiktoken.Tiktoken { if tokenEncoder, ok := tokenEncoderMap[model]; ok { return tokenEncoder diff --git a/main.go b/main.go index f4d20373..9fb0a73e 100644 --- a/main.go +++ b/main.go @@ -77,6 +77,7 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } + controller.InitTokenEncoders() // Initialize HTTP server server := gin.Default()