diff --git a/cli/export.go b/cli/export.go
new file mode 100644
index 00000000..cda9788b
--- /dev/null
+++ b/cli/export.go
@@ -0,0 +1,48 @@
+package cli
+
+import (
+ "encoding/json"
+ "one-api/common"
+ "one-api/relay/util"
+ "os"
+ "sort"
+)
+
+func ExportPrices() {
+ prices := util.GetPricesList("default")
+
+ if len(prices) == 0 {
+ common.SysError("No prices found")
+ return
+ }
+
+ // Sort prices by ChannelType
+ sort.Slice(prices, func(i, j int) bool {
+ if prices[i].ChannelType == prices[j].ChannelType {
+ return prices[i].Model < prices[j].Model
+ }
+ return prices[i].ChannelType < prices[j].ChannelType
+ })
+
+ // 导出到当前目录下的 prices.json 文件
+ file, err := os.Create("prices.json")
+ if err != nil {
+ common.SysError("Failed to create file: " + err.Error())
+ return
+ }
+ defer file.Close()
+
+ jsonData, err := json.MarshalIndent(prices, "", " ")
+ if err != nil {
+ common.SysError("Failed to encode prices: " + err.Error())
+ return
+ }
+
+ _, err = file.Write(jsonData)
+ if err != nil {
+ common.SysError("Failed to write to file: " + err.Error())
+ return
+ }
+
+ common.SysLog("Prices exported to prices.json")
+}
diff --git a/common/config/flag.go b/cli/flag.go
similarity index 82%
rename from common/config/flag.go
rename to cli/flag.go
index 85f56107..804f6cb4 100644
--- a/common/config/flag.go
+++ b/cli/flag.go
@@ -1,4 +1,4 @@
-package config
+package cli
import (
"flag"
@@ -14,10 +14,11 @@ var (
printVersion = flag.Bool("version", false, "print version and exit")
printHelp = flag.Bool("help", false, "print help and exit")
logDir = flag.String("log-dir", "", "specify the log directory")
- config = flag.String("config", "config.yaml", "specify the config.yaml path")
+ Config = flag.String("config", "config.yaml", "specify the config.yaml path")
+ export = flag.Bool("export", false, "Exports prices to a JSON file.")
)
-func flagConfig() {
+func FlagConfig() {
flag.Parse()
if *printVersion {
@@ -38,6 +39,11 @@ func flagConfig() {
viper.Set("log_dir", *logDir)
}
+ if *export {
+ ExportPrices()
+ os.Exit(0)
+ }
+
}
func help() {
diff --git a/common/config/config.go b/common/config/config.go
index e24e27a6..7d2634c2 100644
--- a/common/config/config.go
+++ b/common/config/config.go
@@ -4,13 +4,14 @@ import (
"strings"
"time"
+ "one-api/cli"
"one-api/common"
"github.com/spf13/viper"
)
func InitConf() {
- flagConfig()
+ cli.FlagConfig()
defaultConfig()
setConfigFile()
setEnv()
@@ -25,11 +26,11 @@ func InitConf() {
}
func setConfigFile() {
- if !common.IsFileExist(*config) {
+ if !common.IsFileExist(*cli.Config) {
return
}
- viper.SetConfigFile(*config)
+ viper.SetConfigFile(*cli.Config)
if err := viper.ReadInConfig(); err != nil {
panic(err)
}
@@ -51,4 +52,5 @@ func defaultConfig() {
viper.SetDefault("global.api_rate_limit", 180)
viper.SetDefault("global.web_rate_limit", 100)
viper.SetDefault("connect_timeout", 5)
+ viper.SetDefault("auto_price_updates", true)
}
diff --git a/common/model-ratio.go b/common/model-ratio.go
index 8c3bb11f..92b262e6 100644
--- a/common/model-ratio.go
+++ b/common/model-ratio.go
@@ -1,217 +1,5 @@
package common
-import (
- "encoding/json"
- "strings"
- "time"
-)
-
-type ModelType struct {
- Ratio []float64
- Type int
-}
-
-var ModelTypes map[string]ModelType
-
-// ModelRatio
-// https://platform.openai.com/docs/models/model-endpoint-compatibility
-// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
-// https://openai.com/pricing
-// TODO: when a new api is enabled, check the pricing here
-// 1 === $0.002 / 1K tokens
-// 1 === ¥0.014 / 1k tokens
-var ModelRatio map[string][]float64
-
-func init() {
- ModelTypes = map[string]ModelType{
- // $0.03 / 1K tokens $0.06 / 1K tokens
- "gpt-4": {[]float64{15, 30}, ChannelTypeOpenAI},
- "gpt-4-0314": {[]float64{15, 30}, ChannelTypeOpenAI},
- "gpt-4-0613": {[]float64{15, 30}, ChannelTypeOpenAI},
- // $0.06 / 1K tokens $0.12 / 1K tokens
- "gpt-4-32k": {[]float64{30, 60}, ChannelTypeOpenAI},
- "gpt-4-32k-0314": {[]float64{30, 60}, ChannelTypeOpenAI},
- "gpt-4-32k-0613": {[]float64{30, 60}, ChannelTypeOpenAI},
- // $0.01 / 1K tokens $0.03 / 1K tokens
- "gpt-4-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
- "gpt-4-1106-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
- "gpt-4-0125-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
- "gpt-4-turbo-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
- "gpt-4-vision-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
- // $0.0005 / 1K tokens $0.0015 / 1K tokens
- "gpt-3.5-turbo": {[]float64{0.25, 0.75}, ChannelTypeOpenAI},
- "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, ChannelTypeOpenAI},
- // $0.0015 / 1K tokens $0.002 / 1K tokens
- "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, ChannelTypeOpenAI},
- "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, ChannelTypeOpenAI},
- "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, ChannelTypeOpenAI},
- // $0.003 / 1K tokens $0.004 / 1K tokens
- "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, ChannelTypeOpenAI},
- "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, ChannelTypeOpenAI},
- // $0.001 / 1K tokens $0.002 / 1K tokens
- "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, ChannelTypeOpenAI},
- // $0.0020 / 1K tokens
- "davinci-002": {[]float64{1, 1}, ChannelTypeOpenAI},
- // $0.0004 / 1K tokens
- "babbage-002": {[]float64{0.2, 0.2}, ChannelTypeOpenAI},
- "text-ada-001": {[]float64{0.2, 0.2}, ChannelTypeOpenAI},
- "text-babbage-001": {[]float64{0.25, 0.25}, ChannelTypeOpenAI},
- "text-curie-001": {[]float64{1, 1}, ChannelTypeOpenAI},
- "text-davinci-002": {[]float64{10, 10}, ChannelTypeOpenAI},
- "text-davinci-003": {[]float64{10, 10}, ChannelTypeOpenAI},
- "text-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI},
- "code-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI},
- // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
- "whisper-1": {[]float64{15, 15}, ChannelTypeOpenAI},
- // $0.015 / 1K characters
- "tts-1": {[]float64{7.5, 7.5}, ChannelTypeOpenAI},
- "tts-1-1106": {[]float64{7.5, 7.5}, ChannelTypeOpenAI},
- // $0.030 / 1K characters
- "tts-1-hd": {[]float64{15, 15}, ChannelTypeOpenAI},
- "tts-1-hd-1106": {[]float64{15, 15}, ChannelTypeOpenAI},
- "davinci": {[]float64{10, 10}, ChannelTypeOpenAI},
- "curie": {[]float64{10, 10}, ChannelTypeOpenAI},
- "babbage": {[]float64{10, 10}, ChannelTypeOpenAI},
- "ada": {[]float64{10, 10}, ChannelTypeOpenAI},
- "text-embedding-ada-002": {[]float64{0.05, 0.05}, ChannelTypeOpenAI},
- // $0.00002 / 1K tokens
- "text-embedding-3-small": {[]float64{0.01, 0.01}, ChannelTypeOpenAI},
- // $0.00013 / 1K tokens
- "text-embedding-3-large": {[]float64{0.065, 0.065}, ChannelTypeOpenAI},
- "text-search-ada-doc-001": {[]float64{10, 10}, ChannelTypeOpenAI},
- "text-moderation-stable": {[]float64{0.1, 0.1}, ChannelTypeOpenAI},
- "text-moderation-latest": {[]float64{0.1, 0.1}, ChannelTypeOpenAI},
- // $0.016 - $0.020 / image
- "dall-e-2": {[]float64{8, 8}, ChannelTypeOpenAI},
- // $0.040 - $0.120 / image
- "dall-e-3": {[]float64{20, 20}, ChannelTypeOpenAI},
-
- // $0.80/million tokens $2.40/million tokens
- "claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
- // $8.00/million tokens $24.00/million tokens
- "claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
- "claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
- // $15 / M $75 / M
- "claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
- // $3 / M $15 / M
- "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
- // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
- "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic},
-
- // ¥0.004 / 1k tokens ¥0.008 / 1k tokens
- "ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},
- // ¥0.012 / 1k tokens ¥0.012 / 1k tokens
- "ERNIE-Bot": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu},
- "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu},
- // 0.024元/千tokens 0.048元/千tokens
- "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, ChannelTypeBaidu},
- // ¥0.008 / 1k tokens ¥0.008 / 1k tokens
- "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaidu},
- // ¥0.12 / 1k tokens ¥0.12 / 1k tokens
- "ERNIE-Bot-4": {[]float64{8.572, 8.572}, ChannelTypeBaidu},
- "ERNIE-4.0": {[]float64{8.572, 8.572}, ChannelTypeBaidu},
- // ¥0.002 / 1k tokens
- "Embedding-V1": {[]float64{0.1429, 0.1429}, ChannelTypeBaidu},
- // ¥0.004 / 1k tokens
- "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, ChannelTypeBaidu},
-
- "PaLM-2": {[]float64{1, 1}, ChannelTypePaLM},
- "gemini-pro": {[]float64{1, 1}, ChannelTypeGemini},
- "gemini-pro-vision": {[]float64{1, 1}, ChannelTypeGemini},
- "gemini-1.0-pro": {[]float64{1, 1}, ChannelTypeGemini},
- "gemini-1.5-pro": {[]float64{1, 1}, ChannelTypeGemini},
-
- // ¥0.005 / 1k tokens
- "glm-3-turbo": {[]float64{0.3572, 0.3572}, ChannelTypeZhipu},
- // ¥0.1 / 1k tokens
- "glm-4": {[]float64{7.143, 7.143}, ChannelTypeZhipu},
- "glm-4v": {[]float64{7.143, 7.143}, ChannelTypeZhipu},
- // ¥0.0005 / 1k tokens
- "embedding-2": {[]float64{0.0357, 0.0357}, ChannelTypeZhipu},
- // ¥0.25 / 1张图片
- "cogview-3": {[]float64{17.8571, 17.8571}, ChannelTypeZhipu},
-
- // ¥0.008 / 1k tokens
- "qwen-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeAli},
- // ¥0.02 / 1k tokens
- "qwen-plus": {[]float64{1.4286, 1.4286}, ChannelTypeAli},
- "qwen-vl-max": {[]float64{1.4286, 1.4286}, ChannelTypeAli},
- // 0.12元/1,000tokens
- "qwen-max": {[]float64{8.5714, 8.5714}, ChannelTypeAli},
- "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, ChannelTypeAli},
- // 0.008元/1,000tokens
- "qwen-vl-plus": {[]float64{0.5715, 0.5715}, ChannelTypeAli},
- // ¥0.0007 / 1k tokens
- "text-embedding-v1": {[]float64{0.05, 0.05}, ChannelTypeAli},
-
- // ¥0.018 / 1k tokens
- "SparkDesk": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
- "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
- "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
- "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
- "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
-
- // ¥0.012 / 1k tokens
- "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, ChannelType360},
- // ¥0.001 / 1k tokens
- "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, ChannelType360},
- "embedding_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360},
- "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360},
-
- // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
- "hunyuan": {[]float64{7.143, 7.143}, ChannelTypeTencent},
- // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
- // ¥0.01 / 1k tokens
- "ChatStd": {[]float64{0.7143, 0.7143}, ChannelTypeTencent},
- //¥0.1 / 1k tokens
- "ChatPro": {[]float64{7.143, 7.143}, ChannelTypeTencent},
-
- "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
- "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
- "Baichuan2-53B": {[]float64{1.4286, 1.4286}, ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
- "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
-
- "abab5.5s-chat": {[]float64{0.3572, 0.3572}, ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
- "abab5.5-chat": {[]float64{1.0714, 1.0714}, ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
- "abab6-chat": {[]float64{14.2857, 14.2857}, ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
- "embo-01": {[]float64{0.0357, 0.0357}, ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
-
- "deepseek-coder": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
- "deepseek-chat": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
-
- "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
- "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
- "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
-
- "open-mistral-7b": {[]float64{0.125, 0.125}, ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
- "open-mixtral-8x7b": {[]float64{0.35, 0.35}, ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
- "mistral-small-latest": {[]float64{1, 3}, ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
- "mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
- "mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
- "mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
-
- // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
- "llama2-70b-4096": {[]float64{0.35, 0.4}, ChannelTypeGroq},
- // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
- "llama2-7b-2048": {[]float64{0.05, 0.05}, ChannelTypeGroq},
- "gemma-7b-it": {[]float64{0.05, 0.05}, ChannelTypeGroq},
- // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
- "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, ChannelTypeGroq},
-
- // 2.5 元 / 1M tokens 0.0025 / 1k tokens
- "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, ChannelTypeLingyi},
- // 12 元 / 1M tokens 0.012 / 1k tokens
- "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, ChannelTypeLingyi},
- // 6 元 / 1M tokens 0.006 / 1k tokens
- "yi-vl-plus": {[]float64{0.4286, 0.4286}, ChannelTypeLingyi},
- }
-
- ModelRatio = make(map[string][]float64)
- for name, modelType := range ModelTypes {
- ModelRatio[name] = modelType.Ratio
- }
-}
-
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
@@ -234,104 +22,3 @@ var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
-
-func ModelRatio2JSONString() string {
- jsonBytes, err := json.Marshal(ModelRatio)
- if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
- }
- return string(jsonBytes)
-}
-
-func UpdateModelRatioByJSONString(jsonStr string) error {
- ModelRatio = make(map[string][]float64)
- return json.Unmarshal([]byte(jsonStr), &ModelRatio)
-}
-
-func MergeModelRatioByJSONString(jsonStr string) (newJsonStr string, err error) {
- isNew := false
- inputModelRatio := make(map[string][]float64)
- err = json.Unmarshal([]byte(jsonStr), &inputModelRatio)
- if err != nil {
- inputModelRatioOld := make(map[string]float64)
- err = json.Unmarshal([]byte(jsonStr), &inputModelRatioOld)
- if err != nil {
- return
- }
-
- inputModelRatio = UpdateModeRatioFormat(inputModelRatioOld)
- isNew = true
- }
-
- // 与现有的ModelRatio进行比较,如果有新增的模型,需要添加
- for key, value := range ModelRatio {
- if _, ok := inputModelRatio[key]; !ok {
- isNew = true
- inputModelRatio[key] = value
- }
- }
-
- if !isNew {
- return
- }
-
- var jsonBytes []byte
- jsonBytes, err = json.Marshal(inputModelRatio)
- if err != nil {
- SysError("error marshalling model ratio: " + err.Error())
- }
- newJsonStr = string(jsonBytes)
- return
-}
-
-func UpdateModeRatioFormat(modelRatioOld map[string]float64) map[string][]float64 {
- modelRatioNew := make(map[string][]float64)
- for key, value := range modelRatioOld {
- completionRatio := GetCompletionRatio(key) * value
- modelRatioNew[key] = []float64{value, completionRatio}
- }
- return modelRatioNew
-}
-
-func GetModelRatio(name string) []float64 {
- if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
- name = strings.TrimSuffix(name, "-internet")
- }
- ratio, ok := ModelRatio[name]
- if !ok {
- SysError("model ratio not found: " + name)
- return []float64{30, 30}
- }
- return ratio
-}
-
-func GetCompletionRatio(name string) float64 {
- if strings.HasPrefix(name, "gpt-3.5") {
- if strings.HasSuffix(name, "1106") {
- return 2
- }
- if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
- // TODO: clear this after 2023-12-11
- now := time.Now()
- // https://platform.openai.com/docs/models/continuous-model-upgrades
- // if after 2023-12-11, use 2
- if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
- return 2
- }
- }
- return 1.333333
- }
- if strings.HasPrefix(name, "gpt-4") {
- if strings.HasSuffix(name, "preview") {
- return 3
- }
- return 2
- }
- if strings.HasPrefix(name, "claude-instant-1.2") {
- return 3.38
- }
- if strings.HasPrefix(name, "claude-2") {
- return 2.965517
- }
- return 1
-}
diff --git a/common/token.go b/common/token.go
index 9add2a0b..40ef8733 100644
--- a/common/token.go
+++ b/common/token.go
@@ -13,46 +13,46 @@ import (
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
-var defaultTokenEncoder *tiktoken.Tiktoken
+var gpt35TokenEncoder *tiktoken.Tiktoken
+var gpt4TokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
SysLog("initializing token encoders")
- gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+ var err error
+ gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
- defaultTokenEncoder = gpt35TokenEncoder
- gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
+
+ gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
- for model := range ModelRatio {
- if strings.HasPrefix(model, "gpt-3.5") {
- tokenEncoderMap[model] = gpt35TokenEncoder
- } else if strings.HasPrefix(model, "gpt-4") {
- tokenEncoderMap[model] = gpt4TokenEncoder
- } else {
- tokenEncoderMap[model] = nil
- }
- }
+
SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
- if ok && tokenEncoder != nil {
+ if ok {
return tokenEncoder
}
- if ok {
- tokenEncoder, err := tiktoken.EncodingForModel(model)
+
+ if strings.HasPrefix(model, "gpt-3.5") {
+ tokenEncoder = gpt35TokenEncoder
+ } else if strings.HasPrefix(model, "gpt-4") {
+ tokenEncoder = gpt4TokenEncoder
+ } else {
+ var err error
+ tokenEncoder, err = tiktoken.EncodingForModel(model)
if err != nil {
SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
- tokenEncoder = defaultTokenEncoder
+ tokenEncoder = gpt35TokenEncoder
}
- tokenEncoderMap[model] = tokenEncoder
- return tokenEncoder
}
- return defaultTokenEncoder
+
+ tokenEncoderMap[model] = tokenEncoder
+ return tokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
diff --git a/common/utils.go b/common/utils.go
index f9e574b9..4bae968c 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -212,3 +212,31 @@ func IsFileExist(path string) bool {
_, err := os.Stat(path)
return err == nil || os.IsExist(err)
}
+
+func Contains[T comparable](value T, slice []T) bool {
+ for _, item := range slice {
+ if item == value {
+ return true
+ }
+ }
+ return false
+}
+
+func Filter[T any](arr []T, f func(T) bool) []T {
+ var res []T
+ for _, v := range arr {
+ if f(v) {
+ res = append(res, v)
+ }
+ }
+ return res
+}
+
+func GetModelsWithMatch(modelList *[]string, modelName string) string {
+ for _, model := range *modelList {
+ if strings.HasPrefix(modelName, strings.TrimRight(model, "*")) {
+ return model
+ }
+ }
+ return ""
+}
diff --git a/config.example.yaml b/config.example.yaml
index cf750964..2ed668c4 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -18,6 +18,7 @@ frontend_base_url: "" # 设置之后将重定向页面请求到指定的地址
polling_interval: 0 # 批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
batch_update_interval: 5 # 批量更新聚合的时间间隔,单位为秒,默认为 5。
batch_update_enabled: false # 启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 true 和 false,未设置则默认为 false
+auto_price_updates: true # 启用自动更新价格,可选值为 true 和 false,默认为 true
# 全局设置
global:
diff --git a/controller/pricing.go b/controller/pricing.go
new file mode 100644
index 00000000..8174832c
--- /dev/null
+++ b/controller/pricing.go
@@ -0,0 +1,192 @@
+package controller
+
+import (
+ "errors"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "one-api/relay/util"
+
+ "github.com/gin-gonic/gin"
+)
+
+func GetPricesList(c *gin.Context) {
+ pricesType := c.DefaultQuery("type", "db")
+
+ prices := util.GetPricesList(pricesType)
+
+ if len(prices) == 0 {
+ common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found"))
+ return
+ }
+
+ if pricesType == "old" {
+ c.JSON(http.StatusOK, prices)
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": prices,
+ })
+ }
+}
+
+func GetAllModelList(c *gin.Context) {
+ prices := util.PricingInstance.GetAllPrices()
+ channelModel := model.ChannelGroup.Rule
+
+ modelsMap := make(map[string]bool)
+ for modelName := range prices {
+ modelsMap[modelName] = true
+ }
+
+ for _, modelMap := range channelModel {
+ for modelName := range modelMap {
+ if _, ok := prices[modelName]; !ok {
+ modelsMap[modelName] = true
+ }
+ }
+ }
+
+ var models []string
+ for modelName := range modelsMap {
+ models = append(models, modelName)
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": models,
+ })
+}
+
+func AddPrice(c *gin.Context) {
+ var price model.Price
+ if err := c.ShouldBindJSON(&price); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ if err := util.PricingInstance.AddPrice(&price); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
+
+func UpdatePrice(c *gin.Context) {
+ modelName := c.Param("model")
+ if modelName == "" {
+ common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required"))
+ return
+ }
+
+ var price model.Price
+ if err := c.ShouldBindJSON(&price); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
+
+func DeletePrice(c *gin.Context) {
+ modelName := c.Param("model")
+ if modelName == "" {
+ common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required"))
+ return
+ }
+
+ if err := util.PricingInstance.DeletePrice(modelName); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
+
+type PriceBatchRequest struct {
+ OriginalModels []string `json:"original_models"`
+ util.BatchPrices
+}
+
+func BatchSetPrices(c *gin.Context) {
+ pricesBatch := &PriceBatchRequest{}
+ if err := c.ShouldBindJSON(pricesBatch); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
+
+type PriceBatchDeleteRequest struct {
+ Models []string `json:"models" binding:"required"`
+}
+
+func BatchDeletePrices(c *gin.Context) {
+ pricesBatch := &PriceBatchDeleteRequest{}
+ if err := c.ShouldBindJSON(pricesBatch); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
+
+func SyncPricing(c *gin.Context) {
+ overwrite := c.DefaultQuery("overwrite", "false")
+
+ prices := make([]*model.Price, 0)
+ if err := c.ShouldBindJSON(&prices); err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ if len(prices) == 0 {
+ common.APIRespondWithError(c, http.StatusOK, errors.New("prices is required"))
+ return
+ }
+
+ err := util.PricingInstance.SyncPricing(prices, overwrite == "true")
+ if err != nil {
+ common.APIRespondWithError(c, http.StatusOK, err)
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+}
diff --git a/main.go b/main.go
index fe871c1f..18b31de8 100644
--- a/main.go
+++ b/main.go
@@ -10,6 +10,7 @@ import (
"one-api/controller"
"one-api/middleware"
"one-api/model"
+ "one-api/relay/util"
"one-api/router"
"github.com/gin-contrib/sessions"
@@ -35,6 +36,7 @@ func main() {
common.InitRedisClient()
// Initialize options
model.InitOptionMap()
+ util.NewPricing()
initMemoryCache()
initSync()
diff --git a/model/balancer.go b/model/balancer.go
index 6bc19682..44bdcb9e 100644
--- a/model/balancer.go
+++ b/model/balancer.go
@@ -18,6 +18,7 @@ type ChannelsChooser struct {
sync.RWMutex
Channels map[int]*ChannelChoice
Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
+ Match []string
}
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
@@ -74,11 +75,15 @@ func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) {
return nil, errors.New("group not found")
}
- if _, ok := cc.Rule[group][modelName]; !ok {
- return nil, errors.New("model not found")
+ channelsPriority, ok := cc.Rule[group][modelName]
+ if !ok {
+ matchModel := common.GetModelsWithMatch(&cc.Match, modelName)
+ channelsPriority, ok = cc.Rule[group][matchModel]
+ if !ok {
+ return nil, errors.New("model not found")
+ }
}
- channelsPriority := cc.Rule[group][modelName]
if len(channelsPriority) == 0 {
return nil, errors.New("channel not found")
}
@@ -123,6 +128,7 @@ func (cc *ChannelsChooser) Load() {
newGroup := make(map[string]map[string][][]int)
newChannels := make(map[int]*ChannelChoice)
+ newMatch := make(map[string]bool)
for _, channel := range channels {
if *channel.Weight == 0 {
@@ -143,6 +149,13 @@ func (cc *ChannelsChooser) Load() {
newGroup[ability.Group][ability.Model] = make([][]int, 0)
}
+ // 如果是以 *结尾的 model名称
+ if strings.HasSuffix(ability.Model, "*") {
+ if _, ok := newMatch[ability.Model]; !ok {
+ newMatch[ability.Model] = true
+ }
+ }
+
var priorityIds []int
// 逗号分割 ability.ChannelId
channelIds := strings.Split(ability.ChannelIds, ",")
@@ -153,9 +166,15 @@ func (cc *ChannelsChooser) Load() {
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
}
+ newMatchList := make([]string, 0, len(newMatch))
+ for match := range newMatch {
+ newMatchList = append(newMatchList, match)
+ }
+
cc.Lock()
cc.Rule = newGroup
cc.Channels = newChannels
+ cc.Match = newMatchList
cc.Unlock()
common.SysLog("channels Load success")
}
diff --git a/model/main.go b/model/main.go
index 90cd1b96..fcdb952e 100644
--- a/model/main.go
+++ b/model/main.go
@@ -135,6 +135,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
+ err = db.AutoMigrate(&Price{})
+ if err != nil {
+ return err
+ }
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err
diff --git a/model/option.go b/model/option.go
index 45979df7..62714ce1 100644
--- a/model/option.go
+++ b/model/option.go
@@ -67,7 +67,6 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
- common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -76,28 +75,9 @@ func InitOptionMap() {
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
common.OptionMapRWMutex.Unlock()
- initModelRatio()
loadOptionsFromDatabase()
}
-func initModelRatio() {
- // 查询数据库中的ModelRatio
- option, err := GetOption("ModelRatio")
- if err != nil || option.Value == "" {
- return
- }
-
- newModelRatio, err := common.MergeModelRatioByJSONString(option.Value)
- if err != nil || newModelRatio == "" {
- return
- }
-
- // 更新数据库中的ModelRatio
- common.SysLog("update ModelRatio")
- UpdateOption("ModelRatio", newModelRatio)
-
-}
-
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
@@ -202,8 +182,6 @@ func updateOptionMap(key string, value string) (err error) {
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
- case "ModelRatio":
- err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "ChannelDisableThreshold":
diff --git a/model/price.go b/model/price.go
new file mode 100644
index 00000000..b0d0bd92
--- /dev/null
+++ b/model/price.go
@@ -0,0 +1,300 @@
+package model
+
+import (
+ "one-api/common"
+
+ "gorm.io/gorm"
+)
+
+const (
+ TokensPriceType = "tokens"
+ TimesPriceType = "times"
+ DefaultPrice = 30.0
+ DollarRate = 0.002
+ RMBRate = 0.014
+)
+
+type Price struct {
+ Model string `json:"model" gorm:"type:varchar(30);primaryKey" binding:"required"`
+ Type string `json:"type" gorm:"default:'tokens'" binding:"required"`
+ ChannelType int `json:"channel_type" gorm:"default:0" binding:"gte=0"`
+ Input float64 `json:"input" gorm:"default:0" binding:"gte=0"`
+ Output float64 `json:"output" gorm:"default:0" binding:"gte=0"`
+}
+
+func GetAllPrices() ([]*Price, error) {
+ var prices []*Price
+ if err := DB.Find(&prices).Error; err != nil {
+ return nil, err
+ }
+ return prices, nil
+}
+
+func (price *Price) Update(modelName string) error {
+ if err := DB.Model(price).Select("*").Where("model = ?", modelName).Updates(price).Error; err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (price *Price) Insert() error {
+ if err := DB.Create(price).Error; err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (price *Price) GetInput() float64 {
+ if price.Input <= 0 {
+ return 0
+ }
+ return price.Input
+}
+
+func (price *Price) GetOutput() float64 {
+ if price.Output <= 0 || price.Type == TimesPriceType {
+ return 0
+ }
+
+ return price.Output
+}
+
+func (price *Price) FetchInputCurrencyPrice(rate float64) float64 {
+ return price.GetInput() * rate
+}
+
+func (price *Price) FetchOutputCurrencyPrice(rate float64) float64 {
+ return price.GetOutput() * rate
+}
+
+func UpdatePrices(tx *gorm.DB, models []string, prices *Price) error {
+ err := tx.Model(Price{}).Where("model IN (?)", models).Select("*").Omit("model").Updates(
+ Price{
+ Type: prices.Type,
+ ChannelType: prices.ChannelType,
+ Input: prices.Input,
+ Output: prices.Output,
+ }).Error
+
+ return err
+}
+
+func DeletePrices(tx *gorm.DB, models []string) error {
+ err := tx.Where("model IN (?)", models).Delete(&Price{}).Error
+
+ return err
+}
+
+func InsertPrices(tx *gorm.DB, prices []*Price) error {
+ err := tx.CreateInBatches(prices, 100).Error
+ return err
+}
+
+func DeleteAllPrices(tx *gorm.DB) error {
+ err := tx.Where("1=1").Delete(&Price{}).Error
+ return err
+}
+
+func (price *Price) Delete() error {
+ err := DB.Delete(price).Error
+ if err != nil {
+ return err
+ }
+ return err
+}
+
+type ModelType struct {
+ Ratio []float64
+ Type int
+}
+
+// 1 === $0.002 / 1K tokens
+// 1 === ¥0.014 / 1k tokens
+func GetDefaultPrice() []*Price {
+ ModelTypes := map[string]ModelType{
+ // $0.03 / 1K tokens $0.06 / 1K tokens
+ "gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI},
+ "gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI},
+ "gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI},
+ // $0.06 / 1K tokens $0.12 / 1K tokens
+ "gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI},
+ "gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI},
+ "gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI},
+ // $0.01 / 1K tokens $0.03 / 1K tokens
+ "gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ "gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ "gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ "gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ "gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
+ // $0.0005 / 1K tokens $0.0015 / 1K tokens
+ "gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
+ // $0.0015 / 1K tokens $0.002 / 1K tokens
+ "gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
+ // $0.003 / 1K tokens $0.004 / 1K tokens
+ "gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
+ "gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
+ // $0.001 / 1K tokens $0.002 / 1K tokens
+ "gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI},
+ // $0.0020 / 1K tokens
+ "davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI},
+ // $0.0004 / 1K tokens
+ "babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI},
+ // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+ "whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI},
+ // $0.015 / 1K characters
+ "tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
+ "tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
+ // $0.030 / 1K characters
+ "tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI},
+ "tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI},
+ "text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI},
+ // $0.00002 / 1K tokens
+ "text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI},
+ // $0.00013 / 1K tokens
+ "text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI},
+ "text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
+ "text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
+ // $0.016 - $0.020 / image
+ "dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI},
+ // $0.040 - $0.120 / image
+ "dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI},
+
+ // $0.80/million tokens $2.40/million tokens
+ "claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic},
+ // $8.00/million tokens $24.00/million tokens
+ "claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic},
+ "claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic},
+ // $15 / M $75 / M
+ "claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic},
+ // $3 / M $15 / M
+ "claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic},
+ // $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
+ "claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic},
+
+ // ¥0.004 / 1k tokens ¥0.008 / 1k tokens
+ "ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu},
+ // ¥0.012 / 1k tokens ¥0.012 / 1k tokens
+ "ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
+ "ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
+ // 0.024元/千tokens 0.048元/千tokens
+ "ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu},
+ // ¥0.008 / 1k tokens ¥0.008 / 1k tokens
+ "ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu},
+ // ¥0.12 / 1k tokens ¥0.12 / 1k tokens
+ "ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
+ "ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
+ // ¥0.002 / 1k tokens
+ "Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu},
+ // ¥0.004 / 1k tokens
+ "BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu},
+
+ "PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM},
+ "gemini-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
+ "gemini-pro-vision": {[]float64{1, 1}, common.ChannelTypeGemini},
+ "gemini-1.0-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
+ "gemini-1.5-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
+
+ // ¥0.005 / 1k tokens
+ "glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu},
+ // ¥0.1 / 1k tokens
+ "glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
+ "glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
+ // ¥0.0005 / 1k tokens
+ "embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu},
+ // ¥0.25 / 1张图片
+ "cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu},
+
+ // ¥0.008 / 1k tokens
+ "qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
+ // ¥0.02 / 1k tokens
+ "qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
+ "qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
+ // 0.12元/1,000tokens
+ "qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
+ "qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
+ // 0.008元/1,000tokens
+ "qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
+ // ¥0.0007 / 1k tokens
+ "text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli},
+
+ // ¥0.018 / 1k tokens
+ "SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+ "SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+ "SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+ "SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+ "SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
+
+ // ¥0.012 / 1k tokens
+ "360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360},
+ // ¥0.001 / 1k tokens
+ "embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
+ "embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
+ "semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
+
+ // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
+ "hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
+ // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
+ // ¥0.01 / 1k tokens
+ "ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent},
+ //¥0.1 / 1k tokens
+ "ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
+
+ "Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
+ "Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
+ "Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
+ "Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
+
+ "abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
+ "abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
+ "abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
+ "embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
+
+ "deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
+ "deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
+
+ "moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
+ "moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
+ "moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
+
+ "open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
+ "open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
+ "mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
+ "mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
+ "mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
+ "mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
+
+ // $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
+ "llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq},
+ // $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
+ "llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
+ "gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
+ // $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
+ "mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq},
+
+ // 2.5 元 / 1M tokens 0.0025 / 1k tokens
+ "yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi},
+ // 12 元 / 1M tokens 0.012 / 1k tokens
+ "yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi},
+ // 6 元 / 1M tokens 0.006 / 1k tokens
+ "yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi},
+ }
+
+ var prices []*Price
+
+ for model, modelType := range ModelTypes {
+ prices = append(prices, &Price{
+ Model: model,
+ Type: TokensPriceType,
+ ChannelType: modelType.Type,
+ Input: modelType.Ratio[0],
+ Output: modelType.Ratio[1],
+ })
+ }
+
+ return prices
+}
diff --git a/modelRatio.json b/modelRatio.json
deleted file mode 100644
index 1ae7da87..00000000
--- a/modelRatio.json
+++ /dev/null
@@ -1,98 +0,0 @@
-{
- "gpt-4": [15, 30],
- "gpt-4-0314": [15, 30],
- "gpt-4-0613": [15, 30],
- "gpt-4-32k": [30, 60],
- "gpt-4-32k-0314": [30, 60],
- "gpt-4-32k-0613": [30, 60],
- "gpt-4-preview": [5, 15],
- "gpt-4-1106-preview": [5, 15],
- "gpt-4-0125-preview": [5, 15],
- "gpt-4-vision-preview": [5, 15],
- "gpt-3.5-turbo": [0.25, 0.75],
- "gpt-3.5-turbo-0125": [0.25, 0.75],
- "gpt-3.5-turbo-0301": [0.75, 1],
- "gpt-3.5-turbo-0613": [0.75, 1],
- "gpt-3.5-turbo-instruct": [0.75, 1],
- "gpt-3.5-turbo-16k": [1.5, 2],
- "gpt-3.5-turbo-16k-0613": [1.5, 2],
- "gpt-3.5-turbo-1106": [0.5, 1],
- "davinci-002": [1, 1],
- "babbage-002": [0.2, 0.2],
- "text-ada-001": [0.2, 0.2],
- "text-babbage-001": [0.25, 0.25],
- "text-curie-001": [1, 1],
- "text-davinci-002": [10, 10],
- "text-davinci-003": [10, 10],
- "text-davinci-edit-001": [10, 10],
- "code-davinci-edit-001": [10, 10],
- "whisper-1": [15, 15],
- "tts-1": [7.5, 7.5],
- "tts-1-1106": [7.5, 7.5],
- "tts-1-hd": [15, 15],
- "tts-1-hd-1106": [15, 15],
- "davinci": [10, 10],
- "curie": [10, 10],
- "babbage": [10, 10],
- "ada": [10, 10],
- "text-embedding-ada-002": [0.05, 0.05],
- "text-embedding-3-small": [0.01, 0.01],
- "text-embedding-3-large": [0.065, 0.065],
- "text-search-ada-doc-001": [10, 10],
- "text-moderation-stable": [0.1, 0.1],
- "text-moderation-latest": [0.1, 0.1],
- "dall-e-2": [8, 8],
- "dall-e-3": [20, 20],
- "claude-instant-1.2": [0.4, 1.2],
- "claude-2.0": [4, 12],
- "claude-2.1": [4, 12],
- "claude-3-opus-20240229": [7.5, 22.5],
- "claude-3-sonnet-20240229": [1.3, 3.9],
- "ERNIE-Bot": [0.8572, 0.8572],
- "ERNIE-Bot-8k": [1.7143, 3.4286],
- "ERNIE-Bot-turbo": [0.5715, 0.5715],
- "ERNIE-Bot-4": [8.572, 8.572],
- "Embedding-V1": [0.1429, 0.1429],
- "PaLM-2": [1, 1],
- "gemini-pro": [1, 1],
- "gemini-pro-vision": [1, 1],
- "chatglm_turbo": [0.3572, 0.3572],
- "chatglm_std": [0.3572, 0.3572],
- "glm-3-turbo": [0.3572, 0.3572],
- "chatglm_pro": [0.7143, 0.7143],
- "chatglm_lite": [0.1429, 0.1429],
- "glm-4": [7.143, 7.143],
- "glm-4v": [7.143, 7.143],
- "embedding-2": [0.0357, 0.0357],
- "cogview-3": [17.8571, 17.8571],
- "qwen-turbo": [0.5715, 0.5715],
- "qwen-plus": [1.4286, 1.4286],
- "qwen-max": [1.4286, 1.4286],
- "qwen-max-longcontext": [1.4286, 1.4286],
- "qwen-vl": [0.5715, 0.5715],
- "qwen-vl-plus": [0.5715, 0.5715],
- "text-embedding-v1": [0.05, 0.05],
- "SparkDesk": [1.2858, 1.2858],
- "SparkDesk-v1.1": [1.2858, 1.2858],
- "SparkDesk-v2.1": [1.2858, 1.2858],
- "SparkDesk-v3.1": [1.2858, 1.2858],
- "SparkDesk-v3.5": [1.2858, 1.2858],
- "360GPT_S2_V9": [0.8572, 0.8572],
- "embedding-bert-512-v1": [0.0715, 0.0715],
- "embedding_s1_v1": [0.0715, 0.0715],
- "semantic_similarity_s1_v1": [0.0715, 0.0715],
- "hunyuan": [7.143, 7.143],
- "Baichuan2-Turbo": [0.5715, 0.5715],
- "Baichuan2-Turbo-192k": [1.143, 1.143],
- "Baichuan2-53B": [1.4286, 1.4286],
- "Baichuan-Text-Embedding": [0.0357, 0.0357],
- "abab5.5s-chat": [0.3572, 0.3572],
- "abab5.5-chat": [1.0714, 1.0714],
- "abab6-chat": [14.2857, 14.2857],
- "embo-01": [0.0357, 0.0357],
- "deepseek-coder": [0.75, 0.75],
- "deepseek-chat": [0.75, 0.75],
- "moonshot-v1-8k": [0.8572, 0.8572],
- "moonshot-v1-32k": [1.7143, 1.7143],
- "moonshot-v1-128k": [4.2857, 4.2857]
-}
diff --git a/relay/model.go b/relay/model.go
index 74e6a636..8bc61618 100644
--- a/relay/model.go
+++ b/relay/model.go
@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
+ "one-api/relay/util"
"one-api/types"
"sort"
@@ -13,8 +14,6 @@ import (
// https://platform.openai.com/docs/api-reference/models/list
-var unknownOwnedBy = "未知"
-
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
@@ -30,6 +29,11 @@ type OpenAIModelPermission struct {
IsBlocking bool `json:"is_blocking"`
}
+type ModelPrice struct {
+ Type string `json:"type"`
+ Input string `json:"input"`
+ Output string `json:"output"`
+}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
@@ -38,30 +42,7 @@ type OpenAIModels struct {
Permission *[]OpenAIModelPermission `json:"permission"`
Root *string `json:"root"`
Parent *string `json:"parent"`
-}
-
-var modelOwnedBy map[int]string
-
-func init() {
- modelOwnedBy = map[int]string{
- common.ChannelTypeOpenAI: "OpenAI",
- common.ChannelTypeAnthropic: "Anthropic",
- common.ChannelTypeBaidu: "Baidu",
- common.ChannelTypePaLM: "Google PaLM",
- common.ChannelTypeGemini: "Google Gemini",
- common.ChannelTypeZhipu: "Zhipu",
- common.ChannelTypeAli: "Ali",
- common.ChannelTypeXunfei: "Xunfei",
- common.ChannelType360: "360",
- common.ChannelTypeTencent: "Tencent",
- common.ChannelTypeBaichuan: "Baichuan",
- common.ChannelTypeMiniMax: "MiniMax",
- common.ChannelTypeDeepseek: "Deepseek",
- common.ChannelTypeMoonshot: "Moonshot",
- common.ChannelTypeMistral: "Mistral",
- common.ChannelTypeGroq: "Groq",
- common.ChannelTypeLingyi: "Lingyiwanwu",
- }
+ Price *ModelPrice `json:"price"`
}
func ListModels(c *gin.Context) {
@@ -83,17 +64,9 @@ func ListModels(c *gin.Context) {
}
sort.Strings(models)
- groupOpenAIModels := make([]OpenAIModels, 0, len(models))
- for _, modelId := range models {
- groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{
- Id: modelId,
- Object: "model",
- Created: 1677649963,
- OwnedBy: getModelOwnedBy(modelId),
- Permission: nil,
- Root: nil,
- Parent: nil,
- })
+ var groupOpenAIModels []*OpenAIModels
+ for _, modelName := range models {
+ groupOpenAIModels = append(groupOpenAIModels, getOpenAIModelWithName(modelName))
}
// 根据 OwnedBy 排序
@@ -114,13 +87,14 @@ func ListModels(c *gin.Context) {
}
func ListModelsForAdmin(c *gin.Context) {
- openAIModels := make([]OpenAIModels, 0, len(common.ModelRatio))
- for modelId := range common.ModelRatio {
+ prices := util.PricingInstance.GetAllPrices()
+ var openAIModels []OpenAIModels
+ for modelId, price := range prices {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
- OwnedBy: getModelOwnedBy(modelId),
+ OwnedBy: getModelOwnedBy(price.ChannelType),
Permission: nil,
Root: nil,
Parent: nil,
@@ -144,21 +118,13 @@ func ListModelsForAdmin(c *gin.Context) {
}
func RetrieveModel(c *gin.Context) {
- modelId := c.Param("model")
- ownedByName := getModelOwnedBy(modelId)
- if *ownedByName != unknownOwnedBy {
- c.JSON(200, OpenAIModels{
- Id: modelId,
- Object: "model",
- Created: 1677649963,
- OwnedBy: ownedByName,
- Permission: nil,
- Root: nil,
- Parent: nil,
- })
+ modelName := c.Param("model")
+ openaiModel := getOpenAIModelWithName(modelName)
+ if *openaiModel.OwnedBy != util.UnknownOwnedBy {
+ c.JSON(200, openaiModel)
} else {
openAIError := types.OpenAIError{
- Message: fmt.Sprintf("The model '%s' does not exist", modelId),
+ Message: fmt.Sprintf("The model '%s' does not exist", modelName),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
@@ -169,12 +135,32 @@ func RetrieveModel(c *gin.Context) {
}
}
-func getModelOwnedBy(modelId string) (ownedBy *string) {
- if modelType, ok := common.ModelTypes[modelId]; ok {
- if ownedByName, ok := modelOwnedBy[modelType.Type]; ok {
- return &ownedByName
- }
+func getModelOwnedBy(channelType int) (ownedBy *string) {
+ if ownedByName, ok := util.ModelOwnedBy[channelType]; ok {
+ return &ownedByName
}
- return &unknownOwnedBy
+ return &util.UnknownOwnedBy
+}
+
+func getOpenAIModelWithName(modelName string) *OpenAIModels {
+ price := util.PricingInstance.GetPrice(modelName)
+
+ return &OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1677649963,
+ OwnedBy: getModelOwnedBy(price.ChannelType),
+ Permission: nil,
+ Root: nil,
+ Parent: nil,
+ }
+}
+
+func GetModelOwnedBy(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": util.ModelOwnedBy,
+ })
}
diff --git a/relay/util/pricing.go b/relay/util/pricing.go
new file mode 100644
index 00000000..390acb8a
--- /dev/null
+++ b/relay/util/pricing.go
@@ -0,0 +1,402 @@
+package util
+
+import (
+ "encoding/json"
+ "errors"
+ "one-api/common"
+ "one-api/model"
+ "sort"
+ "strings"
+ "sync"
+
+ "github.com/spf13/viper"
+)
+
+// PricingInstance is the Pricing instance
+var PricingInstance *Pricing
+
+// Pricing is a struct that contains the pricing data
+type Pricing struct {
+ sync.RWMutex
+ Prices map[string]*model.Price `json:"models"`
+ Match []string `json:"-"`
+}
+
+type BatchPrices struct {
+ Models []string `json:"models" binding:"required"`
+ Price model.Price `json:"price" binding:"required"`
+}
+
+// NewPricing creates a new Pricing instance
+func NewPricing() {
+ common.SysLog("Initializing Pricing")
+
+ PricingInstance = &Pricing{
+ Prices: make(map[string]*model.Price),
+ Match: make([]string, 0),
+ }
+
+ err := PricingInstance.Init()
+
+ if err != nil {
+ common.SysError("Failed to initialize Pricing:" + err.Error())
+ return
+ }
+
+ // 初始化时,需要检测是否有更新
+ if viper.GetBool("auto_price_updates") {
+ common.SysLog("Checking for pricing updates")
+ prices := model.GetDefaultPrice()
+ PricingInstance.SyncPricing(prices, false)
+ common.SysLog("Pricing initialized")
+ }
+}
+
+// initializes the Pricing instance
+func (p *Pricing) Init() error {
+ prices, err := model.GetAllPrices()
+ if err != nil {
+ return err
+ }
+
+ if len(prices) == 0 {
+ return nil
+ }
+
+ newPrices := make(map[string]*model.Price)
+ newMatch := make(map[string]bool)
+
+ for _, price := range prices {
+ newPrices[price.Model] = price
+ if strings.HasSuffix(price.Model, "*") {
+ if _, ok := newMatch[price.Model]; !ok {
+ newMatch[price.Model] = true
+ }
+ }
+ }
+
+ var newMatchList []string
+ for match := range newMatch {
+ newMatchList = append(newMatchList, match)
+ }
+
+ p.Lock()
+ defer p.Unlock()
+
+ p.Prices = newPrices
+ p.Match = newMatchList
+
+ return nil
+}
+
+// GetPrice returns the price of a model
+func (p *Pricing) GetPrice(modelName string) *model.Price {
+ p.RLock()
+ defer p.RUnlock()
+
+ if price, ok := p.Prices[modelName]; ok {
+ return price
+ }
+
+ matchModel := common.GetModelsWithMatch(&p.Match, modelName)
+ if price, ok := p.Prices[matchModel]; ok {
+ return price
+ }
+
+ return &model.Price{
+ Type: model.TokensPriceType,
+ ChannelType: common.ChannelTypeUnknown,
+ Input: model.DefaultPrice,
+ Output: model.DefaultPrice,
+ }
+}
+
+func (p *Pricing) GetAllPrices() map[string]*model.Price {
+ return p.Prices
+}
+
+func (p *Pricing) GetAllPricesList() []*model.Price {
+ var prices []*model.Price
+ for _, price := range p.Prices {
+ prices = append(prices, price)
+ }
+
+ return prices
+}
+
+func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
+ if _, ok := p.Prices[modelName]; !ok {
+ return errors.New("model not found")
+ }
+
+ if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
+ return errors.New("model names cannot be duplicated")
+ }
+
+ if err := p.deleteRawPrice(modelName); err != nil {
+ return err
+ }
+
+ return price.Insert()
+}
+
+// UpdatePrice updates the price of a model
+func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
+
+ if err := p.updateRawPrice(modelName, price); err != nil {
+ return err
+ }
+
+ err := p.Init()
+
+ return err
+}
+
+func (p *Pricing) addRawPrice(price *model.Price) error {
+ if _, ok := p.Prices[price.Model]; ok {
+ return errors.New("model already exists")
+ }
+
+ return price.Insert()
+}
+
+// AddPrice adds a new price to the Pricing instance
+func (p *Pricing) AddPrice(price *model.Price) error {
+ if err := p.addRawPrice(price); err != nil {
+ return err
+ }
+
+ err := p.Init()
+
+ return err
+}
+
+func (p *Pricing) deleteRawPrice(modelName string) error {
+ item, ok := p.Prices[modelName]
+ if !ok {
+ return errors.New("model not found")
+ }
+
+ return item.Delete()
+}
+
+// DeletePrice deletes a price from the Pricing instance
+func (p *Pricing) DeletePrice(modelName string) error {
+ if err := p.deleteRawPrice(modelName); err != nil {
+ return err
+ }
+
+ err := p.Init()
+
+ return err
+}
+
+// SyncPricing syncs the pricing data
+func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
+ var err error
+ if overwrite {
+ err = p.SyncPriceWithOverwrite(pricing)
+ } else {
+ err = p.SyncPriceWithoutOverwrite(pricing)
+ }
+
+ return err
+}
+
+// SyncPriceWithOverwrite syncs the pricing data with overwrite
+func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
+ tx := model.DB.Begin()
+
+ err := model.DeleteAllPrices(tx)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ err = model.InsertPrices(tx, pricing)
+
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ tx.Commit()
+
+ return p.Init()
+}
+
+// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
+func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
+ var newPrices []*model.Price
+
+ for _, price := range pricing {
+ if _, ok := p.Prices[price.Model]; !ok {
+ newPrices = append(newPrices, price)
+ }
+ }
+
+ if len(newPrices) == 0 {
+ return nil
+ }
+
+ tx := model.DB.Begin()
+ err := model.InsertPrices(tx, newPrices)
+
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ tx.Commit()
+
+ return p.Init()
+}
+
+// BatchDeletePrices deletes the prices of multiple models
+func (p *Pricing) BatchDeletePrices(models []string) error {
+ tx := model.DB.Begin()
+
+ err := model.DeletePrices(tx, models)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ tx.Commit()
+
+ p.Lock()
+ defer p.Unlock()
+
+ for _, model := range models {
+ delete(p.Prices, model)
+ }
+
+ return nil
+}
+
+func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
+ // 查找需要删除的model
+ var deletePrices []string
+ var addPrices []*model.Price
+ var updatePrices []string
+
+ for _, model := range originalModels {
+ if !common.Contains(model, batchPrices.Models) {
+ deletePrices = append(deletePrices, model)
+ } else {
+ updatePrices = append(updatePrices, model)
+ }
+ }
+
+ for _, model := range batchPrices.Models {
+ if !common.Contains(model, originalModels) {
+ addPrice := batchPrices.Price
+ addPrice.Model = model
+ addPrices = append(addPrices, &addPrice)
+ }
+ }
+
+ tx := model.DB.Begin()
+ if len(addPrices) > 0 {
+ err := model.InsertPrices(tx, addPrices)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+ }
+
+ if len(updatePrices) > 0 {
+ err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+ }
+
+ if len(deletePrices) > 0 {
+ err := model.DeletePrices(tx, deletePrices)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+
+ }
+ tx.Commit()
+
+ return p.Init()
+}
+
+func GetPricesList(pricingType string) []*model.Price {
+ var prices []*model.Price
+
+ switch pricingType {
+ case "default":
+ prices = model.GetDefaultPrice()
+ case "db":
+ prices = PricingInstance.GetAllPricesList()
+ case "old":
+ prices = GetOldPricesList()
+ default:
+ return nil
+ }
+
+ sort.Slice(prices, func(i, j int) bool {
+ if prices[i].ChannelType == prices[j].ChannelType {
+ return prices[i].Model < prices[j].Model
+ }
+ return prices[i].ChannelType < prices[j].ChannelType
+ })
+
+ return prices
+}
+
+func GetOldPricesList() []*model.Price {
+ oldDataJson, err := model.GetOption("ModelRatio")
+ if err != nil || oldDataJson.Value == "" {
+ return nil
+ }
+
+ oldData := make(map[string][]float64)
+ err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
+
+ if err != nil {
+ return nil
+ }
+
+ var prices []*model.Price
+ for modelName, oldPrice := range oldData {
+ price := PricingInstance.GetPrice(modelName)
+ prices = append(prices, &model.Price{
+ Model: modelName,
+ Type: model.TokensPriceType,
+ ChannelType: price.ChannelType,
+ Input: oldPrice[0],
+ Output: oldPrice[1],
+ })
+ }
+
+ return prices
+}
+
+// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
+// batchPricesMap := make(map[string]*BatchPrices)
+// for _, price := range prices {
+// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
+// batchPrice, exists := batchPricesMap[key]
+// if exists {
+// batchPrice.Models = append(batchPrice.Models, price.Model)
+// } else {
+// batchPricesMap[key] = &BatchPrices{
+// Models: []string{price.Model},
+// Price: *price,
+// }
+// }
+// }
+
+// var batchPrices []*BatchPrices
+// for _, batchPrice := range batchPricesMap {
+// batchPrices = append(batchPrices, batchPrice)
+// }
+
+// return batchPrices
+// }
diff --git a/relay/util/quota.go b/relay/util/quota.go
index 0a7c2be0..0274f51c 100644
--- a/relay/util/quota.go
+++ b/relay/util/quota.go
@@ -15,17 +15,16 @@ import (
)
type Quota struct {
- modelName string
- promptTokens int
- preConsumedTokens int
- modelRatio []float64
- groupRatio float64
- ratio float64
- preConsumedQuota int
- userId int
- channelId int
- tokenId int
- HandelStatus bool
+ modelName string
+ promptTokens int
+ price model.Price
+ groupRatio float64
+ inputRatio float64
+ preConsumedQuota int
+ userId int
+ channelId int
+ tokenId int
+ HandelStatus bool
}
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
@@ -37,7 +36,16 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
- quota.init(c.GetString("group"))
+
+ quota.price = *PricingInstance.GetPrice(quota.modelName)
+ quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
+ quota.inputRatio = quota.price.GetInput() * quota.groupRatio
+
+ if quota.price.Type == model.TimesPriceType {
+ quota.preConsumedQuota = int(1000 * quota.inputRatio)
+ } else {
+ quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
+ }
errWithCode := quota.preQuotaConsumption()
if errWithCode != nil {
@@ -47,21 +55,6 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
return quota, nil
}
-func (q *Quota) init(groupName string) {
- modelRatio := common.GetModelRatio(q.modelName)
- groupRatio := common.GetGroupRatio(groupName)
- preConsumedTokens := common.PreConsumedQuota
- ratio := modelRatio[0] * groupRatio
- preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
-
- q.preConsumedTokens = preConsumedTokens
- q.modelRatio = modelRatio
- q.groupRatio = groupRatio
- q.ratio = ratio
- q.preConsumedQuota = preConsumedQuota
-
-}
-
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
@@ -97,11 +90,17 @@ func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
- completionRatio := q.modelRatio[1] * q.groupRatio
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
- quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio))))
- if q.ratio != 0 && quota <= 0 {
+
+ if q.price.Type == model.TimesPriceType {
+ quota = int(1000 * q.inputRatio)
+ } else {
+ completionRatio := q.price.GetOutput() * q.groupRatio
+ quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
+ }
+
+ if q.inputRatio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
@@ -129,13 +128,18 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string,
}
}
var modelRatioStr string
- if q.modelRatio[0] == q.modelRatio[1] {
- modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0])
+ if q.price.Type == model.TimesPriceType {
+ modelRatioStr = fmt.Sprintf("$%g/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
- modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1])
+ // 如果输入费率和输出费率一样,则只显示一个费率
+ if q.price.GetInput() == q.price.GetOutput() {
+ modelRatioStr = fmt.Sprintf("$%g/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
+ } else {
+ modelRatioStr = fmt.Sprintf("$%g/1k (输入) | $%g/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
+ }
}
- logContent := fmt.Sprintf("模型倍率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio)
+ logContent := fmt.Sprintf("模型费率 %s,分组倍率 %.2f", modelRatioStr, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
diff --git a/relay/util/type.go b/relay/util/type.go
new file mode 100644
index 00000000..5c16d288
--- /dev/null
+++ b/relay/util/type.go
@@ -0,0 +1,28 @@
+package util
+
+import "one-api/common"
+
+var UnknownOwnedBy = "未知"
+var ModelOwnedBy map[int]string
+
+func init() {
+ ModelOwnedBy = map[int]string{
+ common.ChannelTypeOpenAI: "OpenAI",
+ common.ChannelTypeAnthropic: "Anthropic",
+ common.ChannelTypeBaidu: "Baidu",
+ common.ChannelTypePaLM: "Google PaLM",
+ common.ChannelTypeGemini: "Google Gemini",
+ common.ChannelTypeZhipu: "Zhipu",
+ common.ChannelTypeAli: "Ali",
+ common.ChannelTypeXunfei: "Xunfei",
+ common.ChannelType360: "360",
+ common.ChannelTypeTencent: "Tencent",
+ common.ChannelTypeBaichuan: "Baichuan",
+ common.ChannelTypeMiniMax: "MiniMax",
+ common.ChannelTypeDeepseek: "Deepseek",
+ common.ChannelTypeMoonshot: "Moonshot",
+ common.ChannelTypeMistral: "Mistral",
+ common.ChannelTypeGroq: "Groq",
+ common.ChannelTypeLingyi: "Lingyiwanwu",
+ }
+}
diff --git a/router/api-router.go b/router/api-router.go
index 57e81816..d1fef864 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -18,6 +18,8 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
+ apiRouter.GET("/prices", middleware.CORS(), controller.GetPricesList)
+ apiRouter.GET("/ownedby", relay.GetModelOwnedBy)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
@@ -129,6 +131,20 @@ func SetApiRouter(router *gin.Engine) {
analyticsRoute.GET("/channel_period", controller.GetChannelExpensesByPeriod)
analyticsRoute.GET("/redemption_period", controller.GetRedemptionStatisticsByPeriod)
}
+
+ pricesRoute := apiRouter.Group("/prices")
+ pricesRoute.Use(middleware.AdminAuth())
+ {
+ pricesRoute.GET("/model_list", controller.GetAllModelList)
+ pricesRoute.POST("/single", controller.AddPrice)
+ pricesRoute.PUT("/single/:model", controller.UpdatePrice)
+ pricesRoute.DELETE("/single/:model", controller.DeletePrice)
+ pricesRoute.POST("/multiple", controller.BatchSetPrices)
+ pricesRoute.PUT("/multiple/delete", controller.BatchDeletePrices)
+ pricesRoute.POST("/sync", controller.SyncPricing)
+
+ }
+
}
}
diff --git a/web/package.json b/web/package.json
index e376b4b8..cbb774de 100644
--- a/web/package.json
+++ b/web/package.json
@@ -9,7 +9,7 @@
"@emotion/react": "^11.9.3",
"@emotion/styled": "^11.9.3",
"@mui/icons-material": "^5.8.4",
- "@mui/lab": "^5.0.0-alpha.88",
+ "@mui/lab": "^5.0.0-alpha.169",
"@mui/material": "^5.8.6",
"@mui/system": "^5.8.6",
"@mui/utils": "^5.8.6",
diff --git a/web/src/menu-items/panel.js b/web/src/menu-items/panel.js
index 9bc945f7..a85f6a02 100644
--- a/web/src/menu-items/panel.js
+++ b/web/src/menu-items/panel.js
@@ -10,7 +10,8 @@ import {
IconUser,
IconUserScan,
IconActivity,
- IconBrandTelegram
+ IconBrandTelegram,
+ IconReceipt2
} from '@tabler/icons-react';
// constant
@@ -25,7 +26,8 @@ const icons = {
IconUser,
IconUserScan,
IconActivity,
- IconBrandTelegram
+ IconBrandTelegram,
+ IconReceipt2
};
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
@@ -112,6 +114,15 @@ const panel = {
breadcrumbs: false,
isAdmin: true
},
+ {
+ id: 'pricing',
+ title: '模型价格',
+ type: 'item',
+ url: '/panel/pricing',
+ icon: icons.IconReceipt2,
+ breadcrumbs: false,
+ isAdmin: true
+ },
{
id: 'setting',
title: '设置',
diff --git a/web/src/routes/MainRoutes.js b/web/src/routes/MainRoutes.js
index f783cf21..45abc176 100644
--- a/web/src/routes/MainRoutes.js
+++ b/web/src/routes/MainRoutes.js
@@ -15,6 +15,7 @@ const Profile = Loadable(lazy(() => import('views/Profile')));
const NotFoundView = Loadable(lazy(() => import('views/Error')));
const Analytics = Loadable(lazy(() => import('views/Analytics')));
const Telegram = Loadable(lazy(() => import('views/Telegram')));
+const Pricing = Loadable(lazy(() => import('views/Pricing')));
// dashboard routing
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
@@ -76,6 +77,10 @@ const MainRoutes = {
{
path: 'telegram',
element: