diff --git a/common/constants.go b/common/constants.go
index a6575ca6..471f6ff2 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -55,6 +55,8 @@ var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
+var QuotaForInviter = 0
+var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var QuotaRemindThreshold = 1000
diff --git a/common/utils.go b/common/utils.go
index de15ce17..1329c1a0 100644
--- a/common/utils.go
+++ b/common/utils.go
@@ -157,6 +157,15 @@ func GenerateKey() string {
return string(key)
}
+func GetRandomString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ return string(key)
+}
+
func GetTimestamp() int64 {
return time.Now().Unix()
}
diff --git a/controller/github.go b/controller/github.go
index 93c2e8d3..e1c64130 100644
--- a/controller/github.go
+++ b/controller/github.go
@@ -125,7 +125,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
- if err := user.Insert(); err != nil {
+ if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/controller/user.go b/controller/user.go
index 09eaccd1..89e32096 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -150,15 +150,18 @@ func Register(c *gin.Context) {
return
}
}
+ affCode := user.AffCode // this code is the inviter's code, not the user's own code
+ inviterId, _ := model.GetUserIdByAffCode(affCode)
cleanUser := model.User{
Username: user.Username,
Password: user.Password,
DisplayName: user.Username,
+ InviterId: inviterId,
}
if common.EmailVerificationEnabled {
cleanUser.Email = user.Email
}
- if err := cleanUser.Insert(); err != nil {
+ if err := cleanUser.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
@@ -280,6 +283,34 @@ func GenerateAccessToken(c *gin.Context) {
return
}
+func GetAffCode(c *gin.Context) {
+ id := c.GetInt("id")
+ user, err := model.GetUserById(id, true)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if user.AffCode == "" {
+ user.AffCode = common.GetRandomString(4)
+ if err := user.Update(false); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": user.AffCode,
+ })
+ return
+}
+
func GetSelf(c *gin.Context) {
id := c.GetInt("id")
user, err := model.GetUserById(id, false)
@@ -495,7 +526,7 @@ func CreateUser(c *gin.Context) {
Password: user.Password,
DisplayName: user.DisplayName,
}
- if err := cleanUser.Insert(); err != nil {
+ if err := cleanUser.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/controller/wechat.go b/controller/wechat.go
index 5620e8d3..ff4c9fb6 100644
--- a/controller/wechat.go
+++ b/controller/wechat.go
@@ -85,7 +85,7 @@ func WeChatAuth(c *gin.Context) {
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
- if err := user.Insert(); err != nil {
+ if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
diff --git a/model/option.go b/model/option.go
index 101f694d..32d655ac 100644
--- a/model/option.go
+++ b/model/option.go
@@ -56,6 +56,8 @@ func InitOptionMap() {
common.OptionMap["TurnstileSiteKey"] = ""
common.OptionMap["TurnstileSecretKey"] = ""
common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
+ common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
+ common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
@@ -175,6 +177,10 @@ func updateOptionMap(key string, value string) (err error) {
common.TurnstileSecretKey = value
case "QuotaForNewUser":
common.QuotaForNewUser, _ = strconv.Atoi(value)
+ case "QuotaForInviter":
+ common.QuotaForInviter, _ = strconv.Atoi(value)
+ case "QuotaForInvitee":
+ common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold":
common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota":
diff --git a/model/user.go b/model/user.go
index a8fb7842..5205662e 100644
--- a/model/user.go
+++ b/model/user.go
@@ -26,6 +26,8 @@ type User struct {
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
+ AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
+ InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
}
func GetMaxUserId() int {
@@ -58,6 +60,15 @@ func GetUserById(id int, selectAll bool) (*User, error) {
return &user, err
}
+func GetUserIdByAffCode(affCode string) (int, error) {
+ if affCode == "" {
+ return 0, errors.New("affCode 为空!")
+ }
+ var user User
+ err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
+ return user.Id, err
+}
+
func DeleteUserById(id int) (err error) {
if id == 0 {
return errors.New("id 为空!")
@@ -66,7 +77,7 @@ func DeleteUserById(id int) (err error) {
return user.Delete()
}
-func (user *User) Insert() error {
+func (user *User) Insert(inviterId int) error {
var err error
if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password)
@@ -76,6 +87,7 @@ func (user *User) Insert() error {
}
user.Quota = common.QuotaForNewUser
user.AccessToken = common.GetUUID()
+ user.AffCode = common.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
@@ -83,6 +95,16 @@ func (user *User) Insert() error {
if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %d 点额度", common.QuotaForNewUser))
}
+ if inviterId != 0 {
+ if common.QuotaForInvitee > 0 {
+ _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
+ RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %d 点额度", common.QuotaForInvitee))
+ }
+ if common.QuotaForInviter > 0 {
+ _ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
+ RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %d 点额度", common.QuotaForInviter))
+ }
+ }
return nil
}
diff --git a/router/api-router.go b/router/api-router.go
index 062ccac1..2e5cd7d4 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -37,6 +37,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.PUT("/self", controller.UpdateSelf)
selfRoute.DELETE("/self", controller.DeleteSelf)
selfRoute.GET("/token", controller.GenerateAccessToken)
+ selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
}
diff --git a/web/src/components/PersonalSetting.js b/web/src/components/PersonalSetting.js
index d3216811..6c6ea49f 100644
--- a/web/src/components/PersonalSetting.js
+++ b/web/src/components/PersonalSetting.js
@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react';
import { Button, Divider, Form, Header, Image, Message, Modal } from 'semantic-ui-react';
import { Link } from 'react-router-dom';
-import { API, copy, showError, showInfo, showSuccess } from '../helpers';
+import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
const PersonalSetting = () => {
@@ -45,6 +45,18 @@ const PersonalSetting = () => {
}
};
+ const getAffLink = async () => {
+ const res = await API.get('/api/user/aff');
+ const { success, message, data } = res.data;
+ if (success) {
+ let link = `${window.location.origin}/register?aff=${data}`;
+ await copy(link);
+ showNotice(`邀请链接已复制到剪切板:${link}`);
+ } else {
+ showError(message);
+ }
+ };
+
const bindWeChat = async () => {
if (inputs.wechat_verification_code === '') return;
const res = await API.get(
@@ -110,6 +122,7 @@ const PersonalSetting = () => {
更新个人信息
+