Merge remote-tracking branch 'upstream/main'
This commit is contained in:
parent
d79a7b5902
commit
e1d840e7dd
3
.gitignore
vendored
3
.gitignore
vendored
@ -5,4 +5,5 @@ upload
|
||||
*.db
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
||||
logs
|
||||
.env
|
||||
|
@ -376,6 +376,7 @@ graph LR
|
||||
14. 编码器缓存设置:
|
||||
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||
|
||||
### 命令行参数
|
||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||
|
@ -21,13 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
var DisplayInCurrencyEnabled = true
|
||||
var DisplayTokenStatEnabled = true
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var SQLitePath = "one-api.db"
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
@ -109,6 +105,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
|
6
common/database.go
Normal file
6
common/database.go
Normal file
@ -0,0 +1,6 @@
|
||||
package common
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
|
||||
var SQLitePath = "one-api.db"
|
@ -46,6 +46,7 @@ var ModelRatio = map[string]float64{
|
||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
|
@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
|
||||
func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
@ -306,6 +306,15 @@ func init() {
|
||||
Root: "ERNIE-Bot-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot-4",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "ERNIE-Bot-4",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "Embedding-V1",
|
||||
Object: "model",
|
||||
|
@ -32,7 +32,14 @@ var httpClient *http.Client
|
||||
var impatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
httpClient = &http.Client{}
|
||||
if common.RelayTimeout == 0 {
|
||||
httpClient = &http.Client{}
|
||||
} else {
|
||||
httpClient = &http.Client{
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
impatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
@ -152,6 +159,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
|
@ -553,6 +553,9 @@
|
||||
"请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次": "Please enter the access token, the current version does not support automatic refresh, please update it every 30 days",
|
||||
"此项可选,用于通过Mirror站来进行 API 调用,请EnterMirror站地址,格式为:https://domain.com": "This is optional, used to make API calls through the Mirror site, please enter the Mirror site address, the format is: https://domain.com",
|
||||
|
||||
"非流式": "Non-streaming",
|
||||
"流式": "Streaming",
|
||||
|
||||
"新密码": "New password",
|
||||
"新密码已复制到剪贴板:": "New password copied to clipboard: ",
|
||||
"兑换失败,": "Redemption failed, ",
|
||||
|
@ -71,7 +71,13 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
streamText := "非流式"
|
||||
|
||||
if modelRequest.Stream {
|
||||
streamText = "流式"
|
||||
}
|
||||
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道 (%s)", userGroup, modelRequest.Model, streamText)
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
|
@ -17,35 +17,25 @@ type Ability struct {
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
|
||||
var cmdWhere *Ability
|
||||
cmdWhere := groupCol + " = ? and model = ? and enabled = " + trueVal
|
||||
|
||||
if stream {
|
||||
cmdWhere = &Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
AllowStreaming: common.ChannelAllowStreamEnabled,
|
||||
}
|
||||
cmdWhere += " and allow_streaming = 1"
|
||||
} else {
|
||||
cmdWhere = &Ability{
|
||||
Group: group,
|
||||
Model: model,
|
||||
Enabled: true,
|
||||
AllowNonStreaming: common.ChannelAllowNonStreamEnabled,
|
||||
}
|
||||
cmdWhere += " and allow_non_streaming = 1"
|
||||
}
|
||||
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere)
|
||||
|
||||
cmd1 := "`group` = ? and model = ? and enabled = 1 and priority = (?)"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
cmd1 = "\"group\" = ? and model = ? and enabled = 1 and priority = (?)"
|
||||
}
|
||||
|
||||
channelQuery := DB.Where(cmd1, group, model, maxPrioritySubQuery)
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(cmdWhere, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
} else {
|
||||
|
@ -21,15 +21,19 @@ var (
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
var token Token
|
||||
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where(&Token{Key: key}).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where(&Token{Key: key}).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -59,19 +59,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingPostgreSQL {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else if common.UsingSQLite {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where(&Channel{Status: common.ChannelStatusEnabled, Group: "default"}).Order("RAND()").Limit(1).First(&channel).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
var err error
|
||||
err = DB.Create(&channels).Error
|
||||
|
@ -51,8 +51,13 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
}
|
||||
redemption := &Redemption{}
|
||||
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(&Redemption{Key: key}).First(redemption).Error
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
|
@ -293,7 +293,12 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
err = DB.Model(&User{}).Where(&User{Id: id}).Select("group").Find(&group).Error
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ const EditChannel = () => {
|
||||
localModels = ['PaLM-2'];
|
||||
break;
|
||||
case 15:
|
||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
|
||||
break;
|
||||
case 17:
|
||||
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
|
||||
|
Loading…
Reference in New Issue
Block a user