♻️ refactor: Refactor price module (#123) (#109) (#128)

This commit is contained in:
Buer 2024-03-28 16:53:34 +08:00 committed by GitHub
parent 646cb74154
commit a58e538c26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 2361 additions and 663 deletions

48
cli/export.go Normal file
View File

@ -0,0 +1,48 @@
package cli
import (
"encoding/json"
"one-api/common"
"one-api/relay/util"
"os"
"sort"
)
func ExportPrices() {
prices := util.GetPricesList("default")
if len(prices) == 0 {
common.SysError("No prices found")
return
}
// Sort prices by ChannelType
sort.Slice(prices, func(i, j int) bool {
if prices[i].ChannelType == prices[j].ChannelType {
return prices[i].Model < prices[j].Model
}
return prices[i].ChannelType < prices[j].ChannelType
})
// 导出到当前目录下的 prices.json 文件
file, err := os.Create("prices.json")
if err != nil {
common.SysError("Failed to create file: " + err.Error())
return
}
defer file.Close()
jsonData, err := json.MarshalIndent(prices, "", " ")
if err != nil {
common.SysError("Failed to encode prices: " + err.Error())
return
}
_, err = file.Write(jsonData)
if err != nil {
common.SysError("Failed to write to file: " + err.Error())
return
}
common.SysLog("Prices exported to prices.json")
}

View File

@ -1,4 +1,4 @@
package config
package cli
import (
"flag"
@ -14,10 +14,11 @@ var (
printVersion = flag.Bool("version", false, "print version and exit")
printHelp = flag.Bool("help", false, "print help and exit")
logDir = flag.String("log-dir", "", "specify the log directory")
config = flag.String("config", "config.yaml", "specify the config.yaml path")
Config = flag.String("config", "config.yaml", "specify the config.yaml path")
export = flag.Bool("export", false, "Exports prices to a JSON file.")
)
func flagConfig() {
func FlagConfig() {
flag.Parse()
if *printVersion {
@ -38,6 +39,11 @@ func flagConfig() {
viper.Set("log_dir", *logDir)
}
if *export {
ExportPrices()
os.Exit(0)
}
}
func help() {

View File

@ -4,13 +4,14 @@ import (
"strings"
"time"
"one-api/cli"
"one-api/common"
"github.com/spf13/viper"
)
func InitConf() {
flagConfig()
cli.FlagConfig()
defaultConfig()
setConfigFile()
setEnv()
@ -25,11 +26,11 @@ func InitConf() {
}
func setConfigFile() {
if !common.IsFileExist(*config) {
if !common.IsFileExist(*cli.Config) {
return
}
viper.SetConfigFile(*config)
viper.SetConfigFile(*cli.Config)
if err := viper.ReadInConfig(); err != nil {
panic(err)
}
@ -51,4 +52,5 @@ func defaultConfig() {
viper.SetDefault("global.api_rate_limit", 180)
viper.SetDefault("global.web_rate_limit", 100)
viper.SetDefault("connect_timeout", 5)
viper.SetDefault("auto_price_updates", true)
}

View File

@ -1,217 +1,5 @@
package common
import (
"encoding/json"
"strings"
"time"
)
type ModelType struct {
Ratio []float64
Type int
}
var ModelTypes map[string]ModelType
// ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
// https://openai.com/pricing
// TODO: when a new api is enabled, check the pricing here
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
var ModelRatio map[string][]float64
func init() {
ModelTypes = map[string]ModelType{
// $0.03 / 1K tokens $0.06 / 1K tokens
"gpt-4": {[]float64{15, 30}, ChannelTypeOpenAI},
"gpt-4-0314": {[]float64{15, 30}, ChannelTypeOpenAI},
"gpt-4-0613": {[]float64{15, 30}, ChannelTypeOpenAI},
// $0.06 / 1K tokens $0.12 / 1K tokens
"gpt-4-32k": {[]float64{30, 60}, ChannelTypeOpenAI},
"gpt-4-32k-0314": {[]float64{30, 60}, ChannelTypeOpenAI},
"gpt-4-32k-0613": {[]float64{30, 60}, ChannelTypeOpenAI},
// $0.01 / 1K tokens $0.03 / 1K tokens
"gpt-4-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
"gpt-4-1106-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
"gpt-4-0125-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
"gpt-4-turbo-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
"gpt-4-vision-preview": {[]float64{5, 15}, ChannelTypeOpenAI},
// $0.0005 / 1K tokens $0.0015 / 1K tokens
"gpt-3.5-turbo": {[]float64{0.25, 0.75}, ChannelTypeOpenAI},
"gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, ChannelTypeOpenAI},
// $0.0015 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-0301": {[]float64{0.75, 1}, ChannelTypeOpenAI},
"gpt-3.5-turbo-0613": {[]float64{0.75, 1}, ChannelTypeOpenAI},
"gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, ChannelTypeOpenAI},
// $0.003 / 1K tokens $0.004 / 1K tokens
"gpt-3.5-turbo-16k": {[]float64{1.5, 2}, ChannelTypeOpenAI},
"gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, ChannelTypeOpenAI},
// $0.001 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-1106": {[]float64{0.5, 1}, ChannelTypeOpenAI},
// $0.0020 / 1K tokens
"davinci-002": {[]float64{1, 1}, ChannelTypeOpenAI},
// $0.0004 / 1K tokens
"babbage-002": {[]float64{0.2, 0.2}, ChannelTypeOpenAI},
"text-ada-001": {[]float64{0.2, 0.2}, ChannelTypeOpenAI},
"text-babbage-001": {[]float64{0.25, 0.25}, ChannelTypeOpenAI},
"text-curie-001": {[]float64{1, 1}, ChannelTypeOpenAI},
"text-davinci-002": {[]float64{10, 10}, ChannelTypeOpenAI},
"text-davinci-003": {[]float64{10, 10}, ChannelTypeOpenAI},
"text-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI},
"code-davinci-edit-001": {[]float64{10, 10}, ChannelTypeOpenAI},
// $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": {[]float64{15, 15}, ChannelTypeOpenAI},
// $0.015 / 1K characters
"tts-1": {[]float64{7.5, 7.5}, ChannelTypeOpenAI},
"tts-1-1106": {[]float64{7.5, 7.5}, ChannelTypeOpenAI},
// $0.030 / 1K characters
"tts-1-hd": {[]float64{15, 15}, ChannelTypeOpenAI},
"tts-1-hd-1106": {[]float64{15, 15}, ChannelTypeOpenAI},
"davinci": {[]float64{10, 10}, ChannelTypeOpenAI},
"curie": {[]float64{10, 10}, ChannelTypeOpenAI},
"babbage": {[]float64{10, 10}, ChannelTypeOpenAI},
"ada": {[]float64{10, 10}, ChannelTypeOpenAI},
"text-embedding-ada-002": {[]float64{0.05, 0.05}, ChannelTypeOpenAI},
// $0.00002 / 1K tokens
"text-embedding-3-small": {[]float64{0.01, 0.01}, ChannelTypeOpenAI},
// $0.00013 / 1K tokens
"text-embedding-3-large": {[]float64{0.065, 0.065}, ChannelTypeOpenAI},
"text-search-ada-doc-001": {[]float64{10, 10}, ChannelTypeOpenAI},
"text-moderation-stable": {[]float64{0.1, 0.1}, ChannelTypeOpenAI},
"text-moderation-latest": {[]float64{0.1, 0.1}, ChannelTypeOpenAI},
// $0.016 - $0.020 / image
"dall-e-2": {[]float64{8, 8}, ChannelTypeOpenAI},
// $0.040 - $0.120 / image
"dall-e-3": {[]float64{20, 20}, ChannelTypeOpenAI},
// $0.80/million tokens $2.40/million tokens
"claude-instant-1.2": {[]float64{0.4, 1.2}, ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens
"claude-2.0": {[]float64{4, 12}, ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, ChannelTypeAnthropic},
// $15 / M $75 / M
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, ChannelTypeAnthropic},
// $3 / M $15 / M
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, ChannelTypeAnthropic},
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, ChannelTypeAnthropic},
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, ChannelTypeBaidu},
// ¥0.012 / 1k tokens ¥0.012 / 1k tokens
"ERNIE-Bot": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu},
"ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, ChannelTypeBaidu},
// 0.024元/千tokens 0.048元/千tokens
"ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, ChannelTypeBaidu},
// ¥0.008 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaidu},
// ¥0.12 / 1k tokens ¥0.12 / 1k tokens
"ERNIE-Bot-4": {[]float64{8.572, 8.572}, ChannelTypeBaidu},
"ERNIE-4.0": {[]float64{8.572, 8.572}, ChannelTypeBaidu},
// ¥0.002 / 1k tokens
"Embedding-V1": {[]float64{0.1429, 0.1429}, ChannelTypeBaidu},
// ¥0.004 / 1k tokens
"BLOOMZ-7B": {[]float64{0.2857, 0.2857}, ChannelTypeBaidu},
"PaLM-2": {[]float64{1, 1}, ChannelTypePaLM},
"gemini-pro": {[]float64{1, 1}, ChannelTypeGemini},
"gemini-pro-vision": {[]float64{1, 1}, ChannelTypeGemini},
"gemini-1.0-pro": {[]float64{1, 1}, ChannelTypeGemini},
"gemini-1.5-pro": {[]float64{1, 1}, ChannelTypeGemini},
// ¥0.005 / 1k tokens
"glm-3-turbo": {[]float64{0.3572, 0.3572}, ChannelTypeZhipu},
// ¥0.1 / 1k tokens
"glm-4": {[]float64{7.143, 7.143}, ChannelTypeZhipu},
"glm-4v": {[]float64{7.143, 7.143}, ChannelTypeZhipu},
// ¥0.0005 / 1k tokens
"embedding-2": {[]float64{0.0357, 0.0357}, ChannelTypeZhipu},
// ¥0.25 / 1张图片
"cogview-3": {[]float64{17.8571, 17.8571}, ChannelTypeZhipu},
// ¥0.008 / 1k tokens
"qwen-turbo": {[]float64{0.5715, 0.5715}, ChannelTypeAli},
// ¥0.02 / 1k tokens
"qwen-plus": {[]float64{1.4286, 1.4286}, ChannelTypeAli},
"qwen-vl-max": {[]float64{1.4286, 1.4286}, ChannelTypeAli},
// 0.12元/1,000tokens
"qwen-max": {[]float64{8.5714, 8.5714}, ChannelTypeAli},
"qwen-max-longcontext": {[]float64{8.5714, 8.5714}, ChannelTypeAli},
// 0.008元/1,000tokens
"qwen-vl-plus": {[]float64{0.5715, 0.5715}, ChannelTypeAli},
// ¥0.0007 / 1k tokens
"text-embedding-v1": {[]float64{0.05, 0.05}, ChannelTypeAli},
// ¥0.018 / 1k tokens
"SparkDesk": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
"SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
"SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
"SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
"SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, ChannelTypeXunfei},
// ¥0.012 / 1k tokens
"360GPT_S2_V9": {[]float64{0.8572, 0.8572}, ChannelType360},
// ¥0.001 / 1k tokens
"embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, ChannelType360},
"embedding_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360},
"semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, ChannelType360},
// ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"hunyuan": {[]float64{7.143, 7.143}, ChannelTypeTencent},
// https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// ¥0.01 / 1k tokens
"ChatStd": {[]float64{0.7143, 0.7143}, ChannelTypeTencent},
//¥0.1 / 1k tokens
"ChatPro": {[]float64{7.143, 7.143}, ChannelTypeTencent},
"Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
"Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
"Baichuan2-53B": {[]float64{1.4286, 1.4286}, ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
"Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
"abab5.5s-chat": {[]float64{0.3572, 0.3572}, ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
"abab5.5-chat": {[]float64{1.0714, 1.0714}, ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
"abab6-chat": {[]float64{14.2857, 14.2857}, ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
"embo-01": {[]float64{0.0357, 0.0357}, ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
"deepseek-coder": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"deepseek-chat": {[]float64{0.75, 0.75}, ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"moonshot-v1-8k": {[]float64{0.8572, 0.8572}, ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
"moonshot-v1-32k": {[]float64{1.7143, 1.7143}, ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
"moonshot-v1-128k": {[]float64{4.2857, 4.2857}, ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
"open-mistral-7b": {[]float64{0.125, 0.125}, ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
"open-mixtral-8x7b": {[]float64{0.35, 0.35}, ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
"mistral-small-latest": {[]float64{1, 3}, ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
"mistral-medium-latest": {[]float64{1.35, 4.05}, ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
"mistral-large-latest": {[]float64{4, 12}, ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
"mistral-embed": {[]float64{0.05, 0.05}, ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
// $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
"llama2-70b-4096": {[]float64{0.35, 0.4}, ChannelTypeGroq},
// $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
"llama2-7b-2048": {[]float64{0.05, 0.05}, ChannelTypeGroq},
"gemma-7b-it": {[]float64{0.05, 0.05}, ChannelTypeGroq},
// $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
"mixtral-8x7b-32768": {[]float64{0.135, 0.135}, ChannelTypeGroq},
// 2.5 元 / 1M tokens 0.0025 / 1k tokens
"yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, ChannelTypeLingyi},
// 12 元 / 1M tokens 0.012 / 1k tokens
"yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, ChannelTypeLingyi},
// 6 元 / 1M tokens 0.006 / 1k tokens
"yi-vl-plus": {[]float64{0.4286, 0.4286}, ChannelTypeLingyi},
}
ModelRatio = make(map[string][]float64)
for name, modelType := range ModelTypes {
ModelRatio[name] = modelType.Ratio
}
}
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
@ -234,104 +22,3 @@ var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRatioByJSONString(jsonStr string) error {
ModelRatio = make(map[string][]float64)
return json.Unmarshal([]byte(jsonStr), &ModelRatio)
}
func MergeModelRatioByJSONString(jsonStr string) (newJsonStr string, err error) {
isNew := false
inputModelRatio := make(map[string][]float64)
err = json.Unmarshal([]byte(jsonStr), &inputModelRatio)
if err != nil {
inputModelRatioOld := make(map[string]float64)
err = json.Unmarshal([]byte(jsonStr), &inputModelRatioOld)
if err != nil {
return
}
inputModelRatio = UpdateModeRatioFormat(inputModelRatioOld)
isNew = true
}
// 与现有的ModelRatio进行比较如果有新增的模型需要添加
for key, value := range ModelRatio {
if _, ok := inputModelRatio[key]; !ok {
isNew = true
inputModelRatio[key] = value
}
}
if !isNew {
return
}
var jsonBytes []byte
jsonBytes, err = json.Marshal(inputModelRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
}
newJsonStr = string(jsonBytes)
return
}
func UpdateModeRatioFormat(modelRatioOld map[string]float64) map[string][]float64 {
modelRatioNew := make(map[string][]float64)
for key, value := range modelRatioOld {
completionRatio := GetCompletionRatio(key) * value
modelRatioNew[key] = []float64{value, completionRatio}
}
return modelRatioNew
}
func GetModelRatio(name string) []float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name]
if !ok {
SysError("model ratio not found: " + name)
return []float64{30, 30}
}
return ratio
}
func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1.2") {
return 3.38
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
}
return 1
}

View File

@ -13,46 +13,46 @@ import (
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
var gpt35TokenEncoder *tiktoken.Tiktoken
var gpt4TokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
var err error
gpt35TokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
gpt4TokenEncoder, err = tiktoken.EncodingForModel("gpt-4")
if err != nil {
FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
for model := range ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoderMap[model] = gpt4TokenEncoder
} else {
tokenEncoderMap[model] = nil
}
}
SysLog("token encoders initialized")
}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
tokenEncoder, ok := tokenEncoderMap[model]
if ok && tokenEncoder != nil {
if ok {
return tokenEncoder
}
if ok {
tokenEncoder, err := tiktoken.EncodingForModel(model)
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoder = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
tokenEncoder = gpt4TokenEncoder
} else {
var err error
tokenEncoder, err = tiktoken.EncodingForModel(model)
if err != nil {
SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder = defaultTokenEncoder
tokenEncoder = gpt35TokenEncoder
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
return defaultTokenEncoder
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {

View File

@ -212,3 +212,31 @@ func IsFileExist(path string) bool {
_, err := os.Stat(path)
return err == nil || os.IsExist(err)
}
func Contains[T comparable](value T, slice []T) bool {
for _, item := range slice {
if item == value {
return true
}
}
return false
}
func Filter[T any](arr []T, f func(T) bool) []T {
var res []T
for _, v := range arr {
if f(v) {
res = append(res, v)
}
}
return res
}
func GetModelsWithMatch(modelList *[]string, modelName string) string {
for _, model := range *modelList {
if strings.HasPrefix(modelName, strings.TrimRight(model, "*")) {
return model
}
}
return ""
}

View File

@ -18,6 +18,7 @@ frontend_base_url: "" # 设置之后将重定向页面请求到指定的地址
polling_interval: 0 # 批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
batch_update_interval: 5 # 批量更新聚合的时间间隔,单位为秒,默认为 5。
batch_update_enabled: false # 启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 true 和 false未设置则默认为 false
auto_price_updates: true # 启用自动更新价格,可选值为 true 和 false默认为 true
# 全局设置
global:

192
controller/pricing.go Normal file
View File

@ -0,0 +1,192 @@
package controller
import (
"errors"
"net/http"
"one-api/common"
"one-api/model"
"one-api/relay/util"
"github.com/gin-gonic/gin"
)
func GetPricesList(c *gin.Context) {
pricesType := c.DefaultQuery("type", "db")
prices := util.GetPricesList(pricesType)
if len(prices) == 0 {
common.APIRespondWithError(c, http.StatusOK, errors.New("pricing data not found"))
return
}
if pricesType == "old" {
c.JSON(http.StatusOK, prices)
} else {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": prices,
})
}
}
func GetAllModelList(c *gin.Context) {
prices := util.PricingInstance.GetAllPrices()
channelModel := model.ChannelGroup.Rule
modelsMap := make(map[string]bool)
for modelName := range prices {
modelsMap[modelName] = true
}
for _, modelMap := range channelModel {
for modelName := range modelMap {
if _, ok := prices[modelName]; !ok {
modelsMap[modelName] = true
}
}
}
var models []string
for modelName := range modelsMap {
models = append(models, modelName)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": models,
})
}
func AddPrice(c *gin.Context) {
var price model.Price
if err := c.ShouldBindJSON(&price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
if err := util.PricingInstance.AddPrice(&price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
func UpdatePrice(c *gin.Context) {
modelName := c.Param("model")
if modelName == "" {
common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required"))
return
}
var price model.Price
if err := c.ShouldBindJSON(&price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
if err := util.PricingInstance.UpdatePrice(modelName, &price); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
func DeletePrice(c *gin.Context) {
modelName := c.Param("model")
if modelName == "" {
common.APIRespondWithError(c, http.StatusOK, errors.New("model name is required"))
return
}
if err := util.PricingInstance.DeletePrice(modelName); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
type PriceBatchRequest struct {
OriginalModels []string `json:"original_models"`
util.BatchPrices
}
func BatchSetPrices(c *gin.Context) {
pricesBatch := &PriceBatchRequest{}
if err := c.ShouldBindJSON(pricesBatch); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
if err := util.PricingInstance.BatchSetPrices(&pricesBatch.BatchPrices, pricesBatch.OriginalModels); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
type PriceBatchDeleteRequest struct {
Models []string `json:"models" binding:"required"`
}
func BatchDeletePrices(c *gin.Context) {
pricesBatch := &PriceBatchDeleteRequest{}
if err := c.ShouldBindJSON(pricesBatch); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
if err := util.PricingInstance.BatchDeletePrices(pricesBatch.Models); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}
func SyncPricing(c *gin.Context) {
overwrite := c.DefaultQuery("overwrite", "false")
prices := make([]*model.Price, 0)
if err := c.ShouldBindJSON(&prices); err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
if len(prices) == 0 {
common.APIRespondWithError(c, http.StatusOK, errors.New("prices is required"))
return
}
err := util.PricingInstance.SyncPricing(prices, overwrite == "true")
if err != nil {
common.APIRespondWithError(c, http.StatusOK, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
}

View File

@ -10,6 +10,7 @@ import (
"one-api/controller"
"one-api/middleware"
"one-api/model"
"one-api/relay/util"
"one-api/router"
"github.com/gin-contrib/sessions"
@ -35,6 +36,7 @@ func main() {
common.InitRedisClient()
// Initialize options
model.InitOptionMap()
util.NewPricing()
initMemoryCache()
initSync()

View File

@ -18,6 +18,7 @@ type ChannelsChooser struct {
sync.RWMutex
Channels map[int]*ChannelChoice
Rule map[string]map[string][][]int // group -> model -> priority -> channelIds
Match []string
}
func (cc *ChannelsChooser) Cooldowns(channelId int) bool {
@ -74,11 +75,15 @@ func (cc *ChannelsChooser) Next(group, modelName string) (*Channel, error) {
return nil, errors.New("group not found")
}
if _, ok := cc.Rule[group][modelName]; !ok {
return nil, errors.New("model not found")
channelsPriority, ok := cc.Rule[group][modelName]
if !ok {
matchModel := common.GetModelsWithMatch(&cc.Match, modelName)
channelsPriority, ok = cc.Rule[group][matchModel]
if !ok {
return nil, errors.New("model not found")
}
}
channelsPriority := cc.Rule[group][modelName]
if len(channelsPriority) == 0 {
return nil, errors.New("channel not found")
}
@ -123,6 +128,7 @@ func (cc *ChannelsChooser) Load() {
newGroup := make(map[string]map[string][][]int)
newChannels := make(map[int]*ChannelChoice)
newMatch := make(map[string]bool)
for _, channel := range channels {
if *channel.Weight == 0 {
@ -143,6 +149,13 @@ func (cc *ChannelsChooser) Load() {
newGroup[ability.Group][ability.Model] = make([][]int, 0)
}
// 如果是以 *结尾的 model名称
if strings.HasSuffix(ability.Model, "*") {
if _, ok := newMatch[ability.Model]; !ok {
newMatch[ability.Model] = true
}
}
var priorityIds []int
// 逗号分割 ability.ChannelId
channelIds := strings.Split(ability.ChannelIds, ",")
@ -153,9 +166,15 @@ func (cc *ChannelsChooser) Load() {
newGroup[ability.Group][ability.Model] = append(newGroup[ability.Group][ability.Model], priorityIds)
}
newMatchList := make([]string, 0, len(newMatch))
for match := range newMatch {
newMatchList = append(newMatchList, match)
}
cc.Lock()
cc.Rule = newGroup
cc.Channels = newChannels
cc.Match = newMatchList
cc.Unlock()
common.SysLog("channels Load success")
}

View File

@ -135,6 +135,10 @@ func InitDB() (err error) {
if err != nil {
return err
}
err = db.AutoMigrate(&Price{})
if err != nil {
return err
}
common.SysLog("database migrated")
err = createRootAccountIfNeed()
return err

View File

@ -67,7 +67,6 @@ func InitOptionMap() {
common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@ -76,28 +75,9 @@ func InitOptionMap() {
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
common.OptionMapRWMutex.Unlock()
initModelRatio()
loadOptionsFromDatabase()
}
func initModelRatio() {
// 查询数据库中的ModelRatio
option, err := GetOption("ModelRatio")
if err != nil || option.Value == "" {
return
}
newModelRatio, err := common.MergeModelRatioByJSONString(option.Value)
if err != nil || newModelRatio == "" {
return
}
// 更新数据库中的ModelRatio
common.SysLog("update ModelRatio")
UpdateOption("ModelRatio", newModelRatio)
}
func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
@ -202,8 +182,6 @@ func updateOptionMap(key string, value string) (err error) {
switch key {
case "EmailDomainWhitelist":
common.EmailDomainWhitelist = strings.Split(value, ",")
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
case "ChannelDisableThreshold":

300
model/price.go Normal file
View File

@ -0,0 +1,300 @@
package model
import (
"one-api/common"
"gorm.io/gorm"
)
const (
TokensPriceType = "tokens"
TimesPriceType = "times"
DefaultPrice = 30.0
DollarRate = 0.002
RMBRate = 0.014
)
type Price struct {
Model string `json:"model" gorm:"type:varchar(30);primaryKey" binding:"required"`
Type string `json:"type" gorm:"default:'tokens'" binding:"required"`
ChannelType int `json:"channel_type" gorm:"default:0" binding:"gte=0"`
Input float64 `json:"input" gorm:"default:0" binding:"gte=0"`
Output float64 `json:"output" gorm:"default:0" binding:"gte=0"`
}
func GetAllPrices() ([]*Price, error) {
var prices []*Price
if err := DB.Find(&prices).Error; err != nil {
return nil, err
}
return prices, nil
}
func (price *Price) Update(modelName string) error {
if err := DB.Model(price).Select("*").Where("model = ?", modelName).Updates(price).Error; err != nil {
return err
}
return nil
}
func (price *Price) Insert() error {
if err := DB.Create(price).Error; err != nil {
return err
}
return nil
}
func (price *Price) GetInput() float64 {
if price.Input <= 0 {
return 0
}
return price.Input
}
func (price *Price) GetOutput() float64 {
if price.Output <= 0 || price.Type == TimesPriceType {
return 0
}
return price.Output
}
func (price *Price) FetchInputCurrencyPrice(rate float64) float64 {
return price.GetInput() * rate
}
func (price *Price) FetchOutputCurrencyPrice(rate float64) float64 {
return price.GetOutput() * rate
}
func UpdatePrices(tx *gorm.DB, models []string, prices *Price) error {
err := tx.Model(Price{}).Where("model IN (?)", models).Select("*").Omit("model").Updates(
Price{
Type: prices.Type,
ChannelType: prices.ChannelType,
Input: prices.Input,
Output: prices.Output,
}).Error
return err
}
func DeletePrices(tx *gorm.DB, models []string) error {
err := tx.Where("model IN (?)", models).Delete(&Price{}).Error
return err
}
func InsertPrices(tx *gorm.DB, prices []*Price) error {
err := tx.CreateInBatches(prices, 100).Error
return err
}
func DeleteAllPrices(tx *gorm.DB) error {
err := tx.Where("1=1").Delete(&Price{}).Error
return err
}
func (price *Price) Delete() error {
err := DB.Delete(price).Error
if err != nil {
return err
}
return err
}
type ModelType struct {
Ratio []float64
Type int
}
// 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens
func GetDefaultPrice() []*Price {
ModelTypes := map[string]ModelType{
// $0.03 / 1K tokens $0.06 / 1K tokens
"gpt-4": {[]float64{15, 30}, common.ChannelTypeOpenAI},
"gpt-4-0314": {[]float64{15, 30}, common.ChannelTypeOpenAI},
"gpt-4-0613": {[]float64{15, 30}, common.ChannelTypeOpenAI},
// $0.06 / 1K tokens $0.12 / 1K tokens
"gpt-4-32k": {[]float64{30, 60}, common.ChannelTypeOpenAI},
"gpt-4-32k-0314": {[]float64{30, 60}, common.ChannelTypeOpenAI},
"gpt-4-32k-0613": {[]float64{30, 60}, common.ChannelTypeOpenAI},
// $0.01 / 1K tokens $0.03 / 1K tokens
"gpt-4-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
"gpt-4-1106-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
"gpt-4-0125-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
"gpt-4-turbo-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
"gpt-4-vision-preview": {[]float64{5, 15}, common.ChannelTypeOpenAI},
// $0.0005 / 1K tokens $0.0015 / 1K tokens
"gpt-3.5-turbo": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
"gpt-3.5-turbo-0125": {[]float64{0.25, 0.75}, common.ChannelTypeOpenAI},
// $0.0015 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-0301": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
"gpt-3.5-turbo-0613": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
"gpt-3.5-turbo-instruct": {[]float64{0.75, 1}, common.ChannelTypeOpenAI},
// $0.003 / 1K tokens $0.004 / 1K tokens
"gpt-3.5-turbo-16k": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
"gpt-3.5-turbo-16k-0613": {[]float64{1.5, 2}, common.ChannelTypeOpenAI},
// $0.001 / 1K tokens $0.002 / 1K tokens
"gpt-3.5-turbo-1106": {[]float64{0.5, 1}, common.ChannelTypeOpenAI},
// $0.0020 / 1K tokens
"davinci-002": {[]float64{1, 1}, common.ChannelTypeOpenAI},
// $0.0004 / 1K tokens
"babbage-002": {[]float64{0.2, 0.2}, common.ChannelTypeOpenAI},
// $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": {[]float64{15, 15}, common.ChannelTypeOpenAI},
// $0.015 / 1K characters
"tts-1": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
"tts-1-1106": {[]float64{7.5, 7.5}, common.ChannelTypeOpenAI},
// $0.030 / 1K characters
"tts-1-hd": {[]float64{15, 15}, common.ChannelTypeOpenAI},
"tts-1-hd-1106": {[]float64{15, 15}, common.ChannelTypeOpenAI},
"text-embedding-ada-002": {[]float64{0.05, 0.05}, common.ChannelTypeOpenAI},
// $0.00002 / 1K tokens
"text-embedding-3-small": {[]float64{0.01, 0.01}, common.ChannelTypeOpenAI},
// $0.00013 / 1K tokens
"text-embedding-3-large": {[]float64{0.065, 0.065}, common.ChannelTypeOpenAI},
"text-moderation-stable": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
"text-moderation-latest": {[]float64{0.1, 0.1}, common.ChannelTypeOpenAI},
// $0.016 - $0.020 / image
"dall-e-2": {[]float64{8, 8}, common.ChannelTypeOpenAI},
// $0.040 - $0.120 / image
"dall-e-3": {[]float64{20, 20}, common.ChannelTypeOpenAI},
// $0.80/million tokens $2.40/million tokens
"claude-instant-1.2": {[]float64{0.4, 1.2}, common.ChannelTypeAnthropic},
// $8.00/million tokens $24.00/million tokens
"claude-2.0": {[]float64{4, 12}, common.ChannelTypeAnthropic},
"claude-2.1": {[]float64{4, 12}, common.ChannelTypeAnthropic},
// $15 / M $75 / M
"claude-3-opus-20240229": {[]float64{7.5, 22.5}, common.ChannelTypeAnthropic},
// $3 / M $15 / M
"claude-3-sonnet-20240229": {[]float64{1.3, 3.9}, common.ChannelTypeAnthropic},
// $0.25 / M $1.25 / M 0.00025$ / 1k tokens 0.00125$ / 1k tokens
"claude-3-haiku-20240307": {[]float64{0.125, 0.625}, common.ChannelTypeAnthropic},
// ¥0.004 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Speed": {[]float64{0.2857, 0.5714}, common.ChannelTypeBaidu},
// ¥0.012 / 1k tokens ¥0.012 / 1k tokens
"ERNIE-Bot": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
"ERNIE-3.5-8K": {[]float64{0.8572, 0.8572}, common.ChannelTypeBaidu},
// 0.024元/千tokens 0.048元/千tokens
"ERNIE-Bot-8k": {[]float64{1.7143, 3.4286}, common.ChannelTypeBaidu},
// ¥0.008 / 1k tokens ¥0.008 / 1k tokens
"ERNIE-Bot-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaidu},
// ¥0.12 / 1k tokens ¥0.12 / 1k tokens
"ERNIE-Bot-4": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
"ERNIE-4.0": {[]float64{8.572, 8.572}, common.ChannelTypeBaidu},
// ¥0.002 / 1k tokens
"Embedding-V1": {[]float64{0.1429, 0.1429}, common.ChannelTypeBaidu},
// ¥0.004 / 1k tokens
"BLOOMZ-7B": {[]float64{0.2857, 0.2857}, common.ChannelTypeBaidu},
"PaLM-2": {[]float64{1, 1}, common.ChannelTypePaLM},
"gemini-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
"gemini-pro-vision": {[]float64{1, 1}, common.ChannelTypeGemini},
"gemini-1.0-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
"gemini-1.5-pro": {[]float64{1, 1}, common.ChannelTypeGemini},
// ¥0.005 / 1k tokens
"glm-3-turbo": {[]float64{0.3572, 0.3572}, common.ChannelTypeZhipu},
// ¥0.1 / 1k tokens
"glm-4": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
"glm-4v": {[]float64{7.143, 7.143}, common.ChannelTypeZhipu},
// ¥0.0005 / 1k tokens
"embedding-2": {[]float64{0.0357, 0.0357}, common.ChannelTypeZhipu},
// ¥0.25 / 1张图片
"cogview-3": {[]float64{17.8571, 17.8571}, common.ChannelTypeZhipu},
// ¥0.008 / 1k tokens
"qwen-turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
// ¥0.02 / 1k tokens
"qwen-plus": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
"qwen-vl-max": {[]float64{1.4286, 1.4286}, common.ChannelTypeAli},
// 0.12元/1,000tokens
"qwen-max": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
"qwen-max-longcontext": {[]float64{8.5714, 8.5714}, common.ChannelTypeAli},
// 0.008元/1,000tokens
"qwen-vl-plus": {[]float64{0.5715, 0.5715}, common.ChannelTypeAli},
// ¥0.0007 / 1k tokens
"text-embedding-v1": {[]float64{0.05, 0.05}, common.ChannelTypeAli},
// ¥0.018 / 1k tokens
"SparkDesk": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
"SparkDesk-v1.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
"SparkDesk-v2.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
"SparkDesk-v3.1": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
"SparkDesk-v3.5": {[]float64{1.2858, 1.2858}, common.ChannelTypeXunfei},
// ¥0.012 / 1k tokens
"360GPT_S2_V9": {[]float64{0.8572, 0.8572}, common.ChannelType360},
// ¥0.001 / 1k tokens
"embedding-bert-512-v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
"embedding_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
"semantic_similarity_s1_v1": {[]float64{0.0715, 0.0715}, common.ChannelType360},
// ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"hunyuan": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
// https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
// ¥0.01 / 1k tokens
"ChatStd": {[]float64{0.7143, 0.7143}, common.ChannelTypeTencent},
//¥0.1 / 1k tokens
"ChatPro": {[]float64{7.143, 7.143}, common.ChannelTypeTencent},
"Baichuan2-Turbo": {[]float64{0.5715, 0.5715}, common.ChannelTypeBaichuan}, // ¥0.008 / 1k tokens
"Baichuan2-Turbo-192k": {[]float64{1.143, 1.143}, common.ChannelTypeBaichuan}, // ¥0.016 / 1k tokens
"Baichuan2-53B": {[]float64{1.4286, 1.4286}, common.ChannelTypeBaichuan}, // ¥0.02 / 1k tokens
"Baichuan-Text-Embedding": {[]float64{0.0357, 0.0357}, common.ChannelTypeBaichuan}, // ¥0.0005 / 1k tokens
"abab5.5s-chat": {[]float64{0.3572, 0.3572}, common.ChannelTypeMiniMax}, // ¥0.005 / 1k tokens
"abab5.5-chat": {[]float64{1.0714, 1.0714}, common.ChannelTypeMiniMax}, // ¥0.015 / 1k tokens
"abab6-chat": {[]float64{14.2857, 14.2857}, common.ChannelTypeMiniMax}, // ¥0.2 / 1k tokens
"embo-01": {[]float64{0.0357, 0.0357}, common.ChannelTypeMiniMax}, // ¥0.0005 / 1k tokens
"deepseek-coder": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"deepseek-chat": {[]float64{0.75, 0.75}, common.ChannelTypeDeepseek}, // 暂定 $0.0015 / 1K tokens
"moonshot-v1-8k": {[]float64{0.8572, 0.8572}, common.ChannelTypeMoonshot}, // ¥0.012 / 1K tokens
"moonshot-v1-32k": {[]float64{1.7143, 1.7143}, common.ChannelTypeMoonshot}, // ¥0.024 / 1K tokens
"moonshot-v1-128k": {[]float64{4.2857, 4.2857}, common.ChannelTypeMoonshot}, // ¥0.06 / 1K tokens
"open-mistral-7b": {[]float64{0.125, 0.125}, common.ChannelTypeMistral}, // 0.25$ / 1M tokens 0.25$ / 1M tokens 0.00025$ / 1k tokens
"open-mixtral-8x7b": {[]float64{0.35, 0.35}, common.ChannelTypeMistral}, // 0.7$ / 1M tokens 0.7$ / 1M tokens 0.0007$ / 1k tokens
"mistral-small-latest": {[]float64{1, 3}, common.ChannelTypeMistral}, // 2$ / 1M tokens 6$ / 1M tokens 0.002$ / 1k tokens
"mistral-medium-latest": {[]float64{1.35, 4.05}, common.ChannelTypeMistral}, // 2.7$ / 1M tokens 8.1$ / 1M tokens 0.0027$ / 1k tokens
"mistral-large-latest": {[]float64{4, 12}, common.ChannelTypeMistral}, // 8$ / 1M tokens 24$ / 1M tokens 0.008$ / 1k tokens
"mistral-embed": {[]float64{0.05, 0.05}, common.ChannelTypeMistral}, // 0.1$ / 1M tokens 0.1$ / 1M tokens 0.0001$ / 1k tokens
// $0.70/$0.80 /1M Tokens 0.0007$ / 1k tokens
"llama2-70b-4096": {[]float64{0.35, 0.4}, common.ChannelTypeGroq},
// $0.10/$0.10 /1M Tokens 0.0001$ / 1k tokens
"llama2-7b-2048": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
"gemma-7b-it": {[]float64{0.05, 0.05}, common.ChannelTypeGroq},
// $0.27/$0.27 /1M Tokens 0.00027$ / 1k tokens
"mixtral-8x7b-32768": {[]float64{0.135, 0.135}, common.ChannelTypeGroq},
// 2.5 元 / 1M tokens 0.0025 / 1k tokens
"yi-34b-chat-0205": {[]float64{0.1786, 0.1786}, common.ChannelTypeLingyi},
// 12 元 / 1M tokens 0.012 / 1k tokens
"yi-34b-chat-200k": {[]float64{0.8571, 0.8571}, common.ChannelTypeLingyi},
// 6 元 / 1M tokens 0.006 / 1k tokens
"yi-vl-plus": {[]float64{0.4286, 0.4286}, common.ChannelTypeLingyi},
}
var prices []*Price
for model, modelType := range ModelTypes {
prices = append(prices, &Price{
Model: model,
Type: TokensPriceType,
ChannelType: modelType.Type,
Input: modelType.Ratio[0],
Output: modelType.Ratio[1],
})
}
return prices
}

View File

@ -1,98 +0,0 @@
{
"gpt-4": [15, 30],
"gpt-4-0314": [15, 30],
"gpt-4-0613": [15, 30],
"gpt-4-32k": [30, 60],
"gpt-4-32k-0314": [30, 60],
"gpt-4-32k-0613": [30, 60],
"gpt-4-preview": [5, 15],
"gpt-4-1106-preview": [5, 15],
"gpt-4-0125-preview": [5, 15],
"gpt-4-vision-preview": [5, 15],
"gpt-3.5-turbo": [0.25, 0.75],
"gpt-3.5-turbo-0125": [0.25, 0.75],
"gpt-3.5-turbo-0301": [0.75, 1],
"gpt-3.5-turbo-0613": [0.75, 1],
"gpt-3.5-turbo-instruct": [0.75, 1],
"gpt-3.5-turbo-16k": [1.5, 2],
"gpt-3.5-turbo-16k-0613": [1.5, 2],
"gpt-3.5-turbo-1106": [0.5, 1],
"davinci-002": [1, 1],
"babbage-002": [0.2, 0.2],
"text-ada-001": [0.2, 0.2],
"text-babbage-001": [0.25, 0.25],
"text-curie-001": [1, 1],
"text-davinci-002": [10, 10],
"text-davinci-003": [10, 10],
"text-davinci-edit-001": [10, 10],
"code-davinci-edit-001": [10, 10],
"whisper-1": [15, 15],
"tts-1": [7.5, 7.5],
"tts-1-1106": [7.5, 7.5],
"tts-1-hd": [15, 15],
"tts-1-hd-1106": [15, 15],
"davinci": [10, 10],
"curie": [10, 10],
"babbage": [10, 10],
"ada": [10, 10],
"text-embedding-ada-002": [0.05, 0.05],
"text-embedding-3-small": [0.01, 0.01],
"text-embedding-3-large": [0.065, 0.065],
"text-search-ada-doc-001": [10, 10],
"text-moderation-stable": [0.1, 0.1],
"text-moderation-latest": [0.1, 0.1],
"dall-e-2": [8, 8],
"dall-e-3": [20, 20],
"claude-instant-1.2": [0.4, 1.2],
"claude-2.0": [4, 12],
"claude-2.1": [4, 12],
"claude-3-opus-20240229": [7.5, 22.5],
"claude-3-sonnet-20240229": [1.3, 3.9],
"ERNIE-Bot": [0.8572, 0.8572],
"ERNIE-Bot-8k": [1.7143, 3.4286],
"ERNIE-Bot-turbo": [0.5715, 0.5715],
"ERNIE-Bot-4": [8.572, 8.572],
"Embedding-V1": [0.1429, 0.1429],
"PaLM-2": [1, 1],
"gemini-pro": [1, 1],
"gemini-pro-vision": [1, 1],
"chatglm_turbo": [0.3572, 0.3572],
"chatglm_std": [0.3572, 0.3572],
"glm-3-turbo": [0.3572, 0.3572],
"chatglm_pro": [0.7143, 0.7143],
"chatglm_lite": [0.1429, 0.1429],
"glm-4": [7.143, 7.143],
"glm-4v": [7.143, 7.143],
"embedding-2": [0.0357, 0.0357],
"cogview-3": [17.8571, 17.8571],
"qwen-turbo": [0.5715, 0.5715],
"qwen-plus": [1.4286, 1.4286],
"qwen-max": [1.4286, 1.4286],
"qwen-max-longcontext": [1.4286, 1.4286],
"qwen-vl": [0.5715, 0.5715],
"qwen-vl-plus": [0.5715, 0.5715],
"text-embedding-v1": [0.05, 0.05],
"SparkDesk": [1.2858, 1.2858],
"SparkDesk-v1.1": [1.2858, 1.2858],
"SparkDesk-v2.1": [1.2858, 1.2858],
"SparkDesk-v3.1": [1.2858, 1.2858],
"SparkDesk-v3.5": [1.2858, 1.2858],
"360GPT_S2_V9": [0.8572, 0.8572],
"embedding-bert-512-v1": [0.0715, 0.0715],
"embedding_s1_v1": [0.0715, 0.0715],
"semantic_similarity_s1_v1": [0.0715, 0.0715],
"hunyuan": [7.143, 7.143],
"Baichuan2-Turbo": [0.5715, 0.5715],
"Baichuan2-Turbo-192k": [1.143, 1.143],
"Baichuan2-53B": [1.4286, 1.4286],
"Baichuan-Text-Embedding": [0.0357, 0.0357],
"abab5.5s-chat": [0.3572, 0.3572],
"abab5.5-chat": [1.0714, 1.0714],
"abab6-chat": [14.2857, 14.2857],
"embo-01": [0.0357, 0.0357],
"deepseek-coder": [0.75, 0.75],
"deepseek-chat": [0.75, 0.75],
"moonshot-v1-8k": [0.8572, 0.8572],
"moonshot-v1-32k": [1.7143, 1.7143],
"moonshot-v1-128k": [4.2857, 4.2857]
}

View File

@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/relay/util"
"one-api/types"
"sort"
@ -13,8 +14,6 @@ import (
// https://platform.openai.com/docs/api-reference/models/list
var unknownOwnedBy = "未知"
type OpenAIModelPermission struct {
Id string `json:"id"`
Object string `json:"object"`
@ -30,6 +29,11 @@ type OpenAIModelPermission struct {
IsBlocking bool `json:"is_blocking"`
}
type ModelPrice struct {
Type string `json:"type"`
Input string `json:"input"`
Output string `json:"output"`
}
type OpenAIModels struct {
Id string `json:"id"`
Object string `json:"object"`
@ -38,30 +42,7 @@ type OpenAIModels struct {
Permission *[]OpenAIModelPermission `json:"permission"`
Root *string `json:"root"`
Parent *string `json:"parent"`
}
var modelOwnedBy map[int]string
func init() {
modelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
}
Price *ModelPrice `json:"price"`
}
func ListModels(c *gin.Context) {
@ -83,17 +64,9 @@ func ListModels(c *gin.Context) {
}
sort.Strings(models)
groupOpenAIModels := make([]OpenAIModels, 0, len(models))
for _, modelId := range models {
groupOpenAIModels = append(groupOpenAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(modelId),
Permission: nil,
Root: nil,
Parent: nil,
})
var groupOpenAIModels []*OpenAIModels
for _, modelName := range models {
groupOpenAIModels = append(groupOpenAIModels, getOpenAIModelWithName(modelName))
}
// 根据 OwnedBy 排序
@ -114,13 +87,14 @@ func ListModels(c *gin.Context) {
}
func ListModelsForAdmin(c *gin.Context) {
openAIModels := make([]OpenAIModels, 0, len(common.ModelRatio))
for modelId := range common.ModelRatio {
prices := util.PricingInstance.GetAllPrices()
var openAIModels []OpenAIModels
for modelId, price := range prices {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(modelId),
OwnedBy: getModelOwnedBy(price.ChannelType),
Permission: nil,
Root: nil,
Parent: nil,
@ -144,21 +118,13 @@ func ListModelsForAdmin(c *gin.Context) {
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
ownedByName := getModelOwnedBy(modelId)
if *ownedByName != unknownOwnedBy {
c.JSON(200, OpenAIModels{
Id: modelId,
Object: "model",
Created: 1677649963,
OwnedBy: ownedByName,
Permission: nil,
Root: nil,
Parent: nil,
})
modelName := c.Param("model")
openaiModel := getOpenAIModelWithName(modelName)
if *openaiModel.OwnedBy != util.UnknownOwnedBy {
c.JSON(200, openaiModel)
} else {
openAIError := types.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Message: fmt.Sprintf("The model '%s' does not exist", modelName),
Type: "invalid_request_error",
Param: "model",
Code: "model_not_found",
@ -169,12 +135,32 @@ func RetrieveModel(c *gin.Context) {
}
}
func getModelOwnedBy(modelId string) (ownedBy *string) {
if modelType, ok := common.ModelTypes[modelId]; ok {
if ownedByName, ok := modelOwnedBy[modelType.Type]; ok {
return &ownedByName
}
func getModelOwnedBy(channelType int) (ownedBy *string) {
if ownedByName, ok := util.ModelOwnedBy[channelType]; ok {
return &ownedByName
}
return &unknownOwnedBy
return &util.UnknownOwnedBy
}
func getOpenAIModelWithName(modelName string) *OpenAIModels {
price := util.PricingInstance.GetPrice(modelName)
return &OpenAIModels{
Id: modelName,
Object: "model",
Created: 1677649963,
OwnedBy: getModelOwnedBy(price.ChannelType),
Permission: nil,
Root: nil,
Parent: nil,
}
}
func GetModelOwnedBy(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": util.ModelOwnedBy,
})
}

402
relay/util/pricing.go Normal file
View File

@ -0,0 +1,402 @@
package util
import (
"encoding/json"
"errors"
"one-api/common"
"one-api/model"
"sort"
"strings"
"sync"
"github.com/spf13/viper"
)
// PricingInstance is the Pricing instance
var PricingInstance *Pricing
// Pricing is a struct that contains the pricing data
type Pricing struct {
sync.RWMutex
Prices map[string]*model.Price `json:"models"`
Match []string `json:"-"`
}
type BatchPrices struct {
Models []string `json:"models" binding:"required"`
Price model.Price `json:"price" binding:"required"`
}
// NewPricing creates a new Pricing instance
func NewPricing() {
common.SysLog("Initializing Pricing")
PricingInstance = &Pricing{
Prices: make(map[string]*model.Price),
Match: make([]string, 0),
}
err := PricingInstance.Init()
if err != nil {
common.SysError("Failed to initialize Pricing:" + err.Error())
return
}
// 初始化时,需要检测是否有更新
if viper.GetBool("auto_price_updates") {
common.SysLog("Checking for pricing updates")
prices := model.GetDefaultPrice()
PricingInstance.SyncPricing(prices, false)
common.SysLog("Pricing initialized")
}
}
// initializes the Pricing instance
func (p *Pricing) Init() error {
prices, err := model.GetAllPrices()
if err != nil {
return err
}
if len(prices) == 0 {
return nil
}
newPrices := make(map[string]*model.Price)
newMatch := make(map[string]bool)
for _, price := range prices {
newPrices[price.Model] = price
if strings.HasSuffix(price.Model, "*") {
if _, ok := newMatch[price.Model]; !ok {
newMatch[price.Model] = true
}
}
}
var newMatchList []string
for match := range newMatch {
newMatchList = append(newMatchList, match)
}
p.Lock()
defer p.Unlock()
p.Prices = newPrices
p.Match = newMatchList
return nil
}
// GetPrice returns the price of a model
func (p *Pricing) GetPrice(modelName string) *model.Price {
p.RLock()
defer p.RUnlock()
if price, ok := p.Prices[modelName]; ok {
return price
}
matchModel := common.GetModelsWithMatch(&p.Match, modelName)
if price, ok := p.Prices[matchModel]; ok {
return price
}
return &model.Price{
Type: model.TokensPriceType,
ChannelType: common.ChannelTypeUnknown,
Input: model.DefaultPrice,
Output: model.DefaultPrice,
}
}
func (p *Pricing) GetAllPrices() map[string]*model.Price {
return p.Prices
}
func (p *Pricing) GetAllPricesList() []*model.Price {
var prices []*model.Price
for _, price := range p.Prices {
prices = append(prices, price)
}
return prices
}
func (p *Pricing) updateRawPrice(modelName string, price *model.Price) error {
if _, ok := p.Prices[modelName]; !ok {
return errors.New("model not found")
}
if _, ok := p.Prices[price.Model]; modelName != price.Model && ok {
return errors.New("model names cannot be duplicated")
}
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
return price.Insert()
}
// UpdatePrice updates the price of a model
func (p *Pricing) UpdatePrice(modelName string, price *model.Price) error {
if err := p.updateRawPrice(modelName, price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) addRawPrice(price *model.Price) error {
if _, ok := p.Prices[price.Model]; ok {
return errors.New("model already exists")
}
return price.Insert()
}
// AddPrice adds a new price to the Pricing instance
func (p *Pricing) AddPrice(price *model.Price) error {
if err := p.addRawPrice(price); err != nil {
return err
}
err := p.Init()
return err
}
func (p *Pricing) deleteRawPrice(modelName string) error {
item, ok := p.Prices[modelName]
if !ok {
return errors.New("model not found")
}
return item.Delete()
}
// DeletePrice deletes a price from the Pricing instance
func (p *Pricing) DeletePrice(modelName string) error {
if err := p.deleteRawPrice(modelName); err != nil {
return err
}
err := p.Init()
return err
}
// SyncPricing syncs the pricing data
func (p *Pricing) SyncPricing(pricing []*model.Price, overwrite bool) error {
var err error
if overwrite {
err = p.SyncPriceWithOverwrite(pricing)
} else {
err = p.SyncPriceWithoutOverwrite(pricing)
}
return err
}
// SyncPriceWithOverwrite syncs the pricing data with overwrite
func (p *Pricing) SyncPriceWithOverwrite(pricing []*model.Price) error {
tx := model.DB.Begin()
err := model.DeleteAllPrices(tx)
if err != nil {
tx.Rollback()
return err
}
err = model.InsertPrices(tx, pricing)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// SyncPriceWithoutOverwrite syncs the pricing data without overwrite
func (p *Pricing) SyncPriceWithoutOverwrite(pricing []*model.Price) error {
var newPrices []*model.Price
for _, price := range pricing {
if _, ok := p.Prices[price.Model]; !ok {
newPrices = append(newPrices, price)
}
}
if len(newPrices) == 0 {
return nil
}
tx := model.DB.Begin()
err := model.InsertPrices(tx, newPrices)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return p.Init()
}
// BatchDeletePrices deletes the prices of multiple models
func (p *Pricing) BatchDeletePrices(models []string) error {
tx := model.DB.Begin()
err := model.DeletePrices(tx, models)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
p.Lock()
defer p.Unlock()
for _, model := range models {
delete(p.Prices, model)
}
return nil
}
func (p *Pricing) BatchSetPrices(batchPrices *BatchPrices, originalModels []string) error {
// 查找需要删除的model
var deletePrices []string
var addPrices []*model.Price
var updatePrices []string
for _, model := range originalModels {
if !common.Contains(model, batchPrices.Models) {
deletePrices = append(deletePrices, model)
} else {
updatePrices = append(updatePrices, model)
}
}
for _, model := range batchPrices.Models {
if !common.Contains(model, originalModels) {
addPrice := batchPrices.Price
addPrice.Model = model
addPrices = append(addPrices, &addPrice)
}
}
tx := model.DB.Begin()
if len(addPrices) > 0 {
err := model.InsertPrices(tx, addPrices)
if err != nil {
tx.Rollback()
return err
}
}
if len(updatePrices) > 0 {
err := model.UpdatePrices(tx, updatePrices, &batchPrices.Price)
if err != nil {
tx.Rollback()
return err
}
}
if len(deletePrices) > 0 {
err := model.DeletePrices(tx, deletePrices)
if err != nil {
tx.Rollback()
return err
}
}
tx.Commit()
return p.Init()
}
func GetPricesList(pricingType string) []*model.Price {
var prices []*model.Price
switch pricingType {
case "default":
prices = model.GetDefaultPrice()
case "db":
prices = PricingInstance.GetAllPricesList()
case "old":
prices = GetOldPricesList()
default:
return nil
}
sort.Slice(prices, func(i, j int) bool {
if prices[i].ChannelType == prices[j].ChannelType {
return prices[i].Model < prices[j].Model
}
return prices[i].ChannelType < prices[j].ChannelType
})
return prices
}
func GetOldPricesList() []*model.Price {
oldDataJson, err := model.GetOption("ModelRatio")
if err != nil || oldDataJson.Value == "" {
return nil
}
oldData := make(map[string][]float64)
err = json.Unmarshal([]byte(oldDataJson.Value), &oldData)
if err != nil {
return nil
}
var prices []*model.Price
for modelName, oldPrice := range oldData {
price := PricingInstance.GetPrice(modelName)
prices = append(prices, &model.Price{
Model: modelName,
Type: model.TokensPriceType,
ChannelType: price.ChannelType,
Input: oldPrice[0],
Output: oldPrice[1],
})
}
return prices
}
// func ConvertBatchPrices(prices []*model.Price) []*BatchPrices {
// batchPricesMap := make(map[string]*BatchPrices)
// for _, price := range prices {
// key := fmt.Sprintf("%s-%d-%g-%g", price.Type, price.ChannelType, price.Input, price.Output)
// batchPrice, exists := batchPricesMap[key]
// if exists {
// batchPrice.Models = append(batchPrice.Models, price.Model)
// } else {
// batchPricesMap[key] = &BatchPrices{
// Models: []string{price.Model},
// Price: *price,
// }
// }
// }
// var batchPrices []*BatchPrices
// for _, batchPrice := range batchPricesMap {
// batchPrices = append(batchPrices, batchPrice)
// }
// return batchPrices
// }

View File

@ -15,17 +15,16 @@ import (
)
type Quota struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio []float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
modelName string
promptTokens int
price model.Price
groupRatio float64
inputRatio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *types.OpenAIErrorWithStatusCode) {
@ -37,7 +36,16 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quota.init(c.GetString("group"))
quota.price = *PricingInstance.GetPrice(quota.modelName)
quota.groupRatio = common.GetGroupRatio(c.GetString("group"))
quota.inputRatio = quota.price.GetInput() * quota.groupRatio
if quota.price.Type == model.TimesPriceType {
quota.preConsumedQuota = int(1000 * quota.inputRatio)
} else {
quota.preConsumedQuota = int(float64(quota.promptTokens+common.PreConsumedQuota) * quota.inputRatio)
}
errWithCode := quota.preQuotaConsumption()
if errWithCode != nil {
@ -47,21 +55,6 @@ func NewQuota(c *gin.Context, modelName string, promptTokens int) (*Quota, *type
return quota, nil
}
func (q *Quota) init(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio[0] * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
}
func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
@ -97,11 +90,17 @@ func (q *Quota) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := q.modelRatio[1] * q.groupRatio
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil(((float64(promptTokens) * q.ratio) + (float64(completionTokens) * completionRatio))))
if q.ratio != 0 && quota <= 0 {
if q.price.Type == model.TimesPriceType {
quota = int(1000 * q.inputRatio)
} else {
completionRatio := q.price.GetOutput() * q.groupRatio
quota = int(math.Ceil(((float64(promptTokens) * q.inputRatio) + (float64(completionTokens) * completionRatio))))
}
if q.inputRatio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
@ -129,13 +128,18 @@ func (q *Quota) completedQuotaConsumption(usage *types.Usage, tokenName string,
}
}
var modelRatioStr string
if q.modelRatio[0] == q.modelRatio[1] {
modelRatioStr = fmt.Sprintf("%.2f", q.modelRatio[0])
if q.price.Type == model.TimesPriceType {
modelRatioStr = fmt.Sprintf("$%g/次", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
modelRatioStr = fmt.Sprintf("%.2f (输入)/%.2f (输出)", q.modelRatio[0], q.modelRatio[1])
// 如果输入费率和输出费率一样,则只显示一个费率
if q.price.GetInput() == q.price.GetOutput() {
modelRatioStr = fmt.Sprintf("$%g/1k", q.price.FetchInputCurrencyPrice(model.DollarRate))
} else {
modelRatioStr = fmt.Sprintf("$%g/1k (输入) | $%g/1k (输出)", q.price.FetchInputCurrencyPrice(model.DollarRate), q.price.FetchOutputCurrencyPrice(model.DollarRate))
}
}
logContent := fmt.Sprintf("模型倍率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
logContent := fmt.Sprintf("模型率 %s分组倍率 %.2f", modelRatioStr, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)

28
relay/util/type.go Normal file
View File

@ -0,0 +1,28 @@
package util
import "one-api/common"
var UnknownOwnedBy = "未知"
var ModelOwnedBy map[int]string
func init() {
ModelOwnedBy = map[int]string{
common.ChannelTypeOpenAI: "OpenAI",
common.ChannelTypeAnthropic: "Anthropic",
common.ChannelTypeBaidu: "Baidu",
common.ChannelTypePaLM: "Google PaLM",
common.ChannelTypeGemini: "Google Gemini",
common.ChannelTypeZhipu: "Zhipu",
common.ChannelTypeAli: "Ali",
common.ChannelTypeXunfei: "Xunfei",
common.ChannelType360: "360",
common.ChannelTypeTencent: "Tencent",
common.ChannelTypeBaichuan: "Baichuan",
common.ChannelTypeMiniMax: "MiniMax",
common.ChannelTypeDeepseek: "Deepseek",
common.ChannelTypeMoonshot: "Moonshot",
common.ChannelTypeMistral: "Mistral",
common.ChannelTypeGroq: "Groq",
common.ChannelTypeLingyi: "Lingyiwanwu",
}
}

View File

@ -18,6 +18,8 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/prices", middleware.CORS(), controller.GetPricesList)
apiRouter.GET("/ownedby", relay.GetModelOwnedBy)
apiRouter.GET("/home_page_content", controller.GetHomePageContent)
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
@ -129,6 +131,20 @@ func SetApiRouter(router *gin.Engine) {
analyticsRoute.GET("/channel_period", controller.GetChannelExpensesByPeriod)
analyticsRoute.GET("/redemption_period", controller.GetRedemptionStatisticsByPeriod)
}
pricesRoute := apiRouter.Group("/prices")
pricesRoute.Use(middleware.AdminAuth())
{
pricesRoute.GET("/model_list", controller.GetAllModelList)
pricesRoute.POST("/single", controller.AddPrice)
pricesRoute.PUT("/single/:model", controller.UpdatePrice)
pricesRoute.DELETE("/single/:model", controller.DeletePrice)
pricesRoute.POST("/multiple", controller.BatchSetPrices)
pricesRoute.PUT("/multiple/delete", controller.BatchDeletePrices)
pricesRoute.POST("/sync", controller.SyncPricing)
}
}
}

View File

@ -9,7 +9,7 @@
"@emotion/react": "^11.9.3",
"@emotion/styled": "^11.9.3",
"@mui/icons-material": "^5.8.4",
"@mui/lab": "^5.0.0-alpha.88",
"@mui/lab": "^5.0.0-alpha.169",
"@mui/material": "^5.8.6",
"@mui/system": "^5.8.6",
"@mui/utils": "^5.8.6",

View File

@ -10,7 +10,8 @@ import {
IconUser,
IconUserScan,
IconActivity,
IconBrandTelegram
IconBrandTelegram,
IconReceipt2
} from '@tabler/icons-react';
// constant
@ -25,7 +26,8 @@ const icons = {
IconUser,
IconUserScan,
IconActivity,
IconBrandTelegram
IconBrandTelegram,
IconReceipt2
};
// ==============================|| DASHBOARD MENU ITEMS ||============================== //
@ -112,6 +114,15 @@ const panel = {
breadcrumbs: false,
isAdmin: true
},
{
id: 'pricing',
title: '模型价格',
type: 'item',
url: '/panel/pricing',
icon: icons.IconReceipt2,
breadcrumbs: false,
isAdmin: true
},
{
id: 'setting',
title: '设置',

View File

@ -15,6 +15,7 @@ const Profile = Loadable(lazy(() => import('views/Profile')));
const NotFoundView = Loadable(lazy(() => import('views/Error')));
const Analytics = Loadable(lazy(() => import('views/Analytics')));
const Telegram = Loadable(lazy(() => import('views/Telegram')));
const Pricing = Loadable(lazy(() => import('views/Pricing')));
// dashboard routing
const Dashboard = Loadable(lazy(() => import('views/Dashboard')));
@ -76,6 +77,10 @@ const MainRoutes = {
{
path: 'telegram',
element: <Telegram />
},
{
path: 'pricing',
element: <Pricing />
}
]
};

View File

@ -32,7 +32,8 @@ const defaultConfig = {
other: '',
proxy: '单独设置代理地址支持http和socks5例如http://127.0.0.1:1080',
test_model: '用于测试使用的模型,为空时无法测速,如gpt-3.5-turbo',
models: '请选择该渠道所支持的模型',
models:
'请选择该渠道所支持的模型,你也可以输入通配符*来匹配模型例如gpt-3.5*表示支持所有gpt-3.5开头的模型,*号只能在最后一位使用前面必须有字符例如gpt-3.5*是正确的,*gpt-3.5是错误的',
model_mapping:
'请输入要修改的模型映射关系格式为api请求模型ID:实际转发给渠道的模型ID使用JSON数组表示例如{"gpt-3.5": "gpt-35"}',
groups: '请选择该渠道所支持的用户组'

View File

@ -198,12 +198,12 @@ export default function Log() {
},
{
id: 'message',
label: '提示',
label: '输入',
disableSort: true
},
{
id: 'completion',
label: '补全',
label: '输出',
disableSort: true
},
{

View File

@ -0,0 +1,195 @@
import PropTypes from 'prop-types';
import { useState, useEffect } from 'react';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
Divider,
Button,
TextField,
Grid,
FormControl,
Alert,
Stack,
Typography
} from '@mui/material';
import { API } from 'utils/api';
import { showError, showSuccess } from 'utils/common';
import LoadingButton from '@mui/lab/LoadingButton';
import Label from 'ui-component/Label';
export const CheckUpdates = ({ open, onCancel, onOk, row }) => {
const [url, setUrl] = useState('https://raw.githubusercontent.com/MartialBE/one-api/prices/prices.json');
const [loading, setLoading] = useState(false);
const [updateLoading, setUpdateLoading] = useState(false);
const [newPricing, setNewPricing] = useState([]);
const [addModel, setAddModel] = useState([]);
const [diffModel, setDiffModel] = useState([]);
const handleCheckUpdates = async () => {
setLoading(true);
try {
const res = await API.get(url);
// 检测是否是一个列表
if (!Array.isArray(res.data)) {
showError('数据格式不正确');
} else {
setNewPricing(res.data);
}
} catch (err) {
console.error(err);
}
setLoading(false);
};
const syncPricing = async (overwrite) => {
setUpdateLoading(true);
if (!newPricing.length) {
showError('请先获取数据');
return;
}
if (!overwrite && !addModel.length) {
showError('没有新增模型');
return;
}
try {
overwrite = overwrite ? 'true' : 'false';
const res = await API.post('/api/prices/sync?overwrite=' + overwrite, newPricing);
const { success, message } = res.data;
if (success) {
showSuccess('操作成功完成!');
onOk(true);
} else {
showError(message);
}
} catch (err) {
console.error(err);
}
setUpdateLoading(false);
};
useEffect(() => {
const newModels = newPricing.filter((np) => !row.some((r) => r.model === np.model));
const changeModel = row.filter((r) =>
newPricing.some((np) => np.model === r.model && (np.input !== r.input || np.output !== r.output))
);
if (newModels.length > 0) {
const newModelsList = newModels.map((model) => model.model);
setAddModel(newModelsList);
} else {
setAddModel('');
}
if (changeModel.length > 0) {
const changeModelList = changeModel.map((model) => {
const newModel = newPricing.find((np) => np.model === model.model);
let changes = '';
if (model.input !== newModel.input) {
changes += `输入倍率由 ${model.input} 变为 ${newModel.input},`;
}
if (model.output !== newModel.output) {
changes += `输出倍率由 ${model.output} 变为 ${newModel.output}`;
}
return `${model.model}:${changes}`;
});
setDiffModel(changeModelList);
} else {
setDiffModel('');
}
}, [row, newPricing]);
return (
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
检查更新
</DialogTitle>
<Divider />
<DialogContent>
<Grid container justifyContent="center" alignItems="center" spacing={2}>
<Grid item xs={12} md={10}>
<FormControl fullWidth component="fieldset">
<TextField label="URL" variant="outlined" value={url} onChange={(e) => setUrl(e.target.value)} />
</FormControl>
</Grid>
<Grid item xs={12} md={2}>
<LoadingButton variant="contained" color="primary" onClick={handleCheckUpdates} loading={loading}>
获取数据
</LoadingButton>
</Grid>
{newPricing.length > 0 && (
<Grid item xs={12}>
{!addModel.length && !diffModel.length && <Alert severity="success">无更新</Alert>}
{addModel.length > 0 && (
<Alert severity="warning">
新增模型
<Stack direction="row" spacing={1} flexWrap="wrap">
{addModel.map((model) => (
<Label color="info" key={model} variant="outlined">
{model}
</Label>
))}
</Stack>
</Alert>
)}
{diffModel.length > 0 && (
<Alert severity="warning">
价格变动模型(仅供参考如果你自己修改了对应模型的价格请忽略)
{diffModel.map((model) => (
<Typography variant="button" display="block" gutterBottom key={model}>
{model}
</Typography>
))}
</Alert>
)}
<Alert severity="warning">
注意:
你可以选择覆盖或者仅添加新增如果你选择覆盖将会删除你自己添加的模型价格完全使用远程配置如果你选择仅添加新增将会只会添加
新增模型的价格
</Alert>
<Stack direction="row" justifyContent="center" spacing={1} flexWrap="wrap">
<LoadingButton
variant="contained"
color="primary"
onClick={() => {
syncPricing(true);
}}
loading={updateLoading}
>
覆盖数据
</LoadingButton>
<LoadingButton
variant="contained"
color="primary"
onClick={() => {
syncPricing(false);
}}
loading={updateLoading}
>
仅添加新增
</LoadingButton>
</Stack>
</Grid>
)}
</Grid>
</DialogContent>
<DialogActions>
<Button onClick={onCancel} color="primary">
取消
</Button>
</DialogActions>
</Dialog>
);
};
CheckUpdates.propTypes = {
open: PropTypes.bool,
row: PropTypes.array,
onCancel: PropTypes.func,
onOk: PropTypes.func
};

View File

@ -0,0 +1,299 @@
import PropTypes from 'prop-types';
import * as Yup from 'yup';
import { Formik } from 'formik';
import { useTheme } from '@mui/material/styles';
import { useState, useEffect } from 'react';
import {
Dialog,
DialogTitle,
DialogContent,
DialogActions,
Button,
Divider,
FormControl,
InputLabel,
OutlinedInput,
InputAdornment,
FormHelperText,
Select,
Autocomplete,
TextField,
Checkbox,
MenuItem
} from '@mui/material';
import { showSuccess, showError } from 'utils/common';
import { API } from 'utils/api';
import { createFilterOptions } from '@mui/material/Autocomplete';
import { ValueFormatter, priceType } from './util';
import CheckBoxOutlineBlankIcon from '@mui/icons-material/CheckBoxOutlineBlank';
import CheckBoxIcon from '@mui/icons-material/CheckBox';
const icon = <CheckBoxOutlineBlankIcon fontSize="small" />;
const checkedIcon = <CheckBoxIcon fontSize="small" />;
const filter = createFilterOptions();
const validationSchema = Yup.object().shape({
is_edit: Yup.boolean(),
type: Yup.string().oneOf(['tokens', 'times'], '类型 错误').required('类型 不能为空'),
channel_type: Yup.number().min(1, '渠道类型 错误').required('渠道类型 不能为空'),
input: Yup.number().required('输入倍率 不能为空'),
output: Yup.number().required('输出倍率 不能为空'),
models: Yup.array().min(1, '模型 不能为空')
});
const originInputs = {
is_edit: false,
type: 'tokens',
channel_type: 1,
input: 0,
output: 0,
models: []
};
const EditModal = ({ open, pricesItem, onCancel, onOk, ownedby, noPriceModel }) => {
const theme = useTheme();
const [inputs, setInputs] = useState(originInputs);
const [selectModel, setSelectModel] = useState([]);
const submit = async (values, { setErrors, setStatus, setSubmitting }) => {
setSubmitting(true);
try {
const res = await API.post(`/api/prices/multiple`, {
original_models: inputs.models,
models: values.models,
price: {
model: 'batch',
type: values.type,
channel_type: values.channel_type,
input: values.input,
output: values.output
}
});
const { success, message } = res.data;
if (success) {
showSuccess('保存成功!');
setSubmitting(false);
setStatus({ success: true });
onOk(true);
return;
} else {
setStatus({ success: false });
showError(message);
setErrors({ submit: message });
}
} catch (error) {
setStatus({ success: false });
showError(error.message);
setErrors({ submit: error.message });
return;
}
onOk();
};
useEffect(() => {
if (pricesItem) {
setSelectModel(pricesItem.models.concat(noPriceModel));
} else {
setSelectModel(noPriceModel);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [pricesItem, noPriceModel]);
useEffect(() => {
if (pricesItem) {
setInputs(pricesItem);
} else {
setInputs(originInputs);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [pricesItem]);
return (
<Dialog open={open} onClose={onCancel} fullWidth maxWidth={'md'}>
<DialogTitle sx={{ margin: '0px', fontWeight: 700, lineHeight: '1.55556', padding: '24px', fontSize: '1.125rem' }}>
{pricesItem ? '编辑' : '新建'}
</DialogTitle>
<Divider />
<DialogContent>
<Formik initialValues={inputs} enableReinitialize validationSchema={validationSchema} onSubmit={submit}>
{({ errors, handleBlur, handleChange, handleSubmit, touched, values, isSubmitting }) => (
<form noValidate onSubmit={handleSubmit}>
<FormControl fullWidth error={Boolean(touched.type && errors.type)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="type-label">名称</InputLabel>
<Select
id="type-label"
label="类型"
value={values.type}
name="type"
onBlur={handleBlur}
onChange={handleChange}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
}
}}
>
{Object.values(priceType).map((option) => {
return (
<MenuItem key={option.value} value={option.value}>
{option.label}
</MenuItem>
);
})}
</Select>
{touched.type && errors.type && (
<FormHelperText error id="helper-tex-type-label">
{errors.type}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.channel_type && errors.channel_type)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel_type-label">渠道类型</InputLabel>
<Select
id="channel_type-label"
label="渠道类型"
value={values.channel_type}
name="channel_type"
onBlur={handleBlur}
onChange={handleChange}
MenuProps={{
PaperProps: {
style: {
maxHeight: 200
}
}
}}
>
{Object.values(ownedby).map((option) => {
return (
<MenuItem key={option.value} value={option.value}>
{option.label}
</MenuItem>
);
})}
</Select>
{touched.channel_type && errors.channel_type && (
<FormHelperText error id="helper-tex-channel_type-label">
{errors.channel_type}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.input && errors.input)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-input-label">输入倍率</InputLabel>
<OutlinedInput
id="channel-input-label"
label="输入倍率"
type="number"
value={values.input}
name="input"
endAdornment={<InputAdornment position="end">{ValueFormatter(values.input)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-input-label"
/>
{touched.input && errors.input && (
<FormHelperText error id="helper-tex-channel-input-label">
{errors.input}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth error={Boolean(touched.output && errors.output)} sx={{ ...theme.typography.otherInput }}>
<InputLabel htmlFor="channel-output-label">输出倍率</InputLabel>
<OutlinedInput
id="channel-output-label"
label="输出倍率"
type="number"
value={values.output}
name="output"
endAdornment={<InputAdornment position="end">{ValueFormatter(values.output)}</InputAdornment>}
onBlur={handleBlur}
onChange={handleChange}
aria-describedby="helper-text-channel-output-label"
/>
{touched.output && errors.output && (
<FormHelperText error id="helper-tex-channel-output-label">
{errors.output}
</FormHelperText>
)}
</FormControl>
<FormControl fullWidth sx={{ ...theme.typography.otherInput }}>
<Autocomplete
multiple
freeSolo
id="channel-models-label"
options={selectModel}
value={values.models}
onChange={(e, value) => {
const event = {
target: {
name: 'models',
value: value
}
};
handleChange(event);
}}
onBlur={handleBlur}
// filterSelectedOptions
disableCloseOnSelect
renderInput={(params) => <TextField {...params} name="models" error={Boolean(errors.models)} label="模型" />}
filterOptions={(options, params) => {
const filtered = filter(options, params);
const { inputValue } = params;
const isExisting = options.some((option) => inputValue === option);
if (inputValue !== '' && !isExisting) {
filtered.push(inputValue);
}
return filtered;
}}
renderOption={(props, option, { selected }) => (
<li {...props}>
<Checkbox icon={icon} checkedIcon={checkedIcon} style={{ marginRight: 8 }} checked={selected} />
{option}
</li>
)}
/>
{errors.models ? (
<FormHelperText error id="helper-tex-channel-models-label">
{errors.models}
</FormHelperText>
) : (
<FormHelperText id="helper-tex-channel-models-label">
{' '}
请选择该价格所支持的模型,你也可以输入通配符*来匹配模型例如gpt-3.5*表示支持所有gpt-3.5开头的模型*号只能在最后一位使用前面必须有字符例如gpt-3.5*是正确的*gpt-3.5是错误的{' '}
</FormHelperText>
)}
</FormControl>
<DialogActions>
<Button onClick={onCancel}>取消</Button>
<Button disableElevation disabled={isSubmitting} type="submit" variant="contained" color="primary">
提交
</Button>
</DialogActions>
</form>
)}
</Formik>
</DialogContent>
</Dialog>
);
};
export default EditModal;
EditModal.propTypes = {
open: PropTypes.bool,
pricesItem: PropTypes.object,
onCancel: PropTypes.func,
onOk: PropTypes.func,
ownedby: PropTypes.array,
noPriceModel: PropTypes.array
};

View File

@ -0,0 +1,156 @@
import PropTypes from 'prop-types';
import { useState } from 'react';
import {
Popover,
TableRow,
MenuItem,
TableCell,
IconButton,
Dialog,
DialogActions,
DialogContent,
DialogContentText,
DialogTitle,
Collapse,
Grid,
Box,
Typography,
Button
} from '@mui/material';
import { IconDotsVertical, IconEdit, IconTrash } from '@tabler/icons-react';
import { ValueFormatter, priceType } from './util';
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown';
import KeyboardArrowUpIcon from '@mui/icons-material/KeyboardArrowUp';
import Label from 'ui-component/Label';
import { copy } from 'utils/common';
export default function PricesTableRow({ item, managePrices, handleOpenModal, setModalPricesItem, ownedby }) {
const [open, setOpen] = useState(null);
const [openRow, setOpenRow] = useState(false);
const [openDelete, setOpenDelete] = useState(false);
const type_label = priceType.find((pt) => pt.value === item.type);
const channel_label = ownedby.find((ob) => ob.value === item.channel_type);
const handleDeleteOpen = () => {
handleCloseMenu();
setOpenDelete(true);
};
const handleDeleteClose = () => {
setOpenDelete(false);
};
const handleOpenMenu = (event) => {
setOpen(event.currentTarget);
};
const handleCloseMenu = () => {
setOpen(null);
};
const handleDelete = async () => {
handleDeleteClose();
await managePrices(item, 'delete', '');
};
return (
<>
<TableRow tabIndex={item.id} onClick={() => setOpenRow(!openRow)}>
<TableCell>
<IconButton aria-label="expand row" size="small" onClick={() => setOpenRow(!openRow)}>
{openRow ? <KeyboardArrowUpIcon /> : <KeyboardArrowDownIcon />}
</IconButton>
</TableCell>
<TableCell>{type_label?.label}</TableCell>
<TableCell>{channel_label?.label}</TableCell>
<TableCell>{ValueFormatter(item.input)}</TableCell>
<TableCell>{ValueFormatter(item.output)}</TableCell>
<TableCell>{item.models.length}</TableCell>
<TableCell onClick={(event) => event.stopPropagation()}>
<IconButton onClick={handleOpenMenu} sx={{ color: 'rgb(99, 115, 129)' }}>
<IconDotsVertical />
</IconButton>
</TableCell>
</TableRow>
<TableRow>
<TableCell style={{ paddingBottom: 0, paddingTop: 0, textAlign: 'left' }} colSpan={10}>
<Collapse in={openRow} timeout="auto" unmountOnExit>
<Grid container spacing={1}>
<Grid item xs={12}>
<Box sx={{ display: 'flex', flexWrap: 'wrap', gap: '10px', margin: 1 }}>
<Typography variant="h6" gutterBottom component="div">
可用模型:
</Typography>
{item.models.map((model) => (
<Label
variant="outlined"
color="primary"
key={model}
onClick={() => {
copy(model, '模型名称');
}}
>
{model}
</Label>
))}
</Box>
</Grid>
</Grid>
</Collapse>
</TableCell>
</TableRow>
<Popover
open={!!open}
anchorEl={open}
onClose={handleCloseMenu}
anchorOrigin={{ vertical: 'top', horizontal: 'left' }}
transformOrigin={{ vertical: 'top', horizontal: 'right' }}
PaperProps={{
sx: { width: 140 }
}}
>
<MenuItem
onClick={() => {
handleCloseMenu();
handleOpenModal();
setModalPricesItem(item);
}}
>
<IconEdit style={{ marginRight: '16px' }} />
编辑
</MenuItem>
<MenuItem onClick={handleDeleteOpen} sx={{ color: 'error.main' }}>
<IconTrash style={{ marginRight: '16px' }} />
删除
</MenuItem>
</Popover>
<Dialog open={openDelete} onClose={handleDeleteClose}>
<DialogTitle>删除价格组</DialogTitle>
<DialogContent>
<DialogContentText>是否删除价格组</DialogContentText>
</DialogContent>
<DialogActions>
<Button onClick={handleDeleteClose}>关闭</Button>
<Button onClick={handleDelete} sx={{ color: 'error.main' }} autoFocus>
删除
</Button>
</DialogActions>
</Dialog>
</>
);
}
PricesTableRow.propTypes = {
item: PropTypes.object,
managePrices: PropTypes.func,
handleOpenModal: PropTypes.func,
setModalPricesItem: PropTypes.func,
priceType: PropTypes.array,
ownedby: PropTypes.array
};

View File

@ -0,0 +1,11 @@
export const priceType = [
{ value: 'tokens', label: '按Token收费' },
{ value: 'times', label: '按次收费' }
];
export function ValueFormatter(value) {
if (value == null) {
return '';
}
return `$${parseFloat(value * 0.002).toFixed(4)} / ¥${parseFloat(value * 0.014).toFixed(4)}`;
}

View File

@ -0,0 +1,221 @@
import { useState, useEffect, useMemo, useCallback } from 'react';
import PropTypes from 'prop-types';
import { Tabs, Tab, Box, Card, Alert, Stack, Button } from '@mui/material';
import { IconTag, IconTags } from '@tabler/icons-react';
import Single from './single';
import Multiple from './multiple';
import { useLocation, useNavigate } from 'react-router-dom';
import AdminContainer from 'ui-component/AdminContainer';
import { API } from 'utils/api';
import { showError } from 'utils/common';
import { CheckUpdates } from './component/CheckUpdates';
function CustomTabPanel(props) {
const { children, value, index, ...other } = props;
return (
<div role="tabpanel" hidden={value !== index} id={`pricing-tabpanel-${index}`} aria-labelledby={`pricing-tab-${index}`} {...other}>
{value === index && <Box sx={{ p: 3 }}>{children}</Box>}
</div>
);
}
CustomTabPanel.propTypes = {
children: PropTypes.node,
index: PropTypes.number.isRequired,
value: PropTypes.number.isRequired
};
function a11yProps(index) {
return {
id: `pricing-tab-${index}`,
'aria-controls': `pricing-tabpanel-${index}`
};
}
const Pricing = () => {
const [ownedby, setOwnedby] = useState([]);
const [modelList, setModelList] = useState([]);
const [openModal, setOpenModal] = useState(false);
const [errPrices, setErrPrices] = useState('');
const [prices, setPrices] = useState([]);
const [noPriceModel, setNoPriceModel] = useState([]);
const location = useLocation();
const navigate = useNavigate();
const hash = location.hash.replace('#', '');
const tabMap = useMemo(
() => ({
single: 0,
multiple: 1
}),
[]
);
const [value, setValue] = useState(tabMap[hash] || 0);
const handleChange = (event, newValue) => {
setValue(newValue);
const hashArray = Object.keys(tabMap);
navigate(`#${hashArray[newValue]}`);
};
const reloadData = () => {
fetchModelList();
fetchPrices();
};
const handleOkModal = (status) => {
if (status === true) {
reloadData();
setOpenModal(false);
}
};
useEffect(() => {
const missingModels = modelList.filter((model) => !prices.some((price) => price.model === model));
setNoPriceModel(missingModels);
}, [modelList, prices]);
useEffect(() => {
// check if there is any price that is not valid
const invalidPrices = prices.filter((price) => price.channel_type <= 0);
if (invalidPrices.length > 0) {
setErrPrices(invalidPrices.map((price) => price.model).join(', '));
} else {
setErrPrices('');
}
}, [prices]);
const fetchOwnedby = useCallback(async () => {
try {
const res = await API.get('/api/ownedby');
const { success, message, data } = res.data;
if (success) {
let ownedbyList = [];
for (let key in data) {
ownedbyList.push({ value: parseInt(key), label: data[key] });
}
setOwnedby(ownedbyList);
} else {
showError(message);
}
} catch (error) {
console.error(error);
}
}, []);
const fetchModelList = useCallback(async () => {
try {
const res = await API.get('/api/prices/model_list');
const { success, message, data } = res.data;
if (success) {
setModelList(data);
} else {
showError(message);
}
} catch (error) {
console.error(error);
}
}, []);
const fetchPrices = useCallback(async () => {
try {
const res = await API.get('/api/prices');
const { success, message, data } = res.data;
if (success) {
setPrices(data);
} else {
showError(message);
}
} catch (error) {
console.error(error);
}
}, []);
useEffect(() => {
const handleHashChange = () => {
const hash = location.hash.replace('#', '');
setValue(tabMap[hash] || 0);
};
window.addEventListener('hashchange', handleHashChange);
return () => {
window.removeEventListener('hashchange', handleHashChange);
};
}, [location, tabMap, fetchOwnedby]);
useEffect(() => {
const fetchData = async () => {
try {
await Promise.all([fetchOwnedby(), fetchModelList()]);
fetchPrices();
} catch (error) {
console.error(error);
}
};
fetchData();
}, [fetchOwnedby, fetchModelList, fetchPrices]);
return (
<Stack spacing={3}>
<Alert severity="info">
<b>美元</b>1 === $0.002 / 1K tokens <b>人民币</b> 1 === ¥0.014 / 1k tokens
<br /> <b>例如</b><br /> gpt-4 输入 $0.03 / 1K tokens 完成$0.06 / 1K tokens <br />
0.03 / 0.002 = 15, 0.06 / 0.002 = 30即输入倍率为 15完成倍率为 30
</Alert>
{noPriceModel.length > 0 && (
<Alert severity="warning">
<b>存在未配置价格的模型请及时配置价格</b>
{noPriceModel.map((model) => (
<span key={model}>{model}, </span>
))}
</Alert>
)}
{errPrices && (
<Alert severity="warning">
<b>存在供应商类型错误的模型请及时配置</b>{errPrices}
</Alert>
)}
<Stack direction="row" alignItems="center" justifyContent="flex-end" mb={5} spacing={2}>
<Button
variant="contained"
onClick={() => {
setOpenModal(true);
}}
>
更新价格
</Button>
</Stack>
<Card>
<AdminContainer>
<Box sx={{ width: '100%' }}>
<Box sx={{ borderBottom: 1, borderColor: 'divider' }}>
<Tabs value={value} onChange={handleChange} variant="scrollable" scrollButtons="auto">
<Tab label="单条操作" {...a11yProps(0)} icon={<IconTag />} iconPosition="start" />
<Tab label="合并操作" {...a11yProps(1)} icon={<IconTags />} iconPosition="start" />
</Tabs>
</Box>
<CustomTabPanel value={value} index={0}>
<Single ownedby={ownedby} reloadData={reloadData} prices={prices} />
</CustomTabPanel>
<CustomTabPanel value={value} index={1}>
<Multiple ownedby={ownedby} reloadData={reloadData} prices={prices} noPriceModel={noPriceModel} />
</CustomTabPanel>
</Box>
</AdminContainer>
</Card>
<CheckUpdates
open={openModal}
onCancel={() => {
setOpenModal(false);
}}
row={prices}
onOk={handleOkModal}
/>
</Stack>
);
};
export default Pricing;

View File

@ -0,0 +1,149 @@
import PropTypes from 'prop-types';
import { useState, useEffect } from 'react';
import { showError, showSuccess } from 'utils/common';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableContainer from '@mui/material/TableContainer';
import PerfectScrollbar from 'react-perfect-scrollbar';
import { Button, Card, Stack } from '@mui/material';
import PricesTableRow from './component/TableRow';
import KeywordTableHead from 'ui-component/TableHead';
import { API } from 'utils/api';
import { IconRefresh, IconPlus } from '@tabler/icons-react';
import EditeModal from './component/EditModal';
// ----------------------------------------------------------------------
export default function Multiple({ ownedby, prices, reloadData, noPriceModel }) {
const [rows, setRows] = useState([]);
const [openModal, setOpenModal] = useState(false);
const [editPricesItem, setEditPricesItem] = useState(null);
// 处理刷新
const handleRefresh = async () => {
reloadData();
};
useEffect(() => {
const grouped = prices.reduce((acc, item, index) => {
const key = `${item.type}-${item.channel_type}-${item.input}-${item.output}`;
if (!acc[key]) {
acc[key] = {
...item,
models: [item.model],
id: index + 1
};
} else {
acc[key].models.push(item.model);
}
return acc;
}, {});
setRows(Object.values(grouped));
}, [prices]);
const managePrices = async (item, action) => {
let res;
try {
switch (action) {
case 'delete':
res = await API.put('/api/prices/multiple/delete', {
models: item.models
});
break;
}
const { success, message } = res.data;
if (success) {
showSuccess('操作成功完成!');
if (action === 'delete') {
await handleRefresh();
}
} else {
showError(message);
}
return res.data;
} catch (error) {
return;
}
};
const handleOpenModal = (item) => {
setEditPricesItem(item);
setOpenModal(true);
};
const handleCloseModal = () => {
setOpenModal(false);
setEditPricesItem(null);
};
const handleOkModal = (status) => {
if (status === true) {
handleCloseModal();
handleRefresh();
}
};
return (
<>
<Stack direction="row" alignItems="center" justifyContent="flex-start" mb={5} spacing={2}>
<Button variant="contained" color="primary" startIcon={<IconPlus />} onClick={() => handleOpenModal(0)}>
新建
</Button>
<Button variant="contained" onClick={handleRefresh} startIcon={<IconRefresh width={'18px'} />}>
刷新
</Button>
</Stack>
<Card>
<PerfectScrollbar component="div">
<TableContainer sx={{ overflow: 'unset' }}>
<Table sx={{ minWidth: 800 }}>
<KeywordTableHead
headLabel={[
{ id: 'collapse', label: '', disableSort: true },
{ id: 'type', label: '类型', disableSort: true },
{ id: 'channel_type', label: '供应商', disableSort: true },
{ id: 'input', label: '输入倍率', disableSort: true },
{ id: 'output', label: '输出倍率', disableSort: true },
{ id: 'count', label: '模型数量', disableSort: true },
{ id: 'action', label: '操作', disableSort: true }
]}
/>
<TableBody>
{rows.map((row) => (
<PricesTableRow
item={row}
managePrices={managePrices}
key={row.id}
handleOpenModal={handleOpenModal}
setModalPricesItem={setEditPricesItem}
ownedby={ownedby}
/>
))}
</TableBody>
</Table>
</TableContainer>
</PerfectScrollbar>
</Card>
<EditeModal
open={openModal}
onCancel={handleCloseModal}
onOk={handleOkModal}
pricesItem={editPricesItem}
ownedby={ownedby}
noPriceModel={noPriceModel}
/>
</>
);
}
Multiple.propTypes = {
prices: PropTypes.array,
ownedby: PropTypes.array,
reloadData: PropTypes.func,
noPriceModel: PropTypes.array
};

View File

@ -1,19 +1,30 @@
import PropTypes from 'prop-types';
import { useState, useEffect, useMemo, useCallback } from 'react';
import { GridRowModes, DataGrid, GridToolbarContainer, GridActionsCellItem } from '@mui/x-data-grid';
import { Box, Button } from '@mui/material';
import { Box, Button, Dialog, DialogActions, DialogContent, DialogTitle } from '@mui/material';
import AddIcon from '@mui/icons-material/Add';
import EditIcon from '@mui/icons-material/Edit';
import DeleteIcon from '@mui/icons-material/DeleteOutlined';
import SaveIcon from '@mui/icons-material/Save';
import CancelIcon from '@mui/icons-material/Close';
import { showError } from 'utils/common';
import { showError, showSuccess } from 'utils/common';
import { API } from 'utils/api';
import { ValueFormatter, priceType } from './component/util';
function validation(row, rows) {
if (row.model === '') {
return '模型名称不能为空';
}
// 判断 type 是否是 等于 tokens || times
if (row.type !== 'tokens' && row.type !== 'times') {
return '类型只能是tokens或times';
}
if (row.channel_type <= 0) {
return '所属渠道类型错误';
}
// 判断 model是否是唯一值
if (rows.filter((r) => r.model === row.model && (row.isNew || r.id !== row.id)).length > 0) {
return '模型名称不能重复';
@ -22,8 +33,8 @@ function validation(row, rows) {
if (row.input === '' || row.input < 0) {
return '输入倍率必须大于等于0';
}
if (row.complete === '' || row.complete < 0) {
return '完成倍率必须大于等于0';
if (row.output === '' || row.output < 0) {
return '输出倍率必须大于等于0';
}
return false;
}
@ -35,7 +46,7 @@ function randomId() {
function EditToolbar({ setRows, setRowModesModel }) {
const handleClick = () => {
const id = randomId();
setRows((oldRows) => [{ id, model: '', input: 0, complete: 0, isNew: true }, ...oldRows]);
setRows((oldRows) => [{ id, model: '', type: 'tokens', channel_type: 1, input: 0, output: 0, isNew: true }, ...oldRows]);
setRowModesModel((oldModel) => ({
[id]: { mode: GridRowModes.Edit, fieldToFocus: 'name' },
...oldModel
@ -56,19 +67,33 @@ EditToolbar.propTypes = {
setRowModesModel: PropTypes.func.isRequired
};
const ModelRationDataGrid = ({ ratio, onChange }) => {
const Single = ({ ownedby, prices, reloadData }) => {
const [rows, setRows] = useState([]);
const [rowModesModel, setRowModesModel] = useState({});
const [selectedRow, setSelectedRow] = useState(null);
const setRatio = useCallback(
(ratioRow) => {
let ratioJson = {};
ratioRow.forEach((row) => {
ratioJson[row.model] = [row.input, row.complete];
});
onChange({ target: { name: 'ModelRatio', value: JSON.stringify(ratioJson, null, 2) } });
const addOrUpdatePirces = useCallback(
async (newRow, oldRow, reject, resolve) => {
try {
let res;
if (oldRow.model == '') {
res = await API.post('/api/prices/single', newRow);
} else {
res = await API.put('/api/prices/single/' + oldRow.model, newRow);
}
const { success, message } = res.data;
if (success) {
showSuccess('保存成功');
resolve(newRow);
reloadData();
} else {
reject(new Error(message));
}
} catch (error) {
reject(new Error(error));
}
},
[onChange]
[reloadData]
);
const handleEditClick = useCallback(
@ -87,11 +112,21 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
const handleDeleteClick = useCallback(
(id) => () => {
setRatio(rows.filter((row) => row.id !== id));
setSelectedRow(rows.find((row) => row.id === id));
},
[rows, setRatio]
[rows]
);
const handleClose = () => {
setSelectedRow(null);
};
const handleConfirmDelete = async () => {
// 执行删除操作
await deletePirces(selectedRow.model);
setSelectedRow(null);
};
const handleCancelClick = useCallback(
(id) => () => {
setRowModesModel({
@ -107,18 +142,30 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
[rowModesModel, rows]
);
const processRowUpdate = (newRow, oldRows) => {
if (!newRow.isNew && newRow.model === oldRows.model && newRow.input === oldRows.input && newRow.complete === oldRows.complete) {
return oldRows;
}
const updatedRow = { ...newRow, isNew: false };
const error = validation(updatedRow, rows);
if (error) {
return Promise.reject(new Error(error));
}
setRatio(rows.map((row) => (row.id === newRow.id ? updatedRow : row)));
return updatedRow;
};
const processRowUpdate = useCallback(
(newRow, oldRows) =>
new Promise((resolve, reject) => {
if (
!newRow.isNew &&
newRow.model === oldRows.model &&
newRow.input === oldRows.input &&
newRow.output === oldRows.output &&
newRow.type === oldRows.type &&
newRow.channel_type === oldRows.channel_type
) {
return resolve(oldRows);
}
const updatedRow = { ...newRow, isNew: false };
const error = validation(updatedRow, rows);
if (error) {
return reject(new Error(error));
}
const response = addOrUpdatePirces(updatedRow, oldRows, reject, resolve);
return response;
}),
[rows, addOrUpdatePirces]
);
const handleProcessRowUpdateError = useCallback((error) => {
showError(error.message);
@ -138,6 +185,26 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
editable: true,
hideable: false
},
{
field: 'type',
sortable: true,
headerName: '类型',
width: 220,
type: 'singleSelect',
valueOptions: priceType,
editable: true,
hideable: false
},
{
field: 'channel_type',
sortable: true,
headerName: '供应商',
width: 220,
type: 'singleSelect',
valueOptions: ownedby,
editable: true,
hideable: false
},
{
field: 'input',
sortable: false,
@ -145,27 +212,17 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
width: 150,
type: 'number',
editable: true,
valueFormatter: (params) => {
if (params.value == null) {
return '';
}
return `$${parseFloat(params.value * 0.002).toFixed(4)} / ¥${parseFloat(params.value * 0.014).toFixed(4)}`;
},
valueFormatter: (params) => ValueFormatter(params.value),
hideable: false
},
{
field: 'complete',
field: 'output',
sortable: false,
headerName: '完成倍率',
headerName: '输出倍率',
width: 150,
type: 'number',
editable: true,
valueFormatter: (params) => {
if (params.value == null) {
return '';
}
return `$${parseFloat(params.value * 0.002).toFixed(4)} / ¥${parseFloat(params.value * 0.014).toFixed(4)}`;
},
valueFormatter: (params) => ValueFormatter(params.value),
hideable: false
},
{
@ -220,18 +277,32 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
}
}
],
[handleEditClick, handleSaveClick, handleDeleteClick, handleCancelClick, rowModesModel]
[handleCancelClick, handleDeleteClick, handleEditClick, handleSaveClick, rowModesModel, ownedby]
);
const deletePirces = async (modelName) => {
try {
const res = await API.delete('/api/prices/single/' + modelName);
const { success, message } = res.data;
if (success) {
showSuccess('保存成功');
await reloadData();
} else {
showError(message);
}
} catch (error) {
console.error(error);
}
};
useEffect(() => {
let modelRatioList = [];
let itemJson = JSON.parse(ratio);
let id = 0;
for (let key in itemJson) {
modelRatioList.push({ id: id++, model: key, input: itemJson[key][0], complete: itemJson[key][1] });
for (let key in prices) {
modelRatioList.push({ id: id++, ...prices[key] });
}
setRows(modelRatioList);
}, [ratio]);
}, [prices]);
return (
<Box
@ -256,6 +327,14 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
onRowModesModelChange={handleRowModesModelChange}
processRowUpdate={processRowUpdate}
onProcessRowUpdateError={handleProcessRowUpdateError}
// onCellDoubleClick={(params, event) => {
// event.defaultMuiPrevented = true;
// }}
onRowEditStop={(params, event) => {
if (params.reason === 'rowFocusOut') {
event.defaultMuiPrevented = true;
}
}}
slots={{
toolbar: EditToolbar
}}
@ -263,13 +342,27 @@ const ModelRationDataGrid = ({ ratio, onChange }) => {
toolbar: { setRows, setRowModesModel }
}}
/>
<Dialog
maxWidth="xs"
// TransitionProps={{ onEntered: handleEntered }}
open={!!selectedRow}
>
<DialogTitle>确定删除?</DialogTitle>
<DialogContent dividers>{`确定删除 ${selectedRow?.model} 吗?`}</DialogContent>
<DialogActions>
<Button onClick={handleClose}>取消</Button>
<Button onClick={handleConfirmDelete}>删除</Button>
</DialogActions>
</Dialog>
</Box>
);
};
ModelRationDataGrid.propTypes = {
ratio: PropTypes.string.isRequired,
onChange: PropTypes.func.isRequired
};
export default Single;
export default ModelRationDataGrid;
Single.propTypes = {
prices: PropTypes.array,
ownedby: PropTypes.array,
reloadData: PropTypes.func
};

View File

@ -1,12 +1,11 @@
import { useState, useEffect } from 'react';
import SubCard from 'ui-component/cards/SubCard';
import { Stack, FormControl, InputLabel, OutlinedInput, Checkbox, Button, FormControlLabel, TextField, Alert, Switch } from '@mui/material';
import { Stack, FormControl, InputLabel, OutlinedInput, Checkbox, Button, FormControlLabel, TextField } from '@mui/material';
import { showSuccess, showError, verifyJSON } from 'utils/common';
import { API } from 'utils/api';
import { AdapterDayjs } from '@mui/x-date-pickers/AdapterDayjs';
import { LocalizationProvider } from '@mui/x-date-pickers/LocalizationProvider';
import { DateTimePicker } from '@mui/x-date-pickers/DateTimePicker';
import ModelRationDataGrid from './ModelRationDataGrid';
import dayjs from 'dayjs';
require('dayjs/locale/zh-cn');
@ -18,7 +17,6 @@ const OperationSetting = () => {
QuotaForInvitee: 0,
QuotaRemindThreshold: 0,
PreConsumedQuota: 0,
ModelRatio: '',
GroupRatio: '',
TopUpLink: '',
ChatLink: '',
@ -34,7 +32,6 @@ const OperationSetting = () => {
RetryCooldownSeconds: 0
});
const [originInputs, setOriginInputs] = useState({});
const [newModelRatioView, setNewModelRatioView] = useState(false);
let [loading, setLoading] = useState(false);
let [historyTimestamp, setHistoryTimestamp] = useState(now.getTime() / 1000 - 30 * 24 * 3600); // a month ago new Date().getTime() / 1000 + 3600
@ -45,7 +42,7 @@ const OperationSetting = () => {
if (success) {
let newInputs = {};
data.forEach((item) => {
if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
if (item.key === 'GroupRatio') {
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
}
newInputs[item.key] = item.value;
@ -110,13 +107,6 @@ const OperationSetting = () => {
}
break;
case 'ratio':
if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
if (!verifyJSON(inputs.ModelRatio)) {
showError('模型倍率不是合法的 JSON 字符串');
return;
}
await updateOption('ModelRatio', inputs.ModelRatio);
}
if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
if (!verifyJSON(inputs.GroupRatio)) {
showError('分组倍率不是合法的 JSON 字符串');
@ -469,44 +459,6 @@ const OperationSetting = () => {
/>
</FormControl>
<FormControl fullWidth>
<Alert severity="info">
配置格式为 JSON 文本键为模型名称值第一位为输入倍率第二位为完成倍率如果只有单一倍率则两者值相同
<br /> <b>美元</b>1 === $0.002 / 1K tokens <b>人民币</b> 1 === ¥0.014 / 1k tokens
<br /> <b>例如</b><br /> gpt-4 输入 $0.03 / 1K tokens 完成$0.06 / 1K tokens <br />
0.03 / 0.002 = 15, 0.06 / 0.002 = 30即输入倍率为 15完成倍率为 30
</Alert>
<FormControlLabel
control={
<Switch
checked={newModelRatioView}
onChange={() => {
setNewModelRatioView(!newModelRatioView);
}}
/>
}
label="使用新编辑器"
/>
</FormControl>
{newModelRatioView ? (
<ModelRationDataGrid ratio={inputs.ModelRatio} onChange={handleInputChange} />
) : (
<FormControl fullWidth>
<TextField
multiline
maxRows={15}
id="channel-ModelRatio-label"
label="模型倍率"
value={inputs.ModelRatio}
name="ModelRatio"
onChange={handleInputChange}
aria-describedby="helper-text-channel-ModelRatio-label"
minRows={5}
placeholder="为一个 JSON 文本,键为模型名称,值为倍率"
/>
</FormControl>
)}
<Button
variant="contained"
onClick={() => {