ai-gateway/model/user.go
2023-08-13 00:51:48 +08:00

311 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"errors"
"fmt"
"gorm.io/gorm"
"one-api/common"
"strings"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
Id int `json:"id"`
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
}
func GetMaxUserId() int {
var user User
DB.Last(&user)
return user.Id
}
func GetAllUsers(startIdx int, num int) (users []*User, err error) {
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
return users, err
}
func GetUserById(id int, selectAll bool) (*User, error) {
if id == 0 {
return nil, errors.New("id 为空!")
}
user := User{Id: id}
var err error = nil
if selectAll {
err = DB.First(&user, "id = ?", id).Error
} else {
err = DB.Omit("password").First(&user, "id = ?", id).Error
}
return &user, err
}
func GetUserIdByAffCode(affCode string) (int, error) {
if affCode == "" {
return 0, errors.New("affCode 为空!")
}
var user User
err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
return user.Id, err
}
func DeleteUserById(id int) (err error) {
if id == 0 {
return errors.New("id 为空!")
}
user := User{Id: id}
return user.Delete()
}
func (user *User) Insert(inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
user.AffCode = common.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
}
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
}
if inviterId != 0 {
if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
}
if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
}
}
return nil
}
func (user *User) Update(updatePassword bool) error {
var err error
if updatePassword {
user.Password, err = common.Password2Hash(user.Password)
if err != nil {
return err
}
}
err = DB.Model(user).Updates(user).Error
return err
}
func (user *User) Delete() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
err := DB.Delete(user).Error
return err
}
// ValidateAndFill check password & user status
func (user *User) ValidateAndFill() (err error) {
// When querying with struct, GORM will only query with non-zero fields,
// that means if your fields value is 0, '', false or other zero values,
// it wont be used to build query conditions
password := user.Password
if user.Username == "" || password == "" {
return errors.New("用户名或密码为空")
}
DB.Where(User{Username: user.Username}).First(user)
okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
}
func (user *User) FillUserById() error {
if user.Id == 0 {
return errors.New("id 为空!")
}
DB.Where(User{Id: user.Id}).First(user)
return nil
}
func (user *User) FillUserByEmail() error {
if user.Email == "" {
return errors.New("email 为空!")
}
DB.Where(User{Email: user.Email}).First(user)
return nil
}
func (user *User) FillUserByGitHubId() error {
if user.GitHubId == "" {
return errors.New("GitHub id 为空!")
}
DB.Where(User{GitHubId: user.GitHubId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
}
DB.Where(User{WeChatId: user.WeChatId}).First(user)
return nil
}
func (user *User) FillUserByUsername() error {
if user.Username == "" {
return errors.New("username 为空!")
}
DB.Where(User{Username: user.Username}).First(user)
return nil
}
func IsEmailAlreadyTaken(email string) bool {
return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
}
func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
}
func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
func ResetUserPasswordByEmail(email string, password string) error {
if email == "" || password == "" {
return errors.New("邮箱地址或密码为空!")
}
hashedPassword, err := common.Password2Hash(password)
if err != nil {
return err
}
err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
return err
}
func IsAdmin(userId int) bool {
if userId == 0 {
return false
}
var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
return false
}
return user.Role >= common.RoleAdminUser
}
func IsUserEnabled(userId int) bool {
if userId == 0 {
return false
}
var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil {
common.SysError("no such user " + err.Error())
return false
}
return user.Status == common.UserStatusEnabled
}
func ValidateAccessToken(token string) (user *User) {
if token == "" {
return nil
}
token = strings.Replace(token, "Bearer ", "", 1)
user = &User{}
if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
return user
}
return nil
}
func GetUserQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err
}
func GetUserUsedQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err
}
func GetUserEmail(id int) (email string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
return email, err
}
func GetUserGroup(id int) (group string, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
return group, err
}
func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", 1),
},
).Error
if err != nil {
common.SysError("failed to update user used quota and request count: " + err.Error())
}
}
func GetUsernameById(id int) (username string) {
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
return username
}