diff --git a/common/blacklist/main.go b/common/blacklist/main.go new file mode 100644 index 00000000..f84ce6ae --- /dev/null +++ b/common/blacklist/main.go @@ -0,0 +1,29 @@ +package blacklist + +import ( + "fmt" + "sync" +) + +var blackList sync.Map + +func init() { + blackList = sync.Map{} +} + +func userId2Key(id int) string { + return fmt.Sprintf("userid_%d", id) +} + +func BanUser(id int) { + blackList.Store(userId2Key(id), true) +} + +func UnbanUser(id int) { + blackList.Delete(userId2Key(id)) +} + +func IsUserBanned(id int) bool { + _, ok := blackList.Load(userId2Key(id)) + return ok +} diff --git a/common/constants.go b/common/constants.go index 5bde1612..de71bc7a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -15,6 +15,7 @@ const ( const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 ) const ( diff --git a/middleware/auth.go b/middleware/auth.go index 9d25f395..30997efd 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,6 +4,7 @@ import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/model" "net/http" "strings" @@ -42,11 +43,14 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled { + if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", }) + session := sessions.Default(c) + session.Clear() + _ = session.Save() c.Abort() return } @@ -99,7 +103,7 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusInternalServerError, err.Error()) return } - if !userEnabled { + if !userEnabled || blacklist.IsUserBanned(token.UserId) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } diff --git a/model/user.go b/model/user.go index 6979c70b..dcbd6ff1 100644 --- a/model/user.go +++ b/model/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" @@ -40,7 +41,7 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int) (users []*User, err error) { - err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error + err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted).Find(&users).Error return users, err } @@ -123,6 +124,11 @@ func (user *User) Update(updatePassword bool) error { return err } } + if user.Status == common.UserStatusDisabled { + blacklist.BanUser(user.Id) + } else if user.Status == common.UserStatusEnabled { + blacklist.UnbanUser(user.Id) + } err = DB.Model(user).Updates(user).Error return err } @@ -131,7 +137,10 @@ func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } - err := DB.Delete(user).Error + blacklist.BanUser(user.Id) + user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) + user.Status = common.UserStatusDeleted + err := DB.Model(user).Updates(user).Error return err }