diff --git a/.gitignore b/.gitignore index 1b2cf071..9a5acecb 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ upload *.db build *.db-journal -logs \ No newline at end of file +logs +.env diff --git a/README.md b/README.md index 1a3a1bd0..513a0205 100644 --- a/README.md +++ b/README.md @@ -376,6 +376,7 @@ graph LR 14. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 +15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/constants.go b/common/constants.go index f6eb2c40..c4004b5b 100644 --- a/common/constants.go +++ b/common/constants.go @@ -21,13 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens var DisplayInCurrencyEnabled = true var DisplayTokenStatEnabled = true -var UsingSQLite = false -var UsingPostgreSQL = false - // Any options with "Secret", "Token" in its key won't be return by GetOptions var SessionSecret = uuid.New().String() -var SQLitePath = "one-api.db" var OptionMap map[string]string var OptionMapRWMutex sync.RWMutex @@ -109,6 +105,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second + const ( RequestIdKey = "X-Oneapi-Request-Id" ) diff --git a/common/database.go b/common/database.go new file mode 100644 index 00000000..c7e9fd52 --- /dev/null +++ b/common/database.go @@ -0,0 +1,6 @@ +package common + +var UsingSQLite = false +var UsingPostgreSQL = false + +var SQLitePath = "one-api.db" diff --git a/common/model-ratio.go b/common/model-ratio.go index f1ce99a5..f52e1101 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{ "claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens diff --git a/common/utils.go b/common/utils.go index ab901b77..21bec8f5 100644 --- a/common/utils.go +++ b/common/utils.go @@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int { func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/controller/model.go b/controller/model.go index e9b64514..e5f3fdbe 100644 --- a/controller/model.go +++ b/controller/model.go @@ -306,6 +306,15 @@ func init() { Root: "ERNIE-Bot-turbo", Parent: nil, }, + { + Id: "ERNIE-Bot-4", + Object: "model", + Created: 1677649963, + OwnedBy: "baidu", + Permission: permission, + Root: "ERNIE-Bot-4", + Parent: nil, + }, { Id: "Embedding-V1", Object: "model", diff --git a/controller/relay-text.go b/controller/relay-text.go index 35e4dbc7..25f9e081 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -32,7 +32,14 @@ var httpClient *http.Client var impatientHTTPClient *http.Client func init() { - httpClient = &http.Client{} + if common.RelayTimeout == 0 { + httpClient = &http.Client{} + } else { + httpClient = &http.Client{ + Timeout: time.Duration(common.RelayTimeout) * time.Second, + } + } + impatientHTTPClient = &http.Client{ Timeout: 5 * time.Second, } @@ -152,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" case "ERNIE-Bot-turbo": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "ERNIE-Bot-4": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" case "BLOOMZ-7B": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" case "Embedding-V1": diff --git a/i18n/en.json b/i18n/en.json index 993ddd02..dcdb4b93 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -553,6 +553,9 @@ "请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次": "Please enter the access token, the current version does not support automatic refresh, please update it every 30 days", "此项可选,用于通过Mirror站来进行 API 调用,请EnterMirror站地址,格式为:https://domain.com": "This is optional, used to make API calls through the Mirror site, please enter the Mirror site address, the format is: https://domain.com", + "非流式": "Non-streaming", + "流式": "Streaming", + "新密码": "New password", "新密码已复制到剪贴板:": "New password copied to clipboard: ", "兑换失败,": "Redemption failed, ", diff --git a/middleware/distributor.go b/middleware/distributor.go index a737fb02..a7243554 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -71,7 +71,13 @@ func Distribute() func(c *gin.Context) { } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream) if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + streamText := "非流式" + + if modelRequest.Stream { + streamText = "流式" + } + + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道 (%s)", userGroup, modelRequest.Model, streamText) if channel != nil { common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" diff --git a/model/ability.go b/model/ability.go index d59c40f8..01e5bb62 100644 --- a/model/ability.go +++ b/model/ability.go @@ -17,35 +17,25 @@ type Ability struct { func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) { ability := Ability{} + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var err error = nil - var cmdWhere *Ability + cmdWhere := groupCol + " = ? and model = ? and enabled = " + trueVal if stream { - cmdWhere = &Ability{ - Group: group, - Model: model, - Enabled: true, - AllowStreaming: common.ChannelAllowStreamEnabled, - } + cmdWhere += " and allow_streaming = 1" } else { - cmdWhere = &Ability{ - Group: group, - Model: model, - Enabled: true, - AllowNonStreaming: common.ChannelAllowNonStreamEnabled, - } + cmdWhere += " and allow_non_streaming = 1" } - maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere) - - cmd1 := "`group` = ? and model = ? and enabled = 1 and priority = (?)" - - if common.UsingPostgreSQL { - cmd1 = "\"group\" = ? and model = ? and enabled = 1 and priority = (?)" - } - - channelQuery := DB.Where(cmd1, group, model, maxPrioritySubQuery) + maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere, group, model) + channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery) if common.UsingSQLite || common.UsingPostgreSQL { err = channelQuery.Order("RANDOM()").First(&ability).Error } else { diff --git a/model/cache.go b/model/cache.go index c9e5b229..b00e1567 100644 --- a/model/cache.go +++ b/model/cache.go @@ -21,15 +21,19 @@ var ( ) func CacheGetTokenByKey(key string) (*Token, error) { + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } var token Token if !common.RedisEnabled { - err := DB.Where(&Token{Key: key}).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error return &token, err } tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) if err != nil { - err := DB.Where(&Token{Key: key}).First(&token).Error + err := DB.Where(keyCol+" = ?", key).First(&token).Error if err != nil { return nil, err } diff --git a/model/channel.go b/model/channel.go index 9e575e97..4bb4ebec 100644 --- a/model/channel.go +++ b/model/channel.go @@ -59,19 +59,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) { return &channel, err } -func GetRandomChannel() (*Channel, error) { - channel := Channel{} - var err error = nil - if common.UsingPostgreSQL { - err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error - } else if common.UsingSQLite { - err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error - } else { - err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RAND()").Limit(1).First(&channel).Error - } - return &channel, err -} - func BatchInsertChannels(channels []Channel) error { var err error err = DB.Create(&channels).Error diff --git a/model/redemption.go b/model/redemption.go index 9fa71971..f347345d 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -51,8 +51,13 @@ func Redeem(key string, userId int) (quota int, err error) { } redemption := &Redemption{} + keyCol := "`key`" + if common.UsingPostgreSQL { + keyCol = `"key"` + } + err = DB.Transaction(func(tx *gorm.DB) error { - err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error + err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error if err != nil { return errors.New("无效的兑换码") } diff --git a/model/user.go b/model/user.go index 6c555b17..eb87c178 100644 --- a/model/user.go +++ b/model/user.go @@ -293,7 +293,12 @@ func GetUserEmail(id int) (email string, err error) { } func GetUserGroup(id int) (group string, err error) { - err = DB.Model(&User{}).Where(&User{Id: id}).Select("group").Find(&group).Error + groupCol := "`group`" + if common.UsingPostgreSQL { + groupCol = `"group"` + } + + err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error return group, err } diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 01843a29..0e5ddb6c 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -68,7 +68,7 @@ const EditChannel = () => { localModels = ['PaLM-2']; break; case 15: - localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1']; + localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1']; break; case 17: localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];