diff --git a/common/constants.go b/common/constants.go index 08754213..a97fda0e 100644 --- a/common/constants.go +++ b/common/constants.go @@ -67,6 +67,7 @@ var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var QuotaRemindThreshold = 1000 var PreConsumedQuota = 500 +var ApproximateTokenEnabled = false var RootUserEmail = "" diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 45f67ac5..35c7fa82 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -24,6 +24,13 @@ func getTokenEncoder(model string) *tiktoken.Tiktoken { return tokenEncoder } +func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { + if common.ApproximateTokenEnabled { + return int(float64(len(text)) * 0.38) + } + return len(tokenEncoder.Encode(text, nil, nil)) +} + func countTokenMessages(messages []Message, model string) int { tokenEncoder := getTokenEncoder(model) // Reference: @@ -43,11 +50,11 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += len(tokenEncoder.Encode(message.Content, nil, nil)) - tokenNum += len(tokenEncoder.Encode(message.Role, nil, nil)) + tokenNum += getTokenNum(tokenEncoder, message.Content) + tokenNum += getTokenNum(tokenEncoder, message.Role) if message.Name != nil { tokenNum += tokensPerName - tokenNum += len(tokenEncoder.Encode(*message.Name, nil, nil)) + tokenNum += getTokenNum(tokenEncoder, *message.Name) } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> @@ -70,8 +77,7 @@ func countTokenInput(input any, model string) int { func countTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) - token := tokenEncoder.Encode(text, nil, nil) - return len(token) + return getTokenNum(tokenEncoder, text) } func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode { diff --git a/model/option.go b/model/option.go index f41cf954..35aeec4c 100644 --- a/model/option.go +++ b/model/option.go @@ -34,6 +34,7 @@ func InitOptionMap() { common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) @@ -141,6 +142,8 @@ func updateOptionMap(key string, value string) (err error) { common.RegisterEnabled = boolValue case "AutomaticDisableChannelEnabled": common.AutomaticDisableChannelEnabled = boolValue + case "ApproximateTokenEnabled": + common.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": common.LogConsumeEnabled = boolValue case "DisplayInCurrencyEnabled": diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 3cc5ba99..69100c85 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -18,7 +18,8 @@ const OperationSetting = () => { ChannelDisableThreshold: 0, LogConsumeEnabled: '', DisplayInCurrencyEnabled: '', - DisplayTokenStatEnabled: '' + DisplayTokenStatEnabled: '', + ApproximateTokenEnabled: '', }); const [originInputs, setOriginInputs] = useState({}); let [loading, setLoading] = useState(false); @@ -181,6 +182,12 @@ const OperationSetting = () => { name='DisplayTokenStatEnabled' onChange={handleInputChange} /> +