diff --git a/common/config/key.go b/common/config/key.go new file mode 100644 index 00000000..4b503c2d --- /dev/null +++ b/common/config/key.go @@ -0,0 +1,9 @@ +package config + +const ( + KeyPrefix = "cfg_" + + KeyAPIVersion = KeyPrefix + "api_version" + KeyLibraryID = KeyPrefix + "library_id" + KeyPlugin = KeyPrefix + "plugin" +) diff --git a/common/constants.go b/common/constants.go index 95b29683..87221b61 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,80 +4,3 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change - -const ( - RoleGuestUser = 0 - RoleCommonUser = 1 - RoleAdminUser = 10 - RoleRootUser = 100 -) - -const ( - UserStatusEnabled = 1 // don't use 0, 0 is the default value! - UserStatusDisabled = 2 // also don't use 0 - UserStatusDeleted = 3 -) - -const ( - TokenStatusEnabled = 1 // don't use 0, 0 is the default value! - TokenStatusDisabled = 2 // also don't use 0 - TokenStatusExpired = 3 - TokenStatusExhausted = 4 -) - -const ( - RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! - RedemptionCodeStatusDisabled = 2 // also don't use 0 - RedemptionCodeStatusUsed = 3 // also don't use 0 -) - -const ( - ChannelStatusUnknown = 0 - ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! - ChannelStatusManuallyDisabled = 2 // also don't use 0 - ChannelStatusAutoDisabled = 3 -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "https://generativelanguage.googleapis.com", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", // 23 - "https://generativelanguage.googleapis.com", // 24 - "https://api.moonshot.cn", // 25 - "https://api.baichuan-ai.com", // 26 - "https://api.minimax.chat", // 27 - "https://api.mistral.ai", // 28 - "https://api.groq.com/openai", // 29 - "http://localhost:11434", // 30 - "https://api.lingyiwanwu.com", // 31 - "https://api.stepfun.com", // 32 -} - -const ( - ConfigKeyPrefix = "cfg_" - - ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" - ConfigKeyLibraryID = ConfigKeyPrefix + "library_id" - ConfigKeyPlugin = ConfigKeyPrefix + "plugin" -) diff --git a/controller/auth/github.go b/controller/auth/github.go index 96298ce4..15542655 100644 --- a/controller/auth/github.go +++ b/controller/auth/github.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" @@ -134,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { user.DisplayName = "GitHub User" } user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -153,7 +152,7 @@ func GitHubOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, diff --git a/controller/auth/lark.go b/controller/auth/lark.go index 21446d46..eb06dde9 100644 --- a/controller/auth/lark.go +++ b/controller/auth/lark.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/controller" @@ -123,8 +122,8 @@ func LarkOAuth(c *gin.Context) { } else { user.DisplayName = "Lark User" } - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -142,7 +141,7 @@ func LarkOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, diff --git a/controller/auth/wechat.go b/controller/auth/wechat.go index 80552c9a..a64746c9 100644 --- a/controller/auth/wechat.go +++ b/controller/auth/wechat.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" @@ -84,8 +83,8 @@ func WeChatAuth(c *gin.Context) { if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -103,7 +102,7 @@ func WeChatAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, diff --git a/controller/channel-billing.go b/controller/channel-billing.go index aec79188..b7ac61fd 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -205,7 +204,7 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { } func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } @@ -302,7 +301,7 @@ func updateAllChannelsBalance() error { return err } for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { continue } // TODO: support Azure diff --git a/controller/channel-test.go b/controller/channel-test.go index a6a76101..ddbe0b4a 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/message" @@ -173,7 +172,7 @@ func testChannels(notify bool, scope string) error { } go func() { for _, channel := range channels { - isChannelEnabled := channel.Status == common.ChannelStatusEnabled + isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel) tok := time.Now() diff --git a/controller/token.go b/controller/token.go index 0e6c8d63..557b5ce1 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,7 +3,6 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/network" @@ -213,15 +212,15 @@ func UpdateToken(c *gin.Context) { }) return } - if token.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { + if token.Status == model.TokenStatusEnabled { + if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { + if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", diff --git a/controller/user.go b/controller/user.go index 4356799b..44b4f793 100644 --- a/controller/user.go +++ b/controller/user.go @@ -239,7 +239,7 @@ func GetUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权获取同级或更高等级用户的信息", @@ -388,14 +388,14 @@ func UpdateUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= originUser.Role && myRole != common.RoleRootUser { + if myRole <= originUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", }) return } - if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + if myRole <= updatedUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", @@ -509,7 +509,7 @@ func DeleteSelf(c *gin.Context) { id := c.GetInt("id") user, _ := model.GetUserById(id, false) - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除超级管理员账户", @@ -611,7 +611,7 @@ func ManageUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", @@ -620,8 +620,8 @@ func ManageUser(c *gin.Context) { } switch req.Action { case "disable": - user.Status = common.UserStatusDisabled - if user.Role == common.RoleRootUser { + user.Status = model.UserStatusDisabled + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法禁用超级管理员用户", @@ -629,9 +629,9 @@ func ManageUser(c *gin.Context) { return } case "enable": - user.Status = common.UserStatusEnabled + user.Status = model.UserStatusEnabled case "delete": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法删除超级管理员用户", @@ -646,37 +646,37 @@ func ManageUser(c *gin.Context) { return } case "promote": - if myRole != common.RoleRootUser { + if myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "普通管理员用户无法提升其他用户为管理员", }) return } - if user.Role >= common.RoleAdminUser { + if user.Role >= model.RoleAdminUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是管理员", }) return } - user.Role = common.RoleAdminUser + user.Role = model.RoleAdminUser case "demote": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法降级超级管理员用户", }) return } - if user.Role == common.RoleCommonUser { + if user.Role == model.RoleCommonUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是普通用户", }) return } - user.Role = common.RoleCommonUser + user.Role = model.RoleCommonUser } if err := user.Update(false); err != nil { @@ -730,7 +730,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ diff --git a/middleware/auth.go b/middleware/auth.go index 223cef3d..01b2cce3 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" @@ -45,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { + if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", @@ -72,19 +71,19 @@ func authHelper(c *gin.Context, minRole int) { func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleCommonUser) + authHelper(c, model.RoleCommonUser) } } func AdminAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleAdminUser) + authHelper(c, model.RoleAdminUser) } } func RootAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleRootUser) + authHelper(c, model.RoleRootUser) } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 29a1d5b3..6e0d2718 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -3,7 +3,7 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channeltype" @@ -34,7 +34,7 @@ func Distribute() func(c *gin.Context) { abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } @@ -68,18 +68,18 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode // this is for backward compatibility switch channel.Type { case channeltype.Azure: - c.Set(common.ConfigKeyAPIVersion, channel.Other) + c.Set(config.KeyAPIVersion, channel.Other) case channeltype.Xunfei: - c.Set(common.ConfigKeyAPIVersion, channel.Other) + c.Set(config.KeyAPIVersion, channel.Other) case channeltype.Gemini: - c.Set(common.ConfigKeyAPIVersion, channel.Other) + c.Set(config.KeyAPIVersion, channel.Other) case channeltype.AIProxyLibrary: - c.Set(common.ConfigKeyLibraryID, channel.Other) + c.Set(config.KeyLibraryID, channel.Other) case channeltype.Ali: - c.Set(common.ConfigKeyPlugin, channel.Other) + c.Set(config.KeyPlugin, channel.Other) } cfg, _ := channel.LoadConfig() for k, v := range cfg { - c.Set(common.ConfigKeyPrefix+k, v) + c.Set(config.KeyPrefix+k, v) } } diff --git a/model/ability.go b/model/ability.go index 4a48bc51..2db72518 100644 --- a/model/ability.go +++ b/model/ability.go @@ -57,7 +57,7 @@ func (channel *Channel) AddAbilities() error { Group: group, Model: model, ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Enabled: channel.Status == ChannelStatusEnabled, Priority: channel.Priority, } abilities = append(abilities, ability) diff --git a/model/cache.go b/model/cache.go index b80680d3..cfb0f8a4 100644 --- a/model/cache.go +++ b/model/cache.go @@ -173,7 +173,7 @@ var channelSyncLock sync.RWMutex func InitChannelCache() { newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } diff --git a/model/channel.go b/model/channel.go index fc4905b1..e667f7e7 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,13 +3,19 @@ package model import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" ) +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 +) + type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` @@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { case "all": err = DB.Order("id desc").Find(&channels).Error case "disabled": - err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error + err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } @@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) { } func UpdateChannelStatusById(id int, status int) { - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) if err != nil { logger.SysError("failed to update ability status: " + err.Error()) } @@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { } func DeleteDisabledChannel() (int64, error) { - result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } diff --git a/model/main.go b/model/main.go index 12679ab9..4b5323c4 100644 --- a/model/main.go +++ b/model/main.go @@ -32,8 +32,8 @@ func CreateRootAccountIfNeed() error { rootUser := User{ Username: "root", Password: hashedPassword, - Role: common.RoleRootUser, - Status: common.UserStatusEnabled, + Role: RoleRootUser, + Status: UserStatusEnabled, DisplayName: "Root User", AccessToken: random.GetUUID(), Quota: 500000000000000, @@ -45,7 +45,7 @@ func CreateRootAccountIfNeed() error { Id: 1, UserId: rootUser.Id, Key: config.InitialRootToken, - Status: common.TokenStatusEnabled, + Status: TokenStatusEnabled, Name: "Initial Root Token", CreatedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(), diff --git a/model/redemption.go b/model/redemption.go index 79a5b8a9..45871a71 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -8,6 +8,12 @@ import ( "gorm.io/gorm" ) +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + type Redemption struct { Id int `json:"id"` UserId int `json:"user_id"` @@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) { if err != nil { return errors.New("无效的兑换码") } - if redemption.Status != common.RedemptionCodeStatusEnabled { + if redemption.Status != RedemptionCodeStatusEnabled { return errors.New("该兑换码已被使用") } err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error @@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) { return err } redemption.RedeemedTime = helper.GetTimestamp() - redemption.Status = common.RedemptionCodeStatusUsed + redemption.Status = RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err }) diff --git a/model/token.go b/model/token.go index 20228ec5..96e6b491 100644 --- a/model/token.go +++ b/model/token.go @@ -11,6 +11,13 @@ import ( "gorm.io/gorm" ) +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + type Token struct { Id int `json:"id"` UserId int `json:"user_id"` @@ -62,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) { } return nil, errors.New("令牌验证失败") } - if token.Status == common.TokenStatusExhausted { + if token.Status == TokenStatusExhausted { return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) - } else if token.Status == common.TokenStatusExpired { + } else if token.Status == TokenStatusExpired { return nil, errors.New("该令牌已过期") } - if token.Status != common.TokenStatusEnabled { + if token.Status != TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { - token.Status = common.TokenStatusExpired + token.Status = TokenStatusExpired err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -83,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted + token.Status = TokenStatusExhausted err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) diff --git a/model/user.go b/model/user.go index a00e98dc..1dc633b1 100644 --- a/model/user.go +++ b/model/user.go @@ -12,6 +12,19 @@ import ( "strings" ) +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 +) + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { @@ -42,7 +55,7 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { - query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) switch order { case "quota": @@ -138,9 +151,9 @@ func (user *User) Update(updatePassword bool) error { return err } } - if user.Status == common.UserStatusDisabled { + if user.Status == UserStatusDisabled { blacklist.BanUser(user.Id) - } else if user.Status == common.UserStatusEnabled { + } else if user.Status == UserStatusEnabled { blacklist.UnbanUser(user.Id) } err = DB.Model(user).Updates(user).Error @@ -153,7 +166,7 @@ func (user *User) Delete() error { } blacklist.BanUser(user.Id) user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) - user.Status = common.UserStatusDeleted + user.Status = UserStatusDeleted err := DB.Model(user).Updates(user).Error return err } @@ -177,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) { } } okay := common.ValidatePasswordAndHash(password, user.Password) - if !okay || user.Status != common.UserStatusEnabled { + if !okay || user.Status != UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil @@ -273,7 +286,7 @@ func IsAdmin(userId int) bool { logger.SysError("no such user " + err.Error()) return false } - return user.Role >= common.RoleAdminUser + return user.Role >= RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { @@ -285,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) { if err != nil { return false, err } - return user.Status == common.UserStatusEnabled, nil + return user.Status == UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -358,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) { } func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) return email } diff --git a/monitor/channel.go b/monitor/channel.go index ad82d2f5..7e5dc58a 100644 --- a/monitor/channel.go +++ b/monitor/channel.go @@ -2,7 +2,6 @@ package monitor import ( "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/message" @@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) { // DisableChannel disable & notify func DisableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) @@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) { } func MetricDisableChannel(channelId int, successRate float64) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", @@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) { // EnableChannel enable & notify func EnableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) diff --git a/relay/adaptor/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go index 479efaed..7ad6225a 100644 --- a/relay/adaptor/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) + aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID) return aiProxyLibraryRequest, nil } diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go index e5caed75..21b5e8b8 100644 --- a/relay/adaptor/ali/adaptor.go +++ b/relay/adaptor/ali/adaptor.go @@ -4,7 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -47,8 +47,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me if meta.Mode == relaymode.ImagesGenerations { req.Header.Set("X-DashScope-Async", "enable") } - if c.GetString(common.ConfigKeyPlugin) != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) + if c.GetString(config.KeyPlugin) != "" { + req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin)) } return nil } diff --git a/relay/adaptor/azure/helper.go b/relay/adaptor/azure/helper.go index 29004d27..dd207f37 100644 --- a/relay/adaptor/azure/helper.go +++ b/relay/adaptor/azure/helper.go @@ -2,14 +2,14 @@ package azure import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" ) func GetAPIVersion(c *gin.Context) string { query := c.Request.URL.Query() apiVersion := query.Get("api-version") if apiVersion == "" { - apiVersion = c.GetString(common.ConfigKeyAPIVersion) + apiVersion = c.GetString(config.KeyAPIVersion) } return apiVersion } diff --git a/relay/adaptor/xunfei/main.go b/relay/adaptor/xunfei/main.go index 710ea8bb..369e6227 100644 --- a/relay/adaptor/xunfei/main.go +++ b/relay/adaptor/xunfei/main.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" @@ -279,7 +280,7 @@ func getAPIVersion(c *gin.Context, modelName string) string { return apiVersion } - apiVersion = c.GetString(common.ConfigKeyAPIVersion) + apiVersion = c.GetString(config.KeyAPIVersion) if apiVersion != "" { return apiVersion } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go new file mode 100644 index 00000000..eec59116 --- /dev/null +++ b/relay/channeltype/url.go @@ -0,0 +1,43 @@ +package channeltype + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 + "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 +} + +func init() { + if len(ChannelBaseURLs) != Dummy { + panic("channel base urls length not match") + } +} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index ce972f88..9d8cfef5 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -119,7 +119,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - baseURL := common.ChannelBaseURLs[channelType] + baseURL := channeltype.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 6fb5592a..22ef1567 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -2,7 +2,7 @@ package meta import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/relay/adaptor/azure" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/relaymode" @@ -41,7 +41,7 @@ func GetByContext(c *gin.Context) *Meta { Group: c.GetString("group"), ModelMapping: c.GetStringMapString("model_mapping"), BaseURL: c.GetString("base_url"), - APIVersion: c.GetString(common.ConfigKeyAPIVersion), + APIVersion: c.GetString(config.KeyAPIVersion), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Config: nil, RequestURLPath: c.Request.URL.String(), @@ -50,7 +50,7 @@ func GetByContext(c *gin.Context) *Meta { meta.APIVersion = azure.GetAPIVersion(c) } if meta.BaseURL == "" { - meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] + meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } meta.APIType = channeltype.ToAPIType(meta.ChannelType) return &meta diff --git a/router/api-router.go b/router/api.go similarity index 100% rename from router/api-router.go rename to router/api.go diff --git a/router/relay-router.go b/router/relay.go similarity index 100% rename from router/relay-router.go rename to router/relay.go diff --git a/router/web-router.go b/router/web.go similarity index 100% rename from router/web-router.go rename to router/web.go