diff --git a/common/constants.go b/common/constants.go index 635e90fd..e99e9a05 100644 --- a/common/constants.go +++ b/common/constants.go @@ -49,11 +49,6 @@ var TurnstileSecretKey = "" var QuotaForNewUser = 100 -// https://platform.openai.com/docs/models/model-endpoint-compatibility -var RatioGPT3dot5 float64 = 2 -var RatioGPT4 float64 = 30 -var RatioGPT4_32k float64 = 60 - const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/common/model-ratio.go b/common/model-ratio.go new file mode 100644 index 00000000..23b34cc7 --- /dev/null +++ b/common/model-ratio.go @@ -0,0 +1,52 @@ +package common + +import "encoding/json" + +// https://platform.openai.com/docs/models/model-endpoint-compatibility +// https://openai.com/pricing +// TODO: when a new api is enabled, check the pricing here +var ModelRatio = map[string]float64{ + "gpt-4": 15, + "gpt-4-0314": 15, + "gpt-4-32k": 30, + "gpt-4-32k-0314": 30, + "gpt-3.5-turbo": 1, + "gpt-3.5-turbo-0301": 1, + "text-ada-001": 0.2, + "text-babbage-001": 0.25, + "text-curie-001": 1, + "text-davinci-002": 10, + "text-davinci-003": 10, + "text-davinci-edit-001": 10, + "code-davinci-edit-001": 10, + "whisper-1": 10, + "davinci": 10, + "curie": 10, + "babbage": 10, + "ada": 10, + "text-embedding-ada-002": 0.25, + "text-search-ada-doc-001": 10, + "text-moderation-stable": 10, + "text-moderation-latest": 10, +} + +func ModelRatio2JSONString() string { + jsonBytes, err := json.Marshal(ModelRatio) + if err != nil { + SysError("Error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateModelRatioByJSONString(jsonStr string) error { + return json.Unmarshal([]byte(jsonStr), &ModelRatio) +} + +func GetModelRatio(name string) float64 { + ratio, ok := ModelRatio[name] + if !ok { + SysError("Model ratio not found: " + name) + return 1 + } + return ratio +} diff --git a/controller/relay.go b/controller/relay.go index 595cc298..c41f4d8a 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -118,24 +118,22 @@ func relayHelper(c *gin.Context) error { defer func() { if consumeQuota { quota := 0 + usingGPT4 := strings.HasPrefix(textRequest.Model, "gpt-4") + completionRatio := 1 + if usingGPT4 { + completionRatio = 2 + } if isStream { - var text string + var promptText string for _, message := range textRequest.Messages { - text += fmt.Sprintf("%s: %s\n", message.Role, message.Content) + promptText += fmt.Sprintf("%s: %s\n", message.Role, message.Content) } - text += fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) - quota = countToken(text) + 3 + completionText := fmt.Sprintf("%s: %s\n", "assistant", streamResponseText) + quota = countToken(promptText) + countToken(completionText)*completionRatio + 3 } else { - quota = textResponse.Usage.TotalTokens - } - ratio := common.RatioGPT3dot5 - if strings.HasPrefix(textRequest.Model, "gpt-4-32k") { - ratio = common.RatioGPT4_32k - } else if strings.HasPrefix(textRequest.Model, "gpt-4") { - ratio = common.RatioGPT4 - } else { - ratio = common.RatioGPT3dot5 + quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio } + ratio := common.GetModelRatio(textRequest.Model) quota = int(float64(quota) * ratio) err := model.DecreaseTokenQuota(tokenId, quota) if err != nil { diff --git a/model/option.go b/model/option.go index bf7744a4..45a56ba9 100644 --- a/model/option.go +++ b/model/option.go @@ -47,9 +47,7 @@ func InitOptionMap() { common.OptionMap["TurnstileSiteKey"] = "" common.OptionMap["TurnstileSecretKey"] = "" common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["RatioGPT3dot5"] = strconv.FormatFloat(common.RatioGPT3dot5, 'f', -1, 64) - common.OptionMap["RatioGPT4"] = strconv.FormatFloat(common.RatioGPT4, 'f', -1, 64) - common.OptionMap["RatioGPT4_32k"] = strconv.FormatFloat(common.RatioGPT4_32k, 'f', -1, 64) + common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMapRWMutex.Unlock() options, _ := AllOption() @@ -75,7 +73,7 @@ func UpdateOption(key string, value string) error { return nil } -func updateOptionMap(key string, value string) { +func updateOptionMap(key string, value string) (err error) { common.OptionMapRWMutex.Lock() defer common.OptionMapRWMutex.Unlock() common.OptionMap[key] = value @@ -138,13 +136,10 @@ func updateOptionMap(key string, value string) { common.TurnstileSecretKey = value case "QuotaForNewUser": common.QuotaForNewUser, _ = strconv.Atoi(value) - case "RatioGPT3dot5": - common.RatioGPT3dot5, _ = strconv.ParseFloat(value, 64) - case "RatioGPT4": - common.RatioGPT4, _ = strconv.ParseFloat(value, 64) - case "RatioGPT4_32k": - common.RatioGPT4_32k, _ = strconv.ParseFloat(value, 64) + case "ModelRatio": + err = common.UpdateModelRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value } + return err } diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index a7db156b..303faaf7 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -1,6 +1,6 @@ import React, { useEffect, useState } from 'react'; import { Divider, Form, Grid, Header, Message } from 'semantic-ui-react'; -import { API, removeTrailingSlash, showError } from '../helpers'; +import { API, removeTrailingSlash, showError, verifyJSON } from '../helpers'; const SystemSetting = () => { let [inputs, setInputs] = useState({ @@ -25,9 +25,7 @@ const SystemSetting = () => { TurnstileSecretKey: '', RegisterEnabled: '', QuotaForNewUser: 0, - RatioGPT3dot5: 2, - RatioGPT4: 30, - RatioGPT4_32k: 60, + ModelRatio: '', TopUpLink: '' }); let originInputs = {}; @@ -93,7 +91,7 @@ const SystemSetting = () => { name === 'TurnstileSiteKey' || name === 'TurnstileSecretKey' || name === 'QuotaForNewUser' || - name.startsWith('Ratio') || + name === 'ModelRatio' || name === 'TopUpLink' ) { setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -111,19 +109,17 @@ const SystemSetting = () => { if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) { await updateOption('QuotaForNewUser', inputs.QuotaForNewUser); } - if (originInputs['RatioGPT3dot5'] !== inputs.RatioGPT3dot5) { - await updateOption('RatioGPT3dot5', inputs.RatioGPT3dot5); - } - if (originInputs['RatioGPT4'] !== inputs.RatioGPT4) { - await updateOption('RatioGPT4', inputs.RatioGPT4); - } - if (originInputs['RatioGPT4_32k'] !== inputs.RatioGPT4_32k) { - await updateOption('RatioGPT4_32k', inputs.RatioGPT4_32k); + if (originInputs['ModelRatio'] !== inputs.ModelRatio) { + if (!verifyJSON(inputs.ModelRatio)) { + showError('模型倍率不是合法的 JSON 字符串'); + return; + } + await updateOption('ModelRatio', inputs.ModelRatio); } if (originInputs['TopUpLink'] !== inputs.TopUpLink) { await updateOption('TopUpLink', inputs.TopUpLink); } - } + }; const submitSMTP = async () => { if (originInputs['SMTPServer'] !== inputs.SMTPServer) { @@ -278,39 +274,15 @@ const SystemSetting = () => { placeholder='例如发卡网站的购买链接' /> - - + - - 保存运营设置 diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index e2476263..f063d571 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -152,4 +152,13 @@ export function downloadTextAsFile(text, filename) { a.href = url; a.download = filename; a.click(); -} \ No newline at end of file +} + +export const verifyJSON = (str) => { + try { + JSON.parse(str); + } catch (e) { + return false; + } + return true; +}; \ No newline at end of file