diff --git a/cli/export.go b/cli/export.go new file mode 100644 index 00000000..cda9788b --- /dev/null +++ b/cli/export.go @@ -0,0 +1,48 @@ +package cli + +import ( + "encoding/json" + "one-api/common" + "one-api/relay/util" + "os" + "sort" +) + +func ExportPrices() { + prices := util.GetPricesList("default") + + if len(prices) == 0 { + common.SysError("No prices found") + return + } + + // Sort prices by ChannelType + sort.Slice(prices, func(i, j int) bool { + if prices[i].ChannelType == prices[j].ChannelType { + return prices[i].Model < prices[j].Model + } + return prices[i].ChannelType < prices[j].ChannelType + }) + + // 导出到当前目录下的 prices.json 文件 + file, err := os.Create("prices.json") + if err != nil { + common.SysError("Failed to create file: " + err.Error()) + return + } + defer file.Close() + + jsonData, err := json.MarshalIndent(prices, "", " ") + if err != nil { + common.SysError("Failed to encode prices: " + err.Error()) + return + } + + _, err = file.Write(jsonData) + if err != nil { + common.SysError("Failed to write to file: " + err.Error()) + return + } + + common.SysLog("Prices exported to prices.json") +} diff --git a/common/config/flag.go b/cli/flag.go similarity index 82% rename from common/config/flag.go rename to cli/flag.go index 85f56107..804f6cb4 100644 --- a/common/config/flag.go +++ b/cli/flag.go @@ -1,4 +1,4 @@ -package config +package cli import ( "flag" @@ -14,10 +14,11 @@ var ( printVersion = flag.Bool("version", false, "print version and exit") printHelp = flag.Bool("help", false, "print help and exit") logDir = flag.String("log-dir", "", "specify the log directory") - config = flag.String("config", "config.yaml", "specify the config.yaml path") + Config = flag.String("config", "config.yaml", "specify the config.yaml path") + export = flag.Bool("export", false, "Exports prices to a JSON file.") ) -func flagConfig() { +func FlagConfig() { flag.Parse() if *printVersion { @@ -38,6 +39,11 @@ func flagConfig() { viper.Set("log_dir", *logDir) } + if *export { + ExportPrices() + os.Exit(0) + } + } func help() { diff --git a/common/config/config.go b/common/config/config.go index e24e27a6..7d2634c2 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -4,13 +4,14 @@ import ( "strings" "time" + "one-api/cli" "one-api/common" "github.com/spf13/viper" ) func InitConf() { - flagConfig() + cli.FlagConfig() defaultConfig() setConfigFile() setEnv() @@ -25,11 +26,11 @@ func InitConf() { } func setConfigFile() { - if !common.IsFileExist(*config) { + if !common.IsFileExist(*cli.Config) { return } - viper.SetConfigFile(*config) + viper.SetConfigFile(*cli.Config) if err := viper.ReadInConfig(); err != nil { panic(err) } @@ -51,4 +52,5 @@ func defaultConfig() { viper.SetDefault("global.api_rate_limit", 180) viper.SetDefault("global.web_rate_limit", 100) viper.SetDefault("connect_timeout", 5) + viper.SetDefault("auto_price_updates", true) } diff --git a/common/model-ratio.go b/common/model-ratio.go index 8c3bb11f..92b262e6 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -1,217 +1,5 @@ package common -import ( - "encoding/json" - "strings" - "time" -) - -type ModelType struct { - Ratio []float64 - Type int -} - -var ModelTypes map[string]ModelType - -// ModelRatio -// https://platform.openai.com/docs/models/model-endpoint-compatibility -// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf -// https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here -// 1 === $0.002 / 1K tokens -// 1 === ¥0.014 / 1k tokens -var ModelRatio map[string][]float64 - -func init() { - ModelTypes = map[string]ModelType{ - // $0.03 / 1K tokens $0.06 / 1K tokens - "gpt-4": {[]float64{15, 30}, ChannelTypeOpenAI}, - "gpt-4-0314": {[]float64{15, 30}, ChannelTypeOpenAI}, - "gpt-4-0613": {[]float64{15, 30}, ChannelTypeOpenAI}, - // $0.06 / 1K tokens $0.12 / 1K tokens - "gpt-4-32k": {[]float64{30, 60}, ChannelTypeOpenAI}, - "gpt-4-32k-0314": {[]float64{30, 60}, ChannelTypeOpenAI}, - "gpt-4-32k-0613": {[]float64{30, 60}, ChannelTypeOpenAI}, - // $0.01 / 1K tokens $0.03 / 1K tokens - "gpt-4-preview": {[]float64{5, 15}, ChannelTypeOpenAI}, - "gpt-4-1106-preview": {[]float64{5, 15}, ChannelTypeOpenAI}, - "gpt-4-0125-preview": {[]float64{5, 15}, ChannelTypeOpenAI}, - "gpt-4-turbo-preview": {[]float64{5, 15}, ChannelTypeOpenAI}, - "gpt-4-vision-preview": {[]float64{5, 15}, ChannelTypeOpenAI}, - // $0.0005 / 1K tokens $0.0015 / 1K tokens - "gpt-3.5-turbo": {[]float64{0.25, 0.75}, ChannelTypeOpenAI}, - "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, ChannelTypeOpenAI}, - // $0.0015 / 1K tokens $0.002 / 1K tokens - "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, ChannelTypeOpenAI}, - "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, ChannelTypeOpenAI}, - "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, ChannelTypeOpenAI}, - // $0.003 / 1K tokens $0.004 / 1K tokens - "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, ChannelTypeOpenAI}, - "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, ChannelTypeOpenAI}, - // $0.001 / 1K tokens $0.002 / 1K tokens - "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, ChannelTypeOpenAI}, - // $0.0020 / 1K tokens - "davinci-002": {[]float64{1, 1}, ChannelTypeOpenAI}, - // $0.0004 / 1K tokens - "babbage-002": {[]float64{0.2, 0.2}, ChannelTypeOpenAI}, - "text-ada-001": {[]float64{0.2, 0.2}, ChannelTypeOpenAI}, - "text-babbage-001": {[]float64{0.25, 0.25}, ChannelTypeOpenAI}, - "text-curie-001": {[]float64{1, 1}, ChannelTypeOpenAI}, - "text-davinci-002": {[]float64{10, 10}, ChannelTypeOpenAI}, - "text-davinci-003": {[]float64{10, 10}, ChannelTypeOpenAI}, - "text-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI}, - "code-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI}, - // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens - "whisper-1": {[]float64{15, 15}, ChannelTypeOpenAI}, - // $0.015 / 1K characters - "tts-1": {[]float64{7.5, 7.5}, ChannelTypeOpenAI}, - "tts-1-1106": {[]float64{7.5, 7.5}, ChannelTypeOpenAI}, - // $0.030 / 1K characters - "tts-1-hd": {[]float64{15, 15}, ChannelTypeOpenAI}, - "tts-1-hd-1106": {[]float64{15, 15}, ChannelTypeOpenAI}, - "davinci": {[]float64{10, 10}, ChannelTypeOpenAI}, - "curie": {[]float64{10, 10}, ChannelTypeOpenAI}, - "babbage": {[]float64{10, 10}, ChannelTypeOpenAI}, - "ada": {[]float64{10, 10}, ChannelTypeOpenAI}, - "text-embedding-ada-002": {[]float64{0.05, 0.05}, ChannelTypeOpenAI}, - // $0.00002 / 1K tokens - "text-embedding-3-small": {[]float64{0.01, 0.01}, ChannelTypeOpenAI}, - // $0.00013 / 1K tokens - "text-embedding-3-large": {[]float64{0.065, 0.065}, ChannelTypeOpenAI}, - "text-search-ada-doc-001": {[]float64{10, 10}, ChannelTypeOpenAI}, - "text-moderation-stable": {[]float64{0.1, 0.1}, ChannelTypeOpenAI}, - "text-moderation-latest": {[]float64{0.1, 0.1}, ChannelTypeOpenAI}, - // $0.016 - $0.020 / image - "dall-e-2": {[]float64{8, 8}, ChannelTypeOpenAI}, - // $0.040 - $0.120 / image - "dall-e-3": {[]float64{20, 20}, ChannelTypeOpenAI}, - - // $0.80/million tokens $2.40/million tokens - "claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic}, - // $8.00/million tokens $24.00/million tokens - "claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic}, - "claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic}, - // $15 / M $75 / M - "claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic}, - // $3 / M $15 / M - "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic}, - // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens - "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic}, - - // ¥0.004 / 1k tokens ¥0.008 / 1k tokens - "ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu}, - // ¥0.012 / 1k tokens ¥0.012 / 1k tokens - "ERNIE-Bot": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu}, - "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu}, - // 0.024元/千tokens 0.048元/千tokens - "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, ChannelTypeBaidu}, - // ¥0.008 / 1k tokens ¥0.008 / 1k tokens - "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaidu}, - // ¥0.12 / 1k tokens ¥0.12 / 1k tokens - "ERNIE-Bot-4": {[]float64{8.572, 8.572}, ChannelTypeBaidu}, - "ERNIE-4.0": {[]float64{8.572, 8.572}, ChannelTypeBaidu}, - // ¥0.002 / 1k tokens - "Embedding-V1": {[]float64{0.1429, 0.1429}, ChannelTypeBaidu}, - // ¥0.004 / 1k tokens - "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, ChannelTypeBaidu}, - - "PaLM-2": {[]float64{1, 1}, ChannelTypePaLM}, - "gemini-pro": {[]float64{1, 1}, ChannelTypeGemini}, - "gemini-pro-vision": {[]float64{1, 1}, ChannelTypeGemini}, - "gemini-1.0-pro": {[]float64{1, 1}, ChannelTypeGemini}, - "gemini-1.5-pro": {[]float64{1, 1}, ChannelTypeGemini}, - - // ¥0.005 / 1k tokens - "glm-3-turbo": {[]float64{0.3572, 0.3572}, ChannelTypeZhipu}, - // ¥0.1 / 1k tokens - "glm-4": {[]float64{7.143, 7.143}, ChannelTypeZhipu}, - "glm-4v": {[]float64{7.143, 7.143}, ChannelTypeZhipu}, - // ¥0.0005 / 1k tokens - "embedding-2": {[]float64{0.0357, 0.0357}, ChannelTypeZhipu}, - // ¥0.25 / 1张图片 - "cogview-3": {[]float64{17.8571, 17.8571}, ChannelTypeZhipu}, - - // ¥0.008 / 1k tokens - "qwen-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeAli}, - // ¥0.02 / 1k tokens - "qwen-plus": {[]float64{1.4286, 1.4286}, ChannelTypeAli}, - "qwen-vl-max": {[]float64{1.4286, 1.4286}, ChannelTypeAli}, - // 0.12元/1,000tokens - "qwen-max": {[]float64{8.5714, 8.5714}, ChannelTypeAli}, - "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, ChannelTypeAli}, - // 0.008元/1,000tokens - "qwen-vl-plus": {[]float64{0.5715, 0.5715}, ChannelTypeAli}, - // ¥0.0007 / 1k tokens - "text-embedding-v1": {[]float64{0.05, 0.05}, ChannelTypeAli}, - - // ¥0.018 / 1k tokens - "SparkDesk": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei}, - "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei}, - "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei}, - "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei}, - "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei}, - - // ¥0.012 / 1k tokens - "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, ChannelType360}, - // ¥0.001 / 1k tokens - "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, ChannelType360}, - "embedding_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360}, - "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360}, - - // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - "hunyuan": {[]float64{7.143, 7.143}, ChannelTypeTencent}, - // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 - // ¥0.01 / 1k tokens - "ChatStd": {[]float64{0.7143, 0.7143}, ChannelTypeTencent}, - //¥0.1 / 1k tokens - "ChatPro": {[]float64{7.143, 7.143}, ChannelTypeTencent}, - - "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaichuan}, // ¥0.008 / 1k tokens - "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, ChannelTypeBaichuan}, // ¥0.016 / 1k tokens - "Baichuan2-53B": {[]float64{1.4286, 1.4286}, ChannelTypeBaichuan}, // ¥0.02 / 1k tokens - "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens - - "abab5.5s-chat": {[]float64{0.3572, 0.3572}, ChannelTypeMiniMax}, // ¥0.005 / 1k tokens - "abab5.5-chat": {[]float64{1.0714, 1.0714}, ChannelTypeMiniMax}, // ¥0.015 / 1k tokens - "abab6-chat": {[]float64{14.2857, 14.2857}, ChannelTypeMiniMax}, // ¥0.2 / 1k tokens - "embo-01": {[]float64{0.0357, 0.0357}, ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens - - "deepseek-coder": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens - "deepseek-chat": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens - - "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, ChannelTypeMoonshot}, // ¥0.012 / 1K tokens - "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, ChannelTypeMoonshot}, // ¥0.024 / 1K tokens - "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, ChannelTypeMoonshot}, // ¥0.06 / 1K tokens - - "open-mistral-7b": {[]float64{0.125, 0.125}, ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens - "open-mixtral-8x7b": {[]float64{0.35, 0.35}, ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens - "mistral-small-latest": {[]float64{1, 3}, ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens - "mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens - "mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens - "mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens - - // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens - "llama2-70b-4096": {[]float64{0.35, 0.4}, ChannelTypeGroq}, - // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens - "llama2-7b-2048": {[]float64{0.05, 0.05}, ChannelTypeGroq}, - "gemma-7b-it": {[]float64{0.05, 0.05}, ChannelTypeGroq}, - // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens - "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, ChannelTypeGroq}, - - // 2.5 元 / 1M tokens 0.0025 / 1k tokens - "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, ChannelTypeLingyi}, - // 12 元 / 1M tokens 0.012 / 1k tokens - "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, ChannelTypeLingyi}, - // 6 元 / 1M tokens 0.006 / 1k tokens - "yi-vl-plus": {[]float64{0.4286, 0.4286}, ChannelTypeLingyi}, - } - - ModelRatio = make(map[string][]float64) - for name, modelType := range ModelTypes { - ModelRatio[name] = modelType.Ratio - } -} - var DalleSizeRatios = map[string]map[string]float64{ "dall-e-2": { "256x256": 1, @@ -234,104 +22,3 @@ var DalleImagePromptLengthLimitations = map[string]int{ "dall-e-2": 1000, "dall-e-3": 4000, } - -func ModelRatio2JSONString() string { - jsonBytes, err := json.Marshal(ModelRatio) - if err != nil { - SysError("error marshalling model ratio: " + err.Error()) - } - return string(jsonBytes) -} - -func UpdateModelRatioByJSONString(jsonStr string) error { - ModelRatio = make(map[string][]float64) - return json.Unmarshal([]byte(jsonStr), &ModelRatio) -} - -func MergeModelRatioByJSONString(jsonStr string) (newJsonStr string, err error) { - isNew := false - inputModelRatio := make(map[string][]float64) - err = json.Unmarshal([]byte(jsonStr), &inputModelRatio) - if err != nil { - inputModelRatioOld := make(map[string]float64) - err = json.Unmarshal([]byte(jsonStr), &inputModelRatioOld) - if err != nil { - return - } - - inputModelRatio = UpdateModeRatioFormat(inputModelRatioOld) - isNew = true - } - - // 与现有的ModelRatio进行比较,如果有新增的模型,需要添加 - for key, value := range ModelRatio { - if _, ok := inputModelRatio[key]; !ok { - isNew = true - inputModelRatio[key] = value - } - } - - if !isNew { - return - } - - var jsonBytes []byte - jsonBytes, err = json.Marshal(inputModelRatio) - if err != nil { - SysError("error marshalling model ratio: " + err.Error()) - } - newJsonStr = string(jsonBytes) - return -} - -func UpdateModeRatioFormat(modelRatioOld map[string]float64) map[string][]float64 { - modelRatioNew := make(map[string][]float64) - for key, value := range modelRatioOld { - completionRatio := GetCompletionRatio(key) * value - modelRatioNew[key] = []float64{value, completionRatio} - } - return modelRatioNew -} - -func GetModelRatio(name string) []float64 { - if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { - name = strings.TrimSuffix(name, "-internet") - } - ratio, ok := ModelRatio[name] - if !ok { - SysError("model ratio not found: " + name) - return []float64{30, 30} - } - return ratio -} - -func GetCompletionRatio(name string) float64 { - if strings.HasPrefix(name, "gpt-3.5") { - if strings.HasSuffix(name, "1106") { - return 2 - } - if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" { - // TODO: clear this after 2023-12-11 - now := time.Now() - // https://platform.openai.com/docs/models/continuous-model-upgrades - // if after 2023-12-11, use 2 - if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) { - return 2 - } - } - return 1.333333 - } - if strings.HasPrefix(name, "gpt-4") { - if strings.HasSuffix(name, "preview") { - return 3 - } - return 2 - } - if strings.HasPrefix(name, "claude-instant-1.2") { - return 3.38 - } - if strings.HasPrefix(name, "claude-2") { - return 2.965517 - } - return 1 -} diff --git a/common/token.go b/common/token.go index 9add2a0b..40ef8733 100644 --- a/common/token.go +++ b/common/token.go @@ -13,46 +13,46 @@ import ( ) var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} -var defaultTokenEncoder *tiktoken.Tiktoken +var gpt35TokenEncoder *tiktoken.Tiktoken +var gpt4TokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { SysLog("initializing token encoders") - gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + var err error + gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo") if err != nil { FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } - defaultTokenEncoder = gpt35TokenEncoder - gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") + + gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4") if err != nil { FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model := range ModelRatio { - if strings.HasPrefix(model, "gpt-3.5") { - tokenEncoderMap[model] = gpt35TokenEncoder - } else if strings.HasPrefix(model, "gpt-4") { - tokenEncoderMap[model] = gpt4TokenEncoder - } else { - tokenEncoderMap[model] = nil - } - } + SysLog("token encoders initialized") } func getTokenEncoder(model string) *tiktoken.Tiktoken { tokenEncoder, ok := tokenEncoderMap[model] - if ok && tokenEncoder != nil { + if ok { return tokenEncoder } - if ok { - tokenEncoder, err := tiktoken.EncodingForModel(model) + + if strings.HasPrefix(model, "gpt-3.5") { + tokenEncoder = gpt35TokenEncoder + } else if strings.HasPrefix(model, "gpt-4") { + tokenEncoder = gpt4TokenEncoder + } else { + var err error + tokenEncoder, err = tiktoken.EncodingForModel(model) if err != nil { SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder = defaultTokenEncoder + tokenEncoder = gpt35TokenEncoder } - tokenEncoderMap[model] = tokenEncoder - return tokenEncoder } - return defaultTokenEncoder + + tokenEncoderMap[model] = tokenEncoder + return tokenEncoder } func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { diff --git a/common/utils.go b/common/utils.go index f9e574b9..4bae968c 100644 --- a/common/utils.go +++ b/common/utils.go @@ -212,3 +212,31 @@ func IsFileExist(path string) bool { _, err := os.Stat(path) return err == nil || os.IsExist(err) } + +func Contains[T comparable](value T, slice []T) bool { + for _, item := range slice { + if item == value { + return true + } + } + return false +} + +func Filter[T any](arr []T, f func(T) bool) []T { + var res []T + for _, v := range arr { + if f(v) { + res = append(res, v) + } + } + return res +} + +func GetModelsWithMatch(modelList *[]string, modelName string) string { + for _, model := range *modelList { + if strings.HasPrefix(modelName, strings.TrimRight(model, "*")) { + return model + } + } + return "" +} diff --git a/config.example.yaml b/config.example.yaml index cf750964..2ed668c4 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -18,6 +18,7 @@ frontend_base_url: "" # 设置之后将重定向页面请求到指定的地址 polling_interval: 0 # 批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 batch_update_interval: 5 # 批量更新聚合的时间间隔,单位为秒,默认为 5。 batch_update_enabled: false # 启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 true 和 false,未设置则默认为 false +auto_price_updates: true # 启用自动更新价格,可选值为 true 和 false,默认为 true # 全局设置 global: diff --git a/controller/pricing.go b/controller/pricing.go new file mode 100644 index 00000000..8174832c --- /dev/null +++ b/controller/pricing.go @@ -0,0 +1,192 @@ +package controller + +import ( + "errors" + "net/http" + "one-api/common" + "one-api/model" + "one-api/relay/util" + + "github.com/gin-gonic/gin" +) + +func GetPricesList(c *gin.Context) { + pricesType := c.DefaultQuery("type", "db") + + prices := util.GetPricesList(pricesType) + + if len(prices) == 0 { + common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found")) + return + } + + if pricesType == "old" { + c.JSON(http.StatusOK, prices) + } else { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": prices, + }) + } +} + +func GetAllModelList(c *gin.Context) { + prices := util.PricingInstance.GetAllPrices() + channelModel := model.ChannelGroup.Rule + + modelsMap := make(map[string]bool) + for modelName := range prices { + modelsMap[modelName] = true + } + + for _, modelMap := range channelModel { + for modelName := range modelMap { + if _, ok := prices[modelName]; !ok { + modelsMap[modelName] = true + } + } + } + + var models []string + for modelName := range modelsMap { + models = append(models, modelName) + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) +} + +func AddPrice(c *gin.Context) { + var price model.Price + if err := c.ShouldBindJSON(&price); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + if err := util.PricingInstance.AddPrice(&price); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} + +func UpdatePrice(c *gin.Context) { + modelName := c.Param("model") + if modelName == "" { + common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required")) + return + } + + var price model.Price + if err := c.ShouldBindJSON(&price); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} + +func DeletePrice(c *gin.Context) { + modelName := c.Param("model") + if modelName == "" { + common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required")) + return + } + + if err := util.PricingInstance.DeletePrice(modelName); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} + +type PriceBatchRequest struct { + OriginalModels []string `json:"original_models"` + util.BatchPrices +} + +func BatchSetPrices(c *gin.Context) { + pricesBatch := &PriceBatchRequest{} + if err := c.ShouldBindJSON(pricesBatch); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} + +type PriceBatchDeleteRequest struct { + Models []string `json:"models" binding:"required"` +} + +func BatchDeletePrices(c *gin.Context) { + pricesBatch := &PriceBatchDeleteRequest{} + if err := c.ShouldBindJSON(pricesBatch); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} + +func SyncPricing(c *gin.Context) { + overwrite := c.DefaultQuery("overwrite", "false") + + prices := make([]*model.Price, 0) + if err := c.ShouldBindJSON(&prices); err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + if len(prices) == 0 { + common.APIRespondWithError(c, http.StatusOK, errors.New("prices is required")) + return + } + + err := util.PricingInstance.SyncPricing(prices, overwrite == "true") + if err != nil { + common.APIRespondWithError(c, http.StatusOK, err) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) +} diff --git a/main.go b/main.go index fe871c1f..18b31de8 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,7 @@ import ( "one-api/controller" "one-api/middleware" "one-api/model" + "one-api/relay/util" "one-api/router" "github.com/gin-contrib/sessions" @@ -35,6 +36,7 @@ func main() { common.InitRedisClient() // Initialize options model.InitOptionMap() + util.NewPricing() initMemoryCache() initSync() diff --git a/model/balancer.go b/model/balancer.go index 6bc19682..44bdcb9e 100644 --- a/model/balancer.go +++ b/model/balancer.go @@ -18,6 +18,7 @@ type ChannelsChooser struct { sync.RWMutex Channels map[int]*ChannelChoice Rule map[string]map[string][][]int // group -> model -> priority -> channelIds + Match []string } func (cc *ChannelsChooser) Cooldowns(channelId int) bool { @@ -74,11 +75,15 @@ func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) { return nil, errors.New("group not found") } - if _, ok := cc.Rule[group][modelName]; !ok { - return nil, errors.New("model not found") + channelsPriority, ok := cc.Rule[group][modelName] + if !ok { + matchModel := common.GetModelsWithMatch(&cc.Match, modelName) + channelsPriority, ok = cc.Rule[group][matchModel] + if !ok { + return nil, errors.New("model not found") + } } - channelsPriority := cc.Rule[group][modelName] if len(channelsPriority) == 0 { return nil, errors.New("channel not found") } @@ -123,6 +128,7 @@ func (cc *ChannelsChooser) Load() { newGroup := make(map[string]map[string][][]int) newChannels := make(map[int]*ChannelChoice) + newMatch := make(map[string]bool) for _, channel := range channels { if *channel.Weight == 0 { @@ -143,6 +149,13 @@ func (cc *ChannelsChooser) Load() { newGroup[ability.Group][ability.Model] = make([][]int, 0) } + // 如果是以 *结尾的 model名称 + if strings.HasSuffix(ability.Model, "*") { + if _, ok := newMatch[ability.Model]; !ok { + newMatch[ability.Model] = true + } + } + var priorityIds []int // 逗号分割 ability.ChannelId channelIds := strings.Split(ability.ChannelIds, ",") @@ -153,9 +166,15 @@ func (cc *ChannelsChooser) Load() { newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds) } + newMatchList := make([]string, 0, len(newMatch)) + for match := range newMatch { + newMatchList = append(newMatchList, match) + } + cc.Lock() cc.Rule = newGroup cc.Channels = newChannels + cc.Match = newMatchList cc.Unlock() common.SysLog("channels Load success") } diff --git a/model/main.go b/model/main.go index 90cd1b96..fcdb952e 100644 --- a/model/main.go +++ b/model/main.go @@ -135,6 +135,10 @@ func InitDB() (err error) { if err != nil { return err } + err = db.AutoMigrate(&Price{}) + if err != nil { + return err + } common.SysLog("database migrated") err = createRootAccountIfNeed() return err diff --git a/model/option.go b/model/option.go index 45979df7..62714ce1 100644 --- a/model/option.go +++ b/model/option.go @@ -67,7 +67,6 @@ func InitOptionMap() { common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["TopUpLink"] = common.TopUpLink common.OptionMap["ChatLink"] = common.ChatLink @@ -76,28 +75,9 @@ func InitOptionMap() { common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds) common.OptionMapRWMutex.Unlock() - initModelRatio() loadOptionsFromDatabase() } -func initModelRatio() { - // 查询数据库中的ModelRatio - option, err := GetOption("ModelRatio") - if err != nil || option.Value == "" { - return - } - - newModelRatio, err := common.MergeModelRatioByJSONString(option.Value) - if err != nil || newModelRatio == "" { - return - } - - // 更新数据库中的ModelRatio - common.SysLog("update ModelRatio") - UpdateOption("ModelRatio", newModelRatio) - -} - func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { @@ -202,8 +182,6 @@ func updateOptionMap(key string, value string) (err error) { switch key { case "EmailDomainWhitelist": common.EmailDomainWhitelist = strings.Split(value, ",") - case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) case "ChannelDisableThreshold": diff --git a/model/price.go b/model/price.go new file mode 100644 index 00000000..b0d0bd92 --- /dev/null +++ b/model/price.go @@ -0,0 +1,300 @@ +package model + +import ( + "one-api/common" + + "gorm.io/gorm" +) + +const ( + TokensPriceType = "tokens" + TimesPriceType = "times" + DefaultPrice = 30.0 + DollarRate = 0.002 + RMBRate = 0.014 +) + +type Price struct { + Model string `json:"model" gorm:"type:varchar(30);primaryKey" binding:"required"` + Type string `json:"type" gorm:"default:'tokens'" binding:"required"` + ChannelType int `json:"channel_type" gorm:"default:0" binding:"gte=0"` + Input float64 `json:"input" gorm:"default:0" binding:"gte=0"` + Output float64 `json:"output" gorm:"default:0" binding:"gte=0"` +} + +func GetAllPrices() ([]*Price, error) { + var prices []*Price + if err := DB.Find(&prices).Error; err != nil { + return nil, err + } + return prices, nil +} + +func (price *Price) Update(modelName string) error { + if err := DB.Model(price).Select("*").Where("model = ?", modelName).Updates(price).Error; err != nil { + return err + } + + return nil +} + +func (price *Price) Insert() error { + if err := DB.Create(price).Error; err != nil { + return err + } + + return nil +} + +func (price *Price) GetInput() float64 { + if price.Input <= 0 { + return 0 + } + return price.Input +} + +func (price *Price) GetOutput() float64 { + if price.Output <= 0 || price.Type == TimesPriceType { + return 0 + } + + return price.Output +} + +func (price *Price) FetchInputCurrencyPrice(rate float64) float64 { + return price.GetInput() * rate +} + +func (price *Price) FetchOutputCurrencyPrice(rate float64) float64 { + return price.GetOutput() * rate +} + +func UpdatePrices(tx *gorm.DB, models []string, prices *Price) error { + err := tx.Model(Price{}).Where("model IN (?)", models).Select("*").Omit("model").Updates( + Price{ + Type: prices.Type, + ChannelType: prices.ChannelType, + Input: prices.Input, + Output: prices.Output, + }).Error + + return err +} + +func DeletePrices(tx *gorm.DB, models []string) error { + err := tx.Where("model IN (?)", models).Delete(&Price{}).Error + + return err +} + +func InsertPrices(tx *gorm.DB, prices []*Price) error { + err := tx.CreateInBatches(prices, 100).Error + return err +} + +func DeleteAllPrices(tx *gorm.DB) error { + err := tx.Where("1=1").Delete(&Price{}).Error + return err +} + +func (price *Price) Delete() error { + err := DB.Delete(price).Error + if err != nil { + return err + } + return err +} + +type ModelType struct { + Ratio []float64 + Type int +} + +// 1 === $0.002 / 1K tokens +// 1 === ¥0.014 / 1k tokens +func GetDefaultPrice() []*Price { + ModelTypes := map[string]ModelType{ + // $0.03 / 1K tokens $0.06 / 1K tokens + "gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI}, + "gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI}, + "gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI}, + // $0.06 / 1K tokens $0.12 / 1K tokens + "gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI}, + "gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI}, + "gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI}, + // $0.01 / 1K tokens $0.03 / 1K tokens + "gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + "gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + "gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + "gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + "gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI}, + // $0.0005 / 1K tokens $0.0015 / 1K tokens + "gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI}, + // $0.0015 / 1K tokens $0.002 / 1K tokens + "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI}, + // $0.003 / 1K tokens $0.004 / 1K tokens + "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, + "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI}, + // $0.001 / 1K tokens $0.002 / 1K tokens + "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI}, + // $0.0020 / 1K tokens + "davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI}, + // $0.0004 / 1K tokens + "babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI}, + // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI}, + // $0.015 / 1K characters + "tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, + "tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI}, + // $0.030 / 1K characters + "tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI}, + "tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI}, + "text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI}, + // $0.00002 / 1K tokens + "text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI}, + // $0.00013 / 1K tokens + "text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI}, + "text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, + "text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI}, + // $0.016 - $0.020 / image + "dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI}, + // $0.040 - $0.120 / image + "dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI}, + + // $0.80/million tokens $2.40/million tokens + "claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic}, + // $8.00/million tokens $24.00/million tokens + "claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic}, + "claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic}, + // $15 / M $75 / M + "claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic}, + // $3 / M $15 / M + "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic}, + // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens + "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic}, + + // ¥0.004 / 1k tokens ¥0.008 / 1k tokens + "ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu}, + // ¥0.012 / 1k tokens ¥0.012 / 1k tokens + "ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, + "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu}, + // 0.024元/千tokens 0.048元/千tokens + "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu}, + // ¥0.008 / 1k tokens ¥0.008 / 1k tokens + "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu}, + // ¥0.12 / 1k tokens ¥0.12 / 1k tokens + "ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, + "ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu}, + // ¥0.002 / 1k tokens + "Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu}, + // ¥0.004 / 1k tokens + "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu}, + + "PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM}, + "gemini-pro": {[]float64{1, 1}, common.ChannelTypeGemini}, + "gemini-pro-vision": {[]float64{1, 1}, common.ChannelTypeGemini}, + "gemini-1.0-pro": {[]float64{1, 1}, common.ChannelTypeGemini}, + "gemini-1.5-pro": {[]float64{1, 1}, common.ChannelTypeGemini}, + + // ¥0.005 / 1k tokens + "glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu}, + // ¥0.1 / 1k tokens + "glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, + "glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu}, + // ¥0.0005 / 1k tokens + "embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu}, + // ¥0.25 / 1张图片 + "cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu}, + + // ¥0.008 / 1k tokens + "qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, + // ¥0.02 / 1k tokens + "qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, + "qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli}, + // 0.12元/1,000tokens + "qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, + "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli}, + // 0.008元/1,000tokens + "qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli}, + // ¥0.0007 / 1k tokens + "text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli}, + + // ¥0.018 / 1k tokens + "SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei}, + + // ¥0.012 / 1k tokens + "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360}, + // ¥0.001 / 1k tokens + "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, + "embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, + "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360}, + + // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + "hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, + // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0 + // ¥0.01 / 1k tokens + "ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent}, + //¥0.1 / 1k tokens + "ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent}, + + "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens + "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens + "Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens + "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens + + "abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens + "abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens + "abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens + "embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens + + "deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens + "deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens + + "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens + "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens + "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens + + "open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens + "open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens + "mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens + "mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens + "mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens + "mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens + + // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens + "llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq}, + // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens + "llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, + "gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq}, + // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens + "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq}, + + // 2.5 元 / 1M tokens 0.0025 / 1k tokens + "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi}, + // 12 元 / 1M tokens 0.012 / 1k tokens + "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi}, + // 6 元 / 1M tokens 0.006 / 1k tokens + "yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi}, + } + + var prices []*Price + + for model, modelType := range ModelTypes { + prices = append(prices, &Price{ + Model: model, + Type: TokensPriceType, + ChannelType: modelType.Type, + Input: modelType.Ratio[0], + Output: modelType.Ratio[1], + }) + } + + return prices +} diff --git a/modelRatio.json b/modelRatio.json deleted file mode 100644 index 1ae7da87..00000000 --- a/modelRatio.json +++ /dev/null @@ -1,98 +0,0 @@ -{ - "gpt-4": [15, 30], - "gpt-4-0314": [15, 30], - "gpt-4-0613": [15, 30], - "gpt-4-32k": [30, 60], - "gpt-4-32k-0314": [30, 60], - "gpt-4-32k-0613": [30, 60], - "gpt-4-preview": [5, 15], - "gpt-4-1106-preview": [5, 15], - "gpt-4-0125-preview": [5, 15], - "gpt-4-vision-preview": [5, 15], - "gpt-3.5-turbo": [0.25, 0.75], - "gpt-3.5-turbo-0125": [0.25, 0.75], - "gpt-3.5-turbo-0301": [0.75, 1], - "gpt-3.5-turbo-0613": [0.75, 1], - "gpt-3.5-turbo-instruct": [0.75, 1], - "gpt-3.5-turbo-16k": [1.5, 2], - "gpt-3.5-turbo-16k-0613": [1.5, 2], - "gpt-3.5-turbo-1106": [0.5, 1], - "davinci-002": [1, 1], - "babbage-002": [0.2, 0.2], - "text-ada-001": [0.2, 0.2], - "text-babbage-001": [0.25, 0.25], - "text-curie-001": [1, 1], - "text-davinci-002": [10, 10], - "text-davinci-003": [10, 10], - "text-davinci-edit-001": [10, 10], - "code-davinci-edit-001": [10, 10], - "whisper-1": [15, 15], - "tts-1": [7.5, 7.5], - "tts-1-1106": [7.5, 7.5], - "tts-1-hd": [15, 15], - "tts-1-hd-1106": [15, 15], - "davinci": [10, 10], - "curie": [10, 10], - "babbage": [10, 10], - "ada": [10, 10], - "text-embedding-ada-002": [0.05, 0.05], - "text-embedding-3-small": [0.01, 0.01], - "text-embedding-3-large": [0.065, 0.065], - "text-search-ada-doc-001": [10, 10], - "text-moderation-stable": [0.1, 0.1], - "text-moderation-latest": [0.1, 0.1], - "dall-e-2": [8, 8], - "dall-e-3": [20, 20], - "claude-instant-1.2": [0.4, 1.2], - "claude-2.0": [4, 12], - "claude-2.1": [4, 12], - "claude-3-opus-20240229": [7.5, 22.5], - "claude-3-sonnet-20240229": [1.3, 3.9], - "ERNIE-Bot": [0.8572, 0.8572], - "ERNIE-Bot-8k": [1.7143, 3.4286], - "ERNIE-Bot-turbo": [0.5715, 0.5715], - "ERNIE-Bot-4": [8.572, 8.572], - "Embedding-V1": [0.1429, 0.1429], - "PaLM-2": [1, 1], - "gemini-pro": [1, 1], - "gemini-pro-vision": [1, 1], - "chatglm_turbo": [0.3572, 0.3572], - "chatglm_std": [0.3572, 0.3572], - "glm-3-turbo": [0.3572, 0.3572], - "chatglm_pro": [0.7143, 0.7143], - "chatglm_lite": [0.1429, 0.1429], - "glm-4": [7.143, 7.143], - "glm-4v": [7.143, 7.143], - "embedding-2": [0.0357, 0.0357], - "cogview-3": [17.8571, 17.8571], - "qwen-turbo": [0.5715, 0.5715], - "qwen-plus": [1.4286, 1.4286], - "qwen-max": [1.4286, 1.4286], - "qwen-max-longcontext": [1.4286, 1.4286], - "qwen-vl": [0.5715, 0.5715], - "qwen-vl-plus": [0.5715, 0.5715], - "text-embedding-v1": [0.05, 0.05], - "SparkDesk": [1.2858, 1.2858], - "SparkDesk-v1.1": [1.2858, 1.2858], - "SparkDesk-v2.1": [1.2858, 1.2858], - "SparkDesk-v3.1": [1.2858, 1.2858], - "SparkDesk-v3.5": [1.2858, 1.2858], - "360GPT_S2_V9": [0.8572, 0.8572], - "embedding-bert-512-v1": [0.0715, 0.0715], - "embedding_s1_v1": [0.0715, 0.0715], - "semantic_similarity_s1_v1": [0.0715, 0.0715], - "hunyuan": [7.143, 7.143], - "Baichuan2-Turbo": [0.5715, 0.5715], - "Baichuan2-Turbo-192k": [1.143, 1.143], - "Baichuan2-53B": [1.4286, 1.4286], - "Baichuan-Text-Embedding": [0.0357, 0.0357], - "abab5.5s-chat": [0.3572, 0.3572], - "abab5.5-chat": [1.0714, 1.0714], - "abab6-chat": [14.2857, 14.2857], - "embo-01": [0.0357, 0.0357], - "deepseek-coder": [0.75, 0.75], - "deepseek-chat": [0.75, 0.75], - "moonshot-v1-8k": [0.8572, 0.8572], - "moonshot-v1-32k": [1.7143, 1.7143], - "moonshot-v1-128k": [4.2857, 4.2857] -} diff --git a/relay/model.go b/relay/model.go index 74e6a636..8bc61618 100644 --- a/relay/model.go +++ b/relay/model.go @@ -5,6 +5,7 @@ import ( "net/http" "one-api/common" "one-api/model" + "one-api/relay/util" "one-api/types" "sort" @@ -13,8 +14,6 @@ import ( // https://platform.openai.com/docs/api-reference/models/list -var unknownOwnedBy = "未知" - type OpenAIModelPermission struct { Id string `json:"id"` Object string `json:"object"` @@ -30,6 +29,11 @@ type OpenAIModelPermission struct { IsBlocking bool `json:"is_blocking"` } +type ModelPrice struct { + Type string `json:"type"` + Input string `json:"input"` + Output string `json:"output"` +} type OpenAIModels struct { Id string `json:"id"` Object string `json:"object"` @@ -38,30 +42,7 @@ type OpenAIModels struct { Permission *[]OpenAIModelPermission `json:"permission"` Root *string `json:"root"` Parent *string `json:"parent"` -} - -var modelOwnedBy map[int]string - -func init() { - modelOwnedBy = map[int]string{ - common.ChannelTypeOpenAI: "OpenAI", - common.ChannelTypeAnthropic: "Anthropic", - common.ChannelTypeBaidu: "Baidu", - common.ChannelTypePaLM: "Google PaLM", - common.ChannelTypeGemini: "Google Gemini", - common.ChannelTypeZhipu: "Zhipu", - common.ChannelTypeAli: "Ali", - common.ChannelTypeXunfei: "Xunfei", - common.ChannelType360: "360", - common.ChannelTypeTencent: "Tencent", - common.ChannelTypeBaichuan: "Baichuan", - common.ChannelTypeMiniMax: "MiniMax", - common.ChannelTypeDeepseek: "Deepseek", - common.ChannelTypeMoonshot: "Moonshot", - common.ChannelTypeMistral: "Mistral", - common.ChannelTypeGroq: "Groq", - common.ChannelTypeLingyi: "Lingyiwanwu", - } + Price *ModelPrice `json:"price"` } func ListModels(c *gin.Context) { @@ -83,17 +64,9 @@ func ListModels(c *gin.Context) { } sort.Strings(models) - groupOpenAIModels := make([]OpenAIModels, 0, len(models)) - for _, modelId := range models { - groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{ - Id: modelId, - Object: "model", - Created: 1677649963, - OwnedBy: getModelOwnedBy(modelId), - Permission: nil, - Root: nil, - Parent: nil, - }) + var groupOpenAIModels []*OpenAIModels + for _, modelName := range models { + groupOpenAIModels = append(groupOpenAIModels, getOpenAIModelWithName(modelName)) } // 根据 OwnedBy 排序 @@ -114,13 +87,14 @@ func ListModels(c *gin.Context) { } func ListModelsForAdmin(c *gin.Context) { - openAIModels := make([]OpenAIModels, 0, len(common.ModelRatio)) - for modelId := range common.ModelRatio { + prices := util.PricingInstance.GetAllPrices() + var openAIModels []OpenAIModels + for modelId, price := range prices { openAIModels = append(openAIModels, OpenAIModels{ Id: modelId, Object: "model", Created: 1677649963, - OwnedBy: getModelOwnedBy(modelId), + OwnedBy: getModelOwnedBy(price.ChannelType), Permission: nil, Root: nil, Parent: nil, @@ -144,21 +118,13 @@ func ListModelsForAdmin(c *gin.Context) { } func RetrieveModel(c *gin.Context) { - modelId := c.Param("model") - ownedByName := getModelOwnedBy(modelId) - if *ownedByName != unknownOwnedBy { - c.JSON(200, OpenAIModels{ - Id: modelId, - Object: "model", - Created: 1677649963, - OwnedBy: ownedByName, - Permission: nil, - Root: nil, - Parent: nil, - }) + modelName := c.Param("model") + openaiModel := getOpenAIModelWithName(modelName) + if *openaiModel.OwnedBy != util.UnknownOwnedBy { + c.JSON(200, openaiModel) } else { openAIError := types.OpenAIError{ - Message: fmt.Sprintf("The model '%s' does not exist", modelId), + Message: fmt.Sprintf("The model '%s' does not exist", modelName), Type: "invalid_request_error", Param: "model", Code: "model_not_found", @@ -169,12 +135,32 @@ func RetrieveModel(c *gin.Context) { } } -func getModelOwnedBy(modelId string) (ownedBy *string) { - if modelType, ok := common.ModelTypes[modelId]; ok { - if ownedByName, ok := modelOwnedBy[modelType.Type]; ok { - return &ownedByName - } +func getModelOwnedBy(channelType int) (ownedBy *string) { + if ownedByName, ok := util.ModelOwnedBy[channelType]; ok { + return &ownedByName } - return &unknownOwnedBy + return &util.UnknownOwnedBy +} + +func getOpenAIModelWithName(modelName string) *OpenAIModels { + price := util.PricingInstance.GetPrice(modelName) + + return &OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1677649963, + OwnedBy: getModelOwnedBy(price.ChannelType), + Permission: nil, + Root: nil, + Parent: nil, + } +} + +func GetModelOwnedBy(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": util.ModelOwnedBy, + }) } diff --git a/relay/util/pricing.go b/relay/util/pricing.go new file mode 100644 index 00000000..390acb8a --- /dev/null +++ b/relay/util/pricing.go @@ -0,0 +1,402 @@ +package util + +import ( + "encoding/json" + "errors" + "one-api/common" + "one-api/model" + "sort" + "strings" + "sync" + + "github.com/spf13/viper" +) + +// PricingInstance is the Pricing instance +var PricingInstance *Pricing + +// Pricing is a struct that contains the pricing data +type Pricing struct { + sync.RWMutex + Prices map[string]*model.Price `json:"models"` + Match []string `json:"-"` +} + +type BatchPrices struct { + Models []string `json:"models" binding:"required"` + Price model.Price `json:"price" binding:"required"` +} + +// NewPricing creates a new Pricing instance +func NewPricing() { + common.SysLog("Initializing Pricing") + + PricingInstance = &Pricing{ + Prices: make(map[string]*model.Price), + Match: make([]string, 0), + } + + err := PricingInstance.Init() + + if err != nil { + common.SysError("Failed to initialize Pricing:" + err.Error()) + return + } + + // 初始化时,需要检测是否有更新 + if viper.GetBool("auto_price_updates") { + common.SysLog("Checking for pricing updates") + prices := model.GetDefaultPrice() + PricingInstance.SyncPricing(prices, false) + common.SysLog("Pricing initialized") + } +} + +// initializes the Pricing instance +func (p *Pricing) Init() error { + prices, err := model.GetAllPrices() + if err != nil { + return err + } + + if len(prices) == 0 { + return nil + } + + newPrices := make(map[string]*model.Price) + newMatch := make(map[string]bool) + + for _, price := range prices { + newPrices[price.Model] = price + if strings.HasSuffix(price.Model, "*") { + if _, ok := newMatch[price.Model]; !ok { + newMatch[price.Model] = true + } + } + } + + var newMatchList []string + for match := range newMatch { + newMatchList = append(newMatchList, match) + } + + p.Lock() + defer p.Unlock() + + p.Prices = newPrices + p.Match = newMatchList + + return nil +} + +// GetPrice returns the price of a model +func (p *Pricing) GetPrice(modelName string) *model.Price { + p.RLock() + defer p.RUnlock() + + if price, ok := p.Prices[modelName]; ok { + return price + } + + matchModel := common.GetModelsWithMatch(&p.Match, modelName) + if price, ok := p.Prices[matchModel]; ok { + return price + } + + return &model.Price{ + Type: model.TokensPriceType, + ChannelType: common.ChannelTypeUnknown, + Input: model.DefaultPrice, + Output: model.DefaultPrice, + } +} + +func (p *Pricing) GetAllPrices() map[string]*model.Price { + return p.Prices +} + +func (p *Pricing) GetAllPricesList() []*model.Price { + var prices []*model.Price + for _, price := range p.Prices { + prices = append(prices, price) + } + + return prices +} + +func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error { + if _, ok := p.Prices[modelName]; !ok { + return errors.New("model not found") + } + + if _, ok := p.Prices[price.Model]; modelName != price.Model && ok { + return errors.New("model names cannot be duplicated") + } + + if err := p.deleteRawPrice(modelName); err != nil { + return err + } + + return price.Insert() +} + +// UpdatePrice updates the price of a model +func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error { + + if err := p.updateRawPrice(modelName, price); err != nil { + return err + } + + err := p.Init() + + return err +} + +func (p *Pricing) addRawPrice(price *model.Price) error { + if _, ok := p.Prices[price.Model]; ok { + return errors.New("model already exists") + } + + return price.Insert() +} + +// AddPrice adds a new price to the Pricing instance +func (p *Pricing) AddPrice(price *model.Price) error { + if err := p.addRawPrice(price); err != nil { + return err + } + + err := p.Init() + + return err +} + +func (p *Pricing) deleteRawPrice(modelName string) error { + item, ok := p.Prices[modelName] + if !ok { + return errors.New("model not found") + } + + return item.Delete() +} + +// DeletePrice deletes a price from the Pricing instance +func (p *Pricing) DeletePrice(modelName string) error { + if err := p.deleteRawPrice(modelName); err != nil { + return err + } + + err := p.Init() + + return err +} + +// SyncPricing syncs the pricing data +func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error { + var err error + if overwrite { + err = p.SyncPriceWithOverwrite(pricing) + } else { + err = p.SyncPriceWithoutOverwrite(pricing) + } + + return err +} + +// SyncPriceWithOverwrite syncs the pricing data with overwrite +func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error { + tx := model.DB.Begin() + + err := model.DeleteAllPrices(tx) + if err != nil { + tx.Rollback() + return err + } + + err = model.InsertPrices(tx, pricing) + + if err != nil { + tx.Rollback() + return err + } + + tx.Commit() + + return p.Init() +} + +// SyncPriceWithoutOverwrite syncs the pricing data without overwrite +func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error { + var newPrices []*model.Price + + for _, price := range pricing { + if _, ok := p.Prices[price.Model]; !ok { + newPrices = append(newPrices, price) + } + } + + if len(newPrices) == 0 { + return nil + } + + tx := model.DB.Begin() + err := model.InsertPrices(tx, newPrices) + + if err != nil { + tx.Rollback() + return err + } + + tx.Commit() + + return p.Init() +} + +// BatchDeletePrices deletes the prices of multiple models +func (p *Pricing) BatchDeletePrices(models []string) error { + tx := model.DB.Begin() + + err := model.DeletePrices(tx, models) + if err != nil { + tx.Rollback() + return err + } + + tx.Commit() + + p.Lock() + defer p.Unlock() + + for _, model := range models { + delete(p.Prices, model) + } + + return nil +} + +func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error { + // 查找需要删除的model + var deletePrices []string + var addPrices []*model.Price + var updatePrices []string + + for _, model := range originalModels { + if !common.Contains(model, batchPrices.Models) { + deletePrices = append(deletePrices, model) + } else { + updatePrices = append(updatePrices, model) + } + } + + for _, model := range batchPrices.Models { + if !common.Contains(model, originalModels) { + addPrice := batchPrices.Price + addPrice.Model = model + addPrices = append(addPrices, &addPrice) + } + } + + tx := model.DB.Begin() + if len(addPrices) > 0 { + err := model.InsertPrices(tx, addPrices) + if err != nil { + tx.Rollback() + return err + } + } + + if len(updatePrices) > 0 { + err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price) + if err != nil { + tx.Rollback() + return err + } + } + + if len(deletePrices) > 0 { + err := model.DeletePrices(tx, deletePrices) + if err != nil { + tx.Rollback() + return err + } + + } + tx.Commit() + + return p.Init() +} + +func GetPricesList(pricingType string) []*model.Price { + var prices []*model.Price + + switch pricingType { + case "default": + prices = model.GetDefaultPrice() + case "db": + prices = PricingInstance.GetAllPricesList() + case "old": + prices = GetOldPricesList() + default: + return nil + } + + sort.Slice(prices, func(i, j int) bool { + if prices[i].ChannelType == prices[j].ChannelType { + return prices[i].Model < prices[j].Model + } + return prices[i].ChannelType < prices[j].ChannelType + }) + + return prices +} + +func GetOldPricesList() []*model.Price { + oldDataJson, err := model.GetOption("ModelRatio") + if err != nil || oldDataJson.Value == "" { + return nil + } + + oldData := make(map[string][]float64) + err = json.Unmarshal([]byte(oldDataJson.Value), &oldData) + + if err != nil { + return nil + } + + var prices []*model.Price + for modelName, oldPrice := range oldData { + price := PricingInstance.GetPrice(modelName) + prices = append(prices, &model.Price{ + Model: modelName, + Type: model.TokensPriceType, + ChannelType: price.ChannelType, + Input: oldPrice[0], + Output: oldPrice[1], + }) + } + + return prices +} + +// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices { +// batchPricesMap := make(map[string]*BatchPrices) +// for _, price := range prices { +// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output) +// batchPrice, exists := batchPricesMap[key] +// if exists { +// batchPrice.Models = append(batchPrice.Models, price.Model) +// } else { +// batchPricesMap[key] = &BatchPrices{ +// Models: []string{price.Model}, +// Price: *price, +// } +// } +// } + +// var batchPrices []*BatchPrices +// for _, batchPrice := range batchPricesMap { +// batchPrices = append(batchPrices, batchPrice) +// } + +// return batchPrices +// } diff --git a/relay/util/quota.go b/relay/util/quota.go index 0a7c2be0..0274f51c 100644 --- a/relay/util/quota.go +++ b/relay/util/quota.go @@ -15,17 +15,16 @@ import ( ) type Quota struct { - modelName string - promptTokens int - preConsumedTokens int - modelRatio []float64 - groupRatio float64 - ratio float64 - preConsumedQuota int - userId int - channelId int - tokenId int - HandelStatus bool + modelName string + promptTokens int + price model.Price + groupRatio float64 + inputRatio float64 + preConsumedQuota int + userId int + channelId int + tokenId int + HandelStatus bool } func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) { @@ -37,7 +36,16 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type tokenId: c.GetInt("token_id"), HandelStatus: false, } - quota.init(c.GetString("group")) + + quota.price = *PricingInstance.GetPrice(quota.modelName) + quota.groupRatio = common.GetGroupRatio(c.GetString("group")) + quota.inputRatio = quota.price.GetInput() * quota.groupRatio + + if quota.price.Type == model.TimesPriceType { + quota.preConsumedQuota = int(1000 * quota.inputRatio) + } else { + quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio) + } errWithCode := quota.preQuotaConsumption() if errWithCode != nil { @@ -47,21 +55,6 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type return quota, nil } -func (q *Quota) init(groupName string) { - modelRatio := common.GetModelRatio(q.modelName) - groupRatio := common.GetGroupRatio(groupName) - preConsumedTokens := common.PreConsumedQuota - ratio := modelRatio[0] * groupRatio - preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio) - - q.preConsumedTokens = preConsumedTokens - q.modelRatio = modelRatio - q.groupRatio = groupRatio - q.ratio = ratio - q.preConsumedQuota = preConsumedQuota - -} - func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { userQuota, err := model.CacheGetUserQuota(q.userId) if err != nil { @@ -97,11 +90,17 @@ func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode { func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error { quota := 0 - completionRatio := q.modelRatio[1] * q.groupRatio promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio)))) - if q.ratio != 0 && quota <= 0 { + + if q.price.Type == model.TimesPriceType { + quota = int(1000 * q.inputRatio) + } else { + completionRatio := q.price.GetOutput() * q.groupRatio + quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio)))) + } + + if q.inputRatio != 0 && quota <= 0 { quota = 1 } totalTokens := promptTokens + completionTokens @@ -129,13 +128,18 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, } } var modelRatioStr string - if q.modelRatio[0] == q.modelRatio[1] { - modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0]) + if q.price.Type == model.TimesPriceType { + modelRatioStr = fmt.Sprintf("$%g/次", q.price.FetchInputCurrencyPrice(model.DollarRate)) } else { - modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1]) + // 如果输入费率和输出费率一样,则只显示一个费率 + if q.price.GetInput() == q.price.GetOutput() { + modelRatioStr = fmt.Sprintf("$%g/1k", q.price.FetchInputCurrencyPrice(model.DollarRate)) + } else { + modelRatioStr = fmt.Sprintf("$%g/1k (输入) | $%g/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate)) + } } - logContent := fmt.Sprintf("模型倍率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio) + logContent := fmt.Sprintf("模型费率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio) model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime) model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota) model.UpdateChannelUsedQuota(q.channelId, quota) diff --git a/relay/util/type.go b/relay/util/type.go new file mode 100644 index 00000000..5c16d288 --- /dev/null +++ b/relay/util/type.go @@ -0,0 +1,28 @@ +package util + +import "one-api/common" + +var UnknownOwnedBy = "未知" +var ModelOwnedBy map[int]string + +func init() { + ModelOwnedBy = map[int]string{ + common.ChannelTypeOpenAI: "OpenAI", + common.ChannelTypeAnthropic: "Anthropic", + common.ChannelTypeBaidu: "Baidu", + common.ChannelTypePaLM: "Google PaLM", + common.ChannelTypeGemini: "Google Gemini", + common.ChannelTypeZhipu: "Zhipu", + common.ChannelTypeAli: "Ali", + common.ChannelTypeXunfei: "Xunfei", + common.ChannelType360: "360", + common.ChannelTypeTencent: "Tencent", + common.ChannelTypeBaichuan: "Baichuan", + common.ChannelTypeMiniMax: "MiniMax", + common.ChannelTypeDeepseek: "Deepseek", + common.ChannelTypeMoonshot: "Moonshot", + common.ChannelTypeMistral: "Mistral", + common.ChannelTypeGroq: "Groq", + common.ChannelTypeLingyi: "Lingyiwanwu", + } +} diff --git a/router/api-router.go b/router/api-router.go index 57e81816..d1fef864 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -18,6 +18,8 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/status", controller.GetStatus) apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/about", controller.GetAbout) + apiRouter.GET("/prices", middleware.CORS(), controller.GetPricesList) + apiRouter.GET("/ownedby", relay.GetModelOwnedBy) apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) @@ -129,6 +131,20 @@ func SetApiRouter(router *gin.Engine) { analyticsRoute.GET("/channel_period", controller.GetChannelExpensesByPeriod) analyticsRoute.GET("/redemption_period", controller.GetRedemptionStatisticsByPeriod) } + + pricesRoute := apiRouter.Group("/prices") + pricesRoute.Use(middleware.AdminAuth()) + { + pricesRoute.GET("/model_list", controller.GetAllModelList) + pricesRoute.POST("/single", controller.AddPrice) + pricesRoute.PUT("/single/:model", controller.UpdatePrice) + pricesRoute.DELETE("/single/:model", controller.DeletePrice) + pricesRoute.POST("/multiple", controller.BatchSetPrices) + pricesRoute.PUT("/multiple/delete", controller.BatchDeletePrices) + pricesRoute.POST("/sync", controller.SyncPricing) + + } + } } diff --git a/web/package.json b/web/package.json index e376b4b8..cbb774de 100644 --- a/web/package.json +++ b/web/package.json @@ -9,7 +9,7 @@ "@emotion/react": "^11.9.3", "@emotion/styled": "^11.9.3", "@mui/icons-material": "^5.8.4", - "@mui/lab": "^5.0.0-alpha.88", + "@mui/lab": "^5.0.0-alpha.169", "@mui/material": "^5.8.6", "@mui/system": "^5.8.6", "@mui/utils": "^5.8.6", diff --git a/web/src/menu-items/panel.js b/web/src/menu-items/panel.js index 9bc945f7..a85f6a02 100644 --- a/web/src/menu-items/panel.js +++ b/web/src/menu-items/panel.js @@ -10,7 +10,8 @@ import { IconUser, IconUserScan, IconActivity, - IconBrandTelegram + IconBrandTelegram, + IconReceipt2 } from '@tabler/icons-react'; // constant @@ -25,7 +26,8 @@ const icons = { IconUser, IconUserScan, IconActivity, - IconBrandTelegram + IconBrandTelegram, + IconReceipt2 }; // ==============================|| DASHBOARD MENU ITEMS ||============================== // @@ -112,6 +114,15 @@ const panel = { breadcrumbs: false, isAdmin: true }, + { + id: 'pricing', + title: '模型价格', + type: 'item', + url: '/panel/pricing', + icon: icons.IconReceipt2, + breadcrumbs: false, + isAdmin: true + }, { id: 'setting', title: '设置', diff --git a/web/src/routes/MainRoutes.js b/web/src/routes/MainRoutes.js index f783cf21..45abc176 100644 --- a/web/src/routes/MainRoutes.js +++ b/web/src/routes/MainRoutes.js @@ -15,6 +15,7 @@ const Profile = Loadable(lazy(() => import('views/Profile'))); const NotFoundView = Loadable(lazy(() => import('views/Error'))); const Analytics = Loadable(lazy(() => import('views/Analytics'))); const Telegram = Loadable(lazy(() => import('views/Telegram'))); +const Pricing = Loadable(lazy(() => import('views/Pricing'))); // dashboard routing const Dashboard = Loadable(lazy(() => import('views/Dashboard'))); @@ -76,6 +77,10 @@ const MainRoutes = { { path: 'telegram', element: + }, + { + path: 'pricing', + element: } ] }; diff --git a/web/src/views/Channel/type/Config.js b/web/src/views/Channel/type/Config.js index 23df06eb..c224fb15 100644 --- a/web/src/views/Channel/type/Config.js +++ b/web/src/views/Channel/type/Config.js @@ -32,7 +32,8 @@ const defaultConfig = { other: '', proxy: '单独设置代理地址,支持http和socks5,例如:http://127.0.0.1:1080', test_model: '用于测试使用的模型,为空时无法测速,如:gpt-3.5-turbo', - models: '请选择该渠道所支持的模型', + models: + '请选择该渠道所支持的模型,你也可以输入通配符*来匹配模型,例如:gpt-3.5*,表示支持所有gpt-3.5开头的模型,*号只能在最后一位使用,前面必须有字符,例如:gpt-3.5*是正确的,*gpt-3.5是错误的', model_mapping: '请输入要修改的模型映射关系,格式为:api请求模型ID:实际转发给渠道的模型ID,使用JSON数组表示,例如:{"gpt-3.5": "gpt-35"}', groups: '请选择该渠道所支持的用户组' diff --git a/web/src/views/Log/index.js b/web/src/views/Log/index.js index b4b6e973..742758de 100644 --- a/web/src/views/Log/index.js +++ b/web/src/views/Log/index.js @@ -198,12 +198,12 @@ export default function Log() { }, { id: 'message', - label: '提示', + label: '输入', disableSort: true }, { id: 'completion', - label: '补全', + label: '输出', disableSort: true }, { diff --git a/web/src/views/Pricing/component/CheckUpdates.js b/web/src/views/Pricing/component/CheckUpdates.js new file mode 100644 index 00000000..6092734c --- /dev/null +++ b/web/src/views/Pricing/component/CheckUpdates.js @@ -0,0 +1,195 @@ +import PropTypes from 'prop-types'; +import { useState, useEffect } from 'react'; +import { + Dialog, + DialogTitle, + DialogContent, + DialogActions, + Divider, + Button, + TextField, + Grid, + FormControl, + Alert, + Stack, + Typography +} from '@mui/material'; +import { API } from 'utils/api'; +import { showError, showSuccess } from 'utils/common'; +import LoadingButton from '@mui/lab/LoadingButton'; +import Label from 'ui-component/Label'; + +export const CheckUpdates = ({ open, onCancel, onOk, row }) => { + const [url, setUrl] = useState('https://raw.githubusercontent.com/MartialBE/one-api/prices/prices.json'); + const [loading, setLoading] = useState(false); + const [updateLoading, setUpdateLoading] = useState(false); + const [newPricing, setNewPricing] = useState([]); + const [addModel, setAddModel] = useState([]); + const [diffModel, setDiffModel] = useState([]); + + const handleCheckUpdates = async () => { + setLoading(true); + try { + const res = await API.get(url); + // 检测是否是一个列表 + if (!Array.isArray(res.data)) { + showError('数据格式不正确'); + } else { + setNewPricing(res.data); + } + } catch (err) { + console.error(err); + } + setLoading(false); + }; + + const syncPricing = async (overwrite) => { + setUpdateLoading(true); + if (!newPricing.length) { + showError('请先获取数据'); + return; + } + + if (!overwrite && !addModel.length) { + showError('没有新增模型'); + return; + } + try { + overwrite = overwrite ? 'true' : 'false'; + const res = await API.post('/api/prices/sync?overwrite=' + overwrite, newPricing); + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + onOk(true); + } else { + showError(message); + } + } catch (err) { + console.error(err); + } + setUpdateLoading(false); + }; + + useEffect(() => { + const newModels = newPricing.filter((np) => !row.some((r) => r.model === np.model)); + + const changeModel = row.filter((r) => + newPricing.some((np) => np.model === r.model && (np.input !== r.input || np.output !== r.output)) + ); + + if (newModels.length > 0) { + const newModelsList = newModels.map((model) => model.model); + setAddModel(newModelsList); + } else { + setAddModel(''); + } + + if (changeModel.length > 0) { + const changeModelList = changeModel.map((model) => { + const newModel = newPricing.find((np) => np.model === model.model); + let changes = ''; + if (model.input !== newModel.input) { + changes += `输入倍率由 ${model.input} 变为 ${newModel.input},`; + } + if (model.output !== newModel.output) { + changes += `输出倍率由 ${model.output} 变为 ${newModel.output}`; + } + return `${model.model}:${changes}`; + }); + setDiffModel(changeModelList); + } else { + setDiffModel(''); + } + }, [row, newPricing]); + + return ( + + + 检查更新 + + + + + + + setUrl(e.target.value)} /> + + + + + 获取数据 + + + {newPricing.length > 0 && ( + + {!addModel.length && !diffModel.length && 无更新} + + {addModel.length > 0 && ( + + 新增模型: + + {addModel.map((model) => ( + + ))} + + + )} + + {diffModel.length > 0 && ( + + 价格变动模型(仅供参考,如果你自己修改了对应模型的价格请忽略): + {diffModel.map((model) => ( + + {model} + + ))} + + )} + + 注意: + 你可以选择覆盖或者仅添加新增,如果你选择覆盖,将会删除你自己添加的模型价格,完全使用远程配置,如果你选择仅添加新增,将会只会添加 + 新增模型的价格 + + + { + syncPricing(true); + }} + loading={updateLoading} + > + 覆盖数据 + + { + syncPricing(false); + }} + loading={updateLoading} + > + 仅添加新增 + + + + )} + + + + + + + ); +}; + +CheckUpdates.propTypes = { + open: PropTypes.bool, + row: PropTypes.array, + onCancel: PropTypes.func, + onOk: PropTypes.func +}; diff --git a/web/src/views/Pricing/component/EditModal.js b/web/src/views/Pricing/component/EditModal.js new file mode 100644 index 00000000..c0120e04 --- /dev/null +++ b/web/src/views/Pricing/component/EditModal.js @@ -0,0 +1,299 @@ +import PropTypes from 'prop-types'; +import * as Yup from 'yup'; +import { Formik } from 'formik'; +import { useTheme } from '@mui/material/styles'; +import { useState, useEffect } from 'react'; +import { + Dialog, + DialogTitle, + DialogContent, + DialogActions, + Button, + Divider, + FormControl, + InputLabel, + OutlinedInput, + InputAdornment, + FormHelperText, + Select, + Autocomplete, + TextField, + Checkbox, + MenuItem +} from '@mui/material'; + +import { showSuccess, showError } from 'utils/common'; +import { API } from 'utils/api'; +import { createFilterOptions } from '@mui/material/Autocomplete'; +import { ValueFormatter, priceType } from './util'; +import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank'; +import CheckBoxIcon from '@mui/icons-material/CheckBox'; + +const icon = ; +const checkedIcon = ; + +const filter = createFilterOptions(); +const validationSchema = Yup.object().shape({ + is_edit: Yup.boolean(), + type: Yup.string().oneOf(['tokens', 'times'], '类型 错误').required('类型 不能为空'), + channel_type: Yup.number().min(1, '渠道类型 错误').required('渠道类型 不能为空'), + input: Yup.number().required('输入倍率 不能为空'), + output: Yup.number().required('输出倍率 不能为空'), + models: Yup.array().min(1, '模型 不能为空') +}); + +const originInputs = { + is_edit: false, + type: 'tokens', + channel_type: 1, + input: 0, + output: 0, + models: [] +}; + +const EditModal = ({ open, pricesItem, onCancel, onOk, ownedby, noPriceModel }) => { + const theme = useTheme(); + const [inputs, setInputs] = useState(originInputs); + const [selectModel, setSelectModel] = useState([]); + + const submit = async (values, { setErrors, setStatus, setSubmitting }) => { + setSubmitting(true); + try { + const res = await API.post(`/api/prices/multiple`, { + original_models: inputs.models, + models: values.models, + price: { + model: 'batch', + type: values.type, + channel_type: values.channel_type, + input: values.input, + output: values.output + } + }); + const { success, message } = res.data; + if (success) { + showSuccess('保存成功!'); + setSubmitting(false); + setStatus({ success: true }); + onOk(true); + return; + } else { + setStatus({ success: false }); + showError(message); + setErrors({ submit: message }); + } + } catch (error) { + setStatus({ success: false }); + showError(error.message); + setErrors({ submit: error.message }); + return; + } + onOk(); + }; + + useEffect(() => { + if (pricesItem) { + setSelectModel(pricesItem.models.concat(noPriceModel)); + } else { + setSelectModel(noPriceModel); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [pricesItem, noPriceModel]); + + useEffect(() => { + if (pricesItem) { + setInputs(pricesItem); + } else { + setInputs(originInputs); + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [pricesItem]); + + return ( + + + {pricesItem ? '编辑' : '新建'} + + + + + {({ errors, handleBlur, handleChange, handleSubmit, touched, values, isSubmitting }) => ( +
+ + 名称 + + {touched.type && errors.type && ( + + {errors.type} + + )} + + + + 渠道类型 + + {touched.channel_type && errors.channel_type && ( + + {errors.channel_type} + + )} + + + + 输入倍率 + {ValueFormatter(values.input)}} + onBlur={handleBlur} + onChange={handleChange} + aria-describedby="helper-text-channel-input-label" + /> + + {touched.input && errors.input && ( + + {errors.input} + + )} + + + + 输出倍率 + {ValueFormatter(values.output)}} + onBlur={handleBlur} + onChange={handleChange} + aria-describedby="helper-text-channel-output-label" + /> + + {touched.output && errors.output && ( + + {errors.output} + + )} + + + + { + const event = { + target: { + name: 'models', + value: value + } + }; + handleChange(event); + }} + onBlur={handleBlur} + // filterSelectedOptions + disableCloseOnSelect + renderInput={(params) => } + filterOptions={(options, params) => { + const filtered = filter(options, params); + const { inputValue } = params; + const isExisting = options.some((option) => inputValue === option); + if (inputValue !== '' && !isExisting) { + filtered.push(inputValue); + } + return filtered; + }} + renderOption={(props, option, { selected }) => ( +
  • + + {option} +
  • + )} + /> + {errors.models ? ( + + {errors.models} + + ) : ( + + {' '} + 请选择该价格所支持的模型,你也可以输入通配符*来匹配模型,例如:gpt-3.5*,表示支持所有gpt-3.5开头的模型,*号只能在最后一位使用,前面必须有字符,例如:gpt-3.5*是正确的,*gpt-3.5是错误的{' '} + + )} +
    + + + + + +
    + )} +
    +
    +
    + ); +}; + +export default EditModal; + +EditModal.propTypes = { + open: PropTypes.bool, + pricesItem: PropTypes.object, + onCancel: PropTypes.func, + onOk: PropTypes.func, + ownedby: PropTypes.array, + noPriceModel: PropTypes.array +}; diff --git a/web/src/views/Pricing/component/TableRow.js b/web/src/views/Pricing/component/TableRow.js new file mode 100644 index 00000000..43664c31 --- /dev/null +++ b/web/src/views/Pricing/component/TableRow.js @@ -0,0 +1,156 @@ +import PropTypes from 'prop-types'; +import { useState } from 'react'; + +import { + Popover, + TableRow, + MenuItem, + TableCell, + IconButton, + Dialog, + DialogActions, + DialogContent, + DialogContentText, + DialogTitle, + Collapse, + Grid, + Box, + Typography, + Button +} from '@mui/material'; + +import { IconDotsVertical, IconEdit, IconTrash } from '@tabler/icons-react'; +import { ValueFormatter, priceType } from './util'; +import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'; +import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp'; +import Label from 'ui-component/Label'; +import { copy } from 'utils/common'; + +export default function PricesTableRow({ item, managePrices, handleOpenModal, setModalPricesItem, ownedby }) { + const [open, setOpen] = useState(null); + const [openRow, setOpenRow] = useState(false); + const [openDelete, setOpenDelete] = useState(false); + const type_label = priceType.find((pt) => pt.value === item.type); + const channel_label = ownedby.find((ob) => ob.value === item.channel_type); + const handleDeleteOpen = () => { + handleCloseMenu(); + setOpenDelete(true); + }; + + const handleDeleteClose = () => { + setOpenDelete(false); + }; + + const handleOpenMenu = (event) => { + setOpen(event.currentTarget); + }; + + const handleCloseMenu = () => { + setOpen(null); + }; + + const handleDelete = async () => { + handleDeleteClose(); + await managePrices(item, 'delete', ''); + }; + + return ( + <> + setOpenRow(!openRow)}> + + setOpenRow(!openRow)}> + {openRow ? : } + + + + {type_label?.label} + + {channel_label?.label} + {ValueFormatter(item.input)} + {ValueFormatter(item.output)} + {item.models.length} + + event.stopPropagation()}> + + + + + + + + + + + + + + 可用模型: + + {item.models.map((model) => ( + + ))} + + + + + + + + { + handleCloseMenu(); + handleOpenModal(); + setModalPricesItem(item); + }} + > + + 编辑 + + + + 删除 + + + + + 删除价格组 + + 是否删除价格组? + + + + + + + + ); +} + +PricesTableRow.propTypes = { + item: PropTypes.object, + managePrices: PropTypes.func, + handleOpenModal: PropTypes.func, + setModalPricesItem: PropTypes.func, + priceType: PropTypes.array, + ownedby: PropTypes.array +}; diff --git a/web/src/views/Pricing/component/util.js b/web/src/views/Pricing/component/util.js new file mode 100644 index 00000000..3ff0f0d8 --- /dev/null +++ b/web/src/views/Pricing/component/util.js @@ -0,0 +1,11 @@ +export const priceType = [ + { value: 'tokens', label: '按Token收费' }, + { value: 'times', label: '按次收费' } +]; + +export function ValueFormatter(value) { + if (value == null) { + return ''; + } + return `$${parseFloat(value * 0.002).toFixed(4)} / ¥${parseFloat(value * 0.014).toFixed(4)}`; +} diff --git a/web/src/views/Pricing/index.js b/web/src/views/Pricing/index.js new file mode 100644 index 00000000..4eca5b0b --- /dev/null +++ b/web/src/views/Pricing/index.js @@ -0,0 +1,221 @@ +import { useState, useEffect, useMemo, useCallback } from 'react'; +import PropTypes from 'prop-types'; +import { Tabs, Tab, Box, Card, Alert, Stack, Button } from '@mui/material'; +import { IconTag, IconTags } from '@tabler/icons-react'; +import Single from './single'; +import Multiple from './multiple'; +import { useLocation, useNavigate } from 'react-router-dom'; +import AdminContainer from 'ui-component/AdminContainer'; +import { API } from 'utils/api'; +import { showError } from 'utils/common'; +import { CheckUpdates } from './component/CheckUpdates'; + +function CustomTabPanel(props) { + const { children, value, index, ...other } = props; + + return ( + + ); +} + +CustomTabPanel.propTypes = { + children: PropTypes.node, + index: PropTypes.number.isRequired, + value: PropTypes.number.isRequired +}; + +function a11yProps(index) { + return { + id: `pricing-tab-${index}`, + 'aria-controls': `pricing-tabpanel-${index}` + }; +} + +const Pricing = () => { + const [ownedby, setOwnedby] = useState([]); + const [modelList, setModelList] = useState([]); + const [openModal, setOpenModal] = useState(false); + const [errPrices, setErrPrices] = useState(''); + const [prices, setPrices] = useState([]); + const [noPriceModel, setNoPriceModel] = useState([]); + + const location = useLocation(); + const navigate = useNavigate(); + const hash = location.hash.replace('#', ''); + const tabMap = useMemo( + () => ({ + single: 0, + multiple: 1 + }), + [] + ); + const [value, setValue] = useState(tabMap[hash] || 0); + + const handleChange = (event, newValue) => { + setValue(newValue); + const hashArray = Object.keys(tabMap); + navigate(`#${hashArray[newValue]}`); + }; + + const reloadData = () => { + fetchModelList(); + fetchPrices(); + }; + + const handleOkModal = (status) => { + if (status === true) { + reloadData(); + setOpenModal(false); + } + }; + + useEffect(() => { + const missingModels = modelList.filter((model) => !prices.some((price) => price.model === model)); + setNoPriceModel(missingModels); + }, [modelList, prices]); + + useEffect(() => { + // check if there is any price that is not valid + const invalidPrices = prices.filter((price) => price.channel_type <= 0); + if (invalidPrices.length > 0) { + setErrPrices(invalidPrices.map((price) => price.model).join(', ')); + } else { + setErrPrices(''); + } + }, [prices]); + + const fetchOwnedby = useCallback(async () => { + try { + const res = await API.get('/api/ownedby'); + const { success, message, data } = res.data; + if (success) { + let ownedbyList = []; + for (let key in data) { + ownedbyList.push({ value: parseInt(key), label: data[key] }); + } + setOwnedby(ownedbyList); + } else { + showError(message); + } + } catch (error) { + console.error(error); + } + }, []); + + const fetchModelList = useCallback(async () => { + try { + const res = await API.get('/api/prices/model_list'); + const { success, message, data } = res.data; + if (success) { + setModelList(data); + } else { + showError(message); + } + } catch (error) { + console.error(error); + } + }, []); + + const fetchPrices = useCallback(async () => { + try { + const res = await API.get('/api/prices'); + const { success, message, data } = res.data; + if (success) { + setPrices(data); + } else { + showError(message); + } + } catch (error) { + console.error(error); + } + }, []); + + useEffect(() => { + const handleHashChange = () => { + const hash = location.hash.replace('#', ''); + setValue(tabMap[hash] || 0); + }; + window.addEventListener('hashchange', handleHashChange); + return () => { + window.removeEventListener('hashchange', handleHashChange); + }; + }, [location, tabMap, fetchOwnedby]); + + useEffect(() => { + const fetchData = async () => { + try { + await Promise.all([fetchOwnedby(), fetchModelList()]); + fetchPrices(); + } catch (error) { + console.error(error); + } + }; + + fetchData(); + }, [fetchOwnedby, fetchModelList, fetchPrices]); + + return ( + + + 美元:1 === $0.002 / 1K tokens 人民币: 1 === ¥0.014 / 1k tokens +
    例如
    gpt-4 输入: $0.03 / 1K tokens 完成:$0.06 / 1K tokens
    + 0.03 / 0.002 = 15, 0.06 / 0.002 = 30,即输入倍率为 15,完成倍率为 30 +
    + + {noPriceModel.length > 0 && ( + + 存在未配置价格的模型,请及时配置价格: + {noPriceModel.map((model) => ( + {model}, + ))} + + )} + + {errPrices && ( + + 存在供应商类型错误的模型,请及时配置:{errPrices} + + )} + + + + + + + + + } iconPosition="start" /> + } iconPosition="start" /> + + + + + + + + + + + + { + setOpenModal(false); + }} + row={prices} + onOk={handleOkModal} + /> +
    + ); +}; + +export default Pricing; diff --git a/web/src/views/Pricing/multiple.js b/web/src/views/Pricing/multiple.js new file mode 100644 index 00000000..adaa4145 --- /dev/null +++ b/web/src/views/Pricing/multiple.js @@ -0,0 +1,149 @@ +import PropTypes from 'prop-types'; +import { useState, useEffect } from 'react'; +import { showError, showSuccess } from 'utils/common'; + +import Table from '@mui/material/Table'; +import TableBody from '@mui/material/TableBody'; +import TableContainer from '@mui/material/TableContainer'; +import PerfectScrollbar from 'react-perfect-scrollbar'; + +import { Button, Card, Stack } from '@mui/material'; +import PricesTableRow from './component/TableRow'; +import KeywordTableHead from 'ui-component/TableHead'; +import { API } from 'utils/api'; +import { IconRefresh, IconPlus } from '@tabler/icons-react'; +import EditeModal from './component/EditModal'; + +// ---------------------------------------------------------------------- +export default function Multiple({ ownedby, prices, reloadData, noPriceModel }) { + const [rows, setRows] = useState([]); + + const [openModal, setOpenModal] = useState(false); + const [editPricesItem, setEditPricesItem] = useState(null); + + // 处理刷新 + const handleRefresh = async () => { + reloadData(); + }; + + useEffect(() => { + const grouped = prices.reduce((acc, item, index) => { + const key = `${item.type}-${item.channel_type}-${item.input}-${item.output}`; + + if (!acc[key]) { + acc[key] = { + ...item, + models: [item.model], + id: index + 1 + }; + } else { + acc[key].models.push(item.model); + } + return acc; + }, {}); + + setRows(Object.values(grouped)); + }, [prices]); + + const managePrices = async (item, action) => { + let res; + try { + switch (action) { + case 'delete': + res = await API.put('/api/prices/multiple/delete', { + models: item.models + }); + break; + } + const { success, message } = res.data; + if (success) { + showSuccess('操作成功完成!'); + if (action === 'delete') { + await handleRefresh(); + } + } else { + showError(message); + } + + return res.data; + } catch (error) { + return; + } + }; + + const handleOpenModal = (item) => { + setEditPricesItem(item); + setOpenModal(true); + }; + + const handleCloseModal = () => { + setOpenModal(false); + setEditPricesItem(null); + }; + + const handleOkModal = (status) => { + if (status === true) { + handleCloseModal(); + handleRefresh(); + } + }; + + return ( + <> + + + + + + + + + + + {rows.map((row) => ( + + ))} + +
    +
    +
    +
    + + + ); +} + +Multiple.propTypes = { + prices: PropTypes.array, + ownedby: PropTypes.array, + reloadData: PropTypes.func, + noPriceModel: PropTypes.array +}; diff --git a/web/src/views/Setting/component/ModelRationDataGrid.js b/web/src/views/Pricing/single.js similarity index 54% rename from web/src/views/Setting/component/ModelRationDataGrid.js rename to web/src/views/Pricing/single.js index edac3ff0..ef37348a 100644 --- a/web/src/views/Setting/component/ModelRationDataGrid.js +++ b/web/src/views/Pricing/single.js @@ -1,19 +1,30 @@ import PropTypes from 'prop-types'; import { useState, useEffect, useMemo, useCallback } from 'react'; import { GridRowModes, DataGrid, GridToolbarContainer, GridActionsCellItem } from '@mui/x-data-grid'; -import { Box, Button } from '@mui/material'; +import { Box, Button, Dialog, DialogActions, DialogContent, DialogTitle } from '@mui/material'; import AddIcon from '@mui/icons-material/Add'; import EditIcon from '@mui/icons-material/Edit'; import DeleteIcon from '@mui/icons-material/DeleteOutlined'; import SaveIcon from '@mui/icons-material/Save'; import CancelIcon from '@mui/icons-material/Close'; -import { showError } from 'utils/common'; +import { showError, showSuccess } from 'utils/common'; +import { API } from 'utils/api'; +import { ValueFormatter, priceType } from './component/util'; function validation(row, rows) { if (row.model === '') { return '模型名称不能为空'; } + // 判断 type 是否是 等于 tokens || times + if (row.type !== 'tokens' && row.type !== 'times') { + return '类型只能是tokens或times'; + } + + if (row.channel_type <= 0) { + return '所属渠道类型错误'; + } + // 判断 model是否是唯一值 if (rows.filter((r) => r.model === row.model && (row.isNew || r.id !== row.id)).length > 0) { return '模型名称不能重复'; @@ -22,8 +33,8 @@ function validation(row, rows) { if (row.input === '' || row.input < 0) { return '输入倍率必须大于等于0'; } - if (row.complete === '' || row.complete < 0) { - return '完成倍率必须大于等于0'; + if (row.output === '' || row.output < 0) { + return '输出倍率必须大于等于0'; } return false; } @@ -35,7 +46,7 @@ function randomId() { function EditToolbar({ setRows, setRowModesModel }) { const handleClick = () => { const id = randomId(); - setRows((oldRows) => [{ id, model: '', input: 0, complete: 0, isNew: true }, ...oldRows]); + setRows((oldRows) => [{ id, model: '', type: 'tokens', channel_type: 1, input: 0, output: 0, isNew: true }, ...oldRows]); setRowModesModel((oldModel) => ({ [id]: { mode: GridRowModes.Edit, fieldToFocus: 'name' }, ...oldModel @@ -56,19 +67,33 @@ EditToolbar.propTypes = { setRowModesModel: PropTypes.func.isRequired }; -const ModelRationDataGrid = ({ ratio, onChange }) => { +const Single = ({ ownedby, prices, reloadData }) => { const [rows, setRows] = useState([]); const [rowModesModel, setRowModesModel] = useState({}); + const [selectedRow, setSelectedRow] = useState(null); - const setRatio = useCallback( - (ratioRow) => { - let ratioJson = {}; - ratioRow.forEach((row) => { - ratioJson[row.model] = [row.input, row.complete]; - }); - onChange({ target: { name: 'ModelRatio', value: JSON.stringify(ratioJson, null, 2) } }); + const addOrUpdatePirces = useCallback( + async (newRow, oldRow, reject, resolve) => { + try { + let res; + if (oldRow.model == '') { + res = await API.post('/api/prices/single', newRow); + } else { + res = await API.put('/api/prices/single/' + oldRow.model, newRow); + } + const { success, message } = res.data; + if (success) { + showSuccess('保存成功'); + resolve(newRow); + reloadData(); + } else { + reject(new Error(message)); + } + } catch (error) { + reject(new Error(error)); + } }, - [onChange] + [reloadData] ); const handleEditClick = useCallback( @@ -87,11 +112,21 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { const handleDeleteClick = useCallback( (id) => () => { - setRatio(rows.filter((row) => row.id !== id)); + setSelectedRow(rows.find((row) => row.id === id)); }, - [rows, setRatio] + [rows] ); + const handleClose = () => { + setSelectedRow(null); + }; + + const handleConfirmDelete = async () => { + // 执行删除操作 + await deletePirces(selectedRow.model); + setSelectedRow(null); + }; + const handleCancelClick = useCallback( (id) => () => { setRowModesModel({ @@ -107,18 +142,30 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { [rowModesModel, rows] ); - const processRowUpdate = (newRow, oldRows) => { - if (!newRow.isNew && newRow.model === oldRows.model && newRow.input === oldRows.input && newRow.complete === oldRows.complete) { - return oldRows; - } - const updatedRow = { ...newRow, isNew: false }; - const error = validation(updatedRow, rows); - if (error) { - return Promise.reject(new Error(error)); - } - setRatio(rows.map((row) => (row.id === newRow.id ? updatedRow : row))); - return updatedRow; - }; + const processRowUpdate = useCallback( + (newRow, oldRows) => + new Promise((resolve, reject) => { + if ( + !newRow.isNew && + newRow.model === oldRows.model && + newRow.input === oldRows.input && + newRow.output === oldRows.output && + newRow.type === oldRows.type && + newRow.channel_type === oldRows.channel_type + ) { + return resolve(oldRows); + } + const updatedRow = { ...newRow, isNew: false }; + const error = validation(updatedRow, rows); + if (error) { + return reject(new Error(error)); + } + + const response = addOrUpdatePirces(updatedRow, oldRows, reject, resolve); + return response; + }), + [rows, addOrUpdatePirces] + ); const handleProcessRowUpdateError = useCallback((error) => { showError(error.message); @@ -138,6 +185,26 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { editable: true, hideable: false }, + { + field: 'type', + sortable: true, + headerName: '类型', + width: 220, + type: 'singleSelect', + valueOptions: priceType, + editable: true, + hideable: false + }, + { + field: 'channel_type', + sortable: true, + headerName: '供应商', + width: 220, + type: 'singleSelect', + valueOptions: ownedby, + editable: true, + hideable: false + }, { field: 'input', sortable: false, @@ -145,27 +212,17 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { width: 150, type: 'number', editable: true, - valueFormatter: (params) => { - if (params.value == null) { - return ''; - } - return `$${parseFloat(params.value * 0.002).toFixed(4)} / ¥${parseFloat(params.value * 0.014).toFixed(4)}`; - }, + valueFormatter: (params) => ValueFormatter(params.value), hideable: false }, { - field: 'complete', + field: 'output', sortable: false, - headerName: '完成倍率', + headerName: '输出倍率', width: 150, type: 'number', editable: true, - valueFormatter: (params) => { - if (params.value == null) { - return ''; - } - return `$${parseFloat(params.value * 0.002).toFixed(4)} / ¥${parseFloat(params.value * 0.014).toFixed(4)}`; - }, + valueFormatter: (params) => ValueFormatter(params.value), hideable: false }, { @@ -220,18 +277,32 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { } } ], - [handleEditClick, handleSaveClick, handleDeleteClick, handleCancelClick, rowModesModel] + [handleCancelClick, handleDeleteClick, handleEditClick, handleSaveClick, rowModesModel, ownedby] ); + const deletePirces = async (modelName) => { + try { + const res = await API.delete('/api/prices/single/' + modelName); + const { success, message } = res.data; + if (success) { + showSuccess('保存成功'); + await reloadData(); + } else { + showError(message); + } + } catch (error) { + console.error(error); + } + }; + useEffect(() => { let modelRatioList = []; - let itemJson = JSON.parse(ratio); let id = 0; - for (let key in itemJson) { - modelRatioList.push({ id: id++, model: key, input: itemJson[key][0], complete: itemJson[key][1] }); + for (let key in prices) { + modelRatioList.push({ id: id++, ...prices[key] }); } setRows(modelRatioList); - }, [ratio]); + }, [prices]); return ( { onRowModesModelChange={handleRowModesModelChange} processRowUpdate={processRowUpdate} onProcessRowUpdateError={handleProcessRowUpdateError} + // onCellDoubleClick={(params, event) => { + // event.defaultMuiPrevented = true; + // }} + onRowEditStop={(params, event) => { + if (params.reason === 'rowFocusOut') { + event.defaultMuiPrevented = true; + } + }} slots={{ toolbar: EditToolbar }} @@ -263,13 +342,27 @@ const ModelRationDataGrid = ({ ratio, onChange }) => { toolbar: { setRows, setRowModesModel } }} /> + + + 确定删除? + {`确定删除 ${selectedRow?.model} 吗?`} + + + + + ); }; -ModelRationDataGrid.propTypes = { - ratio: PropTypes.string.isRequired, - onChange: PropTypes.func.isRequired -}; +export default Single; -export default ModelRationDataGrid; +Single.propTypes = { + prices: PropTypes.array, + ownedby: PropTypes.array, + reloadData: PropTypes.func +}; diff --git a/web/src/views/Setting/component/OperationSetting.js b/web/src/views/Setting/component/OperationSetting.js index 96680c71..85cfe0ef 100644 --- a/web/src/views/Setting/component/OperationSetting.js +++ b/web/src/views/Setting/component/OperationSetting.js @@ -1,12 +1,11 @@ import { useState, useEffect } from 'react'; import SubCard from 'ui-component/cards/SubCard'; -import { Stack, FormControl, InputLabel, OutlinedInput, Checkbox, Button, FormControlLabel, TextField, Alert, Switch } from '@mui/material'; +import { Stack, FormControl, InputLabel, OutlinedInput, Checkbox, Button, FormControlLabel, TextField } from '@mui/material'; import { showSuccess, showError, verifyJSON } from 'utils/common'; import { API } from 'utils/api'; import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs'; import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider'; import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker'; -import ModelRationDataGrid from './ModelRationDataGrid'; import dayjs from 'dayjs'; require('dayjs/locale/zh-cn'); @@ -18,7 +17,6 @@ const OperationSetting = () => { QuotaForInvitee: 0, QuotaRemindThreshold: 0, PreConsumedQuota: 0, - ModelRatio: '', GroupRatio: '', TopUpLink: '', ChatLink: '', @@ -34,7 +32,6 @@ const OperationSetting = () => { RetryCooldownSeconds: 0 }); const [originInputs, setOriginInputs] = useState({}); - const [newModelRatioView, setNewModelRatioView] = useState(false); let [loading, setLoading] = useState(false); let [historyTimestamp, setHistoryTimestamp] = useState(now.getTime() / 1000 - 30 * 24 * 3600); // a month ago new Date().getTime() / 1000 + 3600 @@ -45,7 +42,7 @@ const OperationSetting = () => { if (success) { let newInputs = {}; data.forEach((item) => { - if (item.key === 'ModelRatio' || item.key === 'GroupRatio') { + if (item.key === 'GroupRatio') { item.value = JSON.stringify(JSON.parse(item.value), null, 2); } newInputs[item.key] = item.value; @@ -110,13 +107,6 @@ const OperationSetting = () => { } break; case 'ratio': - if (originInputs['ModelRatio'] !== inputs.ModelRatio) { - if (!verifyJSON(inputs.ModelRatio)) { - showError('模型倍率不是合法的 JSON 字符串'); - return; - } - await updateOption('ModelRatio', inputs.ModelRatio); - } if (originInputs['GroupRatio'] !== inputs.GroupRatio) { if (!verifyJSON(inputs.GroupRatio)) { showError('分组倍率不是合法的 JSON 字符串'); @@ -469,44 +459,6 @@ const OperationSetting = () => { /> - - - 配置格式为 JSON 文本,键为模型名称;值第一位为输入倍率,第二位为完成倍率,如果只有单一倍率则两者值相同。 -
    美元:1 === $0.002 / 1K tokens 人民币: 1 === ¥0.014 / 1k tokens -
    例如
    gpt-4 输入: $0.03 / 1K tokens 完成:$0.06 / 1K tokens
    - 0.03 / 0.002 = 15, 0.06 / 0.002 = 30,即输入倍率为 15,完成倍率为 30 -
    - { - setNewModelRatioView(!newModelRatioView); - }} - /> - } - label="使用新编辑器" - /> -
    - - {newModelRatioView ? ( - - ) : ( - - - - )}