From 4908a9eddc64e9491462993fe0cd6bd103fe246e Mon Sep 17 00:00:00 2001
From: ckt1031 <65409152+ckt1031@users.noreply.github.com>
Date: Mon, 24 Jul 2023 12:20:52 +0800
Subject: [PATCH] feat: add Discord OAuth
---
common/constants.go | 4 +
controller/discord.go | 244 ++++++++++++++++++++++++++
controller/misc.go | 5 +-
controller/option.go | 11 +-
model/option.go | 9 +
model/user.go | 16 +-
router/api-router.go | 1 +
web/src/App.js | 9 +
web/src/components/DiscordOAuth.js | 57 ++++++
web/src/components/LoginForm.js | 24 ++-
web/src/components/PersonalSetting.js | 11 ++
web/src/components/SystemSetting.js | 62 +++++++
web/src/pages/Home/index.js | 6 +
web/src/pages/User/EditUser.js | 13 +-
14 files changed, 461 insertions(+), 11 deletions(-)
create mode 100644 controller/discord.go
create mode 100644 web/src/components/DiscordOAuth.js
diff --git a/common/constants.go b/common/constants.go
index 81f98163..41107ff0 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -38,6 +38,7 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
+var DiscordOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
@@ -53,6 +54,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
+var DiscordClientId = ""
+var DiscordClientSecret = ""
+
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
diff --git a/controller/discord.go b/controller/discord.go
new file mode 100644
index 00000000..752b937f
--- /dev/null
+++ b/controller/discord.go
@@ -0,0 +1,244 @@
+package controller
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "one-api/common"
+ "one-api/model"
+ "strconv"
+
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+)
+
+type DiscordOAuthResponse struct {
+ AccessToken string `json:"access_token"`
+ Scope string `json:"scope"`
+ TokenType string `json:"token_type"`
+}
+
+type DiscordUser struct {
+ Id string `json:"id"`
+ Verified bool `json:"verified"`
+ Username string `json:"username"`
+}
+
+func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) {
+ if codeFromURLParamaters == "" {
+ return nil, errors.New("无效参数")
+ }
+
+ RequestClient := &http.Client{}
+
+ accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf(
+ "client_id=%s&client_secret=%s&grant_type=authorization_code&redirect_uri=%s/oauth/discord&code=%s&scope=identify",
+ common.DiscordClientId, common.DiscordClientSecret, host, codeFromURLParamaters,
+ )))
+
+ req, _ := http.NewRequest("POST",
+ "https://discordapp.com/api/oauth2/token",
+ accessTokenBody,
+ )
+
+ req.Header = http.Header{
+ "Content-Type": []string{"application/x-www-form-urlencoded"},
+ "Accept": []string{"application/json"},
+ }
+
+ resp, err := RequestClient.Do(req)
+
+ if resp.StatusCode != 200 || err != nil {
+ return nil, errors.New("访问令牌无效")
+ }
+
+ var discordOAuthResponse DiscordOAuthResponse
+
+ json.NewDecoder(resp.Body).Decode(&discordOAuthResponse)
+
+ accessToken := fmt.Sprintf("Bearer %s", discordOAuthResponse.AccessToken)
+
+ // Get User Info
+ req, _ = http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
+
+ req.Header = http.Header{
+ "Content-Type": []string{"application/json"},
+ "Authorization": []string{accessToken},
+ }
+
+ defer resp.Body.Close()
+
+ resp, err = RequestClient.Do(req)
+
+ if resp.StatusCode != 200 || err != nil {
+ return nil, errors.New("Discord 用户信息无效")
+ }
+
+ var discordUser DiscordUser
+
+ json.NewDecoder(resp.Body).Decode(&discordUser)
+
+ if err != nil {
+ return nil, err
+ }
+
+ if discordUser.Id == "" {
+ return nil, errors.New("返回值无效,用户字段为空,请稍后再试!")
+ }
+
+ if discordUser.Verified == false {
+ return nil, errors.New("Discord 帐户未经验证!")
+ }
+
+ defer resp.Body.Close()
+
+ return &discordUser, nil
+}
+
+func DiscordOAuth(c *gin.Context) {
+ session := sessions.Default(c)
+ username := session.Get("username")
+ if username != nil {
+ DiscordBind(c)
+ return
+ }
+
+ if !common.DiscordOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 Discord 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+
+ // Get protocal whether http or https and host
+ host := c.Request.Host
+ if c.Request.TLS == nil {
+ host = "http://" + host
+ } else {
+ host = "https://" + host
+ }
+
+ discordUser, err := getDiscordUserInfoByCode(code, host)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ DiscordId: discordUser.Id,
+ }
+ if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
+ err := user.FillUserByDiscordId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ if common.RegisterEnabled {
+ user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
+ if discordUser.Username != "" {
+ user.DisplayName = discordUser.Username
+ } else {
+ user.DisplayName = "Discord User"
+ }
+ user.Role = common.RoleCommonUser
+ user.Status = common.UserStatusEnabled
+
+ if err := user.Insert(0); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != common.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ setupLogin(&user, c)
+}
+
+func DiscordBind(c *gin.Context) {
+ if !common.DiscordOAuthEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员未开启通过 Discord 登录以及注册",
+ })
+ return
+ }
+ code := c.Query("code")
+
+ // Get protocal whether http or https and host
+ host := c.Request.Host
+ if c.Request.TLS == nil {
+ host = "http://" + host
+ } else {
+ host = "https://" + host
+ }
+
+ discordUser, err := getDiscordUserInfoByCode(code, host)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ DiscordId: discordUser.Id,
+ }
+ if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该 Discord 账户已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ // id := c.GetInt("id") // critical bug!
+ user.Id = id.(int)
+ err = user.FillUserById()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user.DiscordId = discordUser.Id
+ err = user.Update(false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+ return
+}
diff --git a/controller/misc.go b/controller/misc.go
index 958a3716..c6634bd1 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -3,10 +3,11 @@ package controller
import (
"encoding/json"
"fmt"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
+
+ "github.com/gin-gonic/gin"
)
func GetStatus(c *gin.Context) {
@@ -19,6 +20,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
+ "discord_oauth": common.DiscordOAuthEnabled,
+ "discord_client_id": common.DiscordClientId,
"system_name": common.SystemName,
"logo": common.Logo,
"footer_html": common.Footer,
diff --git a/controller/option.go b/controller/option.go
index abf0d5be..a5ee8f03 100644
--- a/controller/option.go
+++ b/controller/option.go
@@ -2,11 +2,12 @@ package controller
import (
"encoding/json"
- "github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strings"
+
+ "github.com/gin-gonic/gin"
)
func GetOptions(c *gin.Context) {
@@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
+ case "DiscordOAuthEnabled":
+ if option.Value == "true" && common.DiscordClientId == "" {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "无法启用 Discord OAuth,请先填入 Discord Client ID 以及 Discord Client Secret!",
+ })
+ return
+ }
case "WeChatAuthEnabled":
if option.Value == "true" && common.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{
diff --git a/model/option.go b/model/option.go
index e7bc6806..df93b704 100644
--- a/model/option.go
+++ b/model/option.go
@@ -30,6 +30,7 @@ func InitOptionMap() {
common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
+ common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled)
common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
@@ -53,6 +54,8 @@ func InitOptionMap() {
common.OptionMap["ServerAddress"] = ""
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
+ common.OptionMap["DiscordClientId"] = ""
+ common.OptionMap["DiscordClientSecret"] = ""
common.OptionMap["WeChatServerAddress"] = ""
common.OptionMap["WeChatServerToken"] = ""
common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
@@ -135,6 +138,8 @@ func updateOptionMap(key string, value string) (err error) {
common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled":
common.GitHubOAuthEnabled = boolValue
+ case "DiscordOAuthEnabled":
+ common.DiscordOAuthEnabled = boolValue
case "WeChatAuthEnabled":
common.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled":
@@ -171,6 +176,10 @@ func updateOptionMap(key string, value string) (err error) {
common.GitHubClientId = value
case "GitHubClientSecret":
common.GitHubClientSecret = value
+ case "DiscordClientId":
+ common.DiscordClientId = value
+ case "DiscordClientSecret":
+ common.DiscordClientSecret = value
case "Footer":
common.Footer = value
case "SystemName":
diff --git a/model/user.go b/model/user.go
index 7c771840..234b2d99 100644
--- a/model/user.go
+++ b/model/user.go
@@ -3,9 +3,10 @@ package model
import (
"errors"
"fmt"
- "gorm.io/gorm"
"one-api/common"
"strings"
+
+ "gorm.io/gorm"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
@@ -19,6 +20,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
+ DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
@@ -169,6 +171,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
+func (user *User) FillUserByDiscordId() error {
+ if user.DiscordId == "" {
+ return errors.New("Discord id 为空!")
+ }
+ DB.Where(User{DiscordId: user.DiscordId}).First(user)
+ return nil
+}
+
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -193,6 +203,10 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
}
+func IsDiscordIdAlreadyTaken(discordId string) bool {
+ return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
+}
+
func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
diff --git a/router/api-router.go b/router/api-router.go
index 383133fa..4ec68ef9 100644
--- a/router/api-router.go
+++ b/router/api-router.go
@@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
+ apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
diff --git a/web/src/App.js b/web/src/App.js
index c967ce2c..5ca8988d 100644
--- a/web/src/App.js
+++ b/web/src/App.js
@@ -12,6 +12,7 @@ import AddUser from './pages/User/AddUser';
import { API, getLogo, getSystemName, showError, showNotice } from './helpers';
import PasswordResetForm from './components/PasswordResetForm';
import GitHubOAuth from './components/GitHubOAuth';
+import DiscordOAuth from './components/DiscordOAuth';
import PasswordResetConfirm from './components/PasswordResetConfirm';
import { UserContext } from './context/User';
import { StatusContext } from './context/Status';
@@ -239,6 +240,14 @@ function App() {
}
/>
+
+ Discord 身份验证: + {statusState?.status?.discord_oauth === true + ? '已启用' + : '未启用'} +
微信身份验证: {statusState?.status?.wechat_login === true diff --git a/web/src/pages/User/EditUser.js b/web/src/pages/User/EditUser.js index b1c77945..032673e0 100644 --- a/web/src/pages/User/EditUser.js +++ b/web/src/pages/User/EditUser.js @@ -13,13 +13,14 @@ const EditUser = () => { display_name: '', password: '', github_id: '', + discord_id: '', wechat_id: '', email: '', quota: 0, group: 'default' }); const [groupOptions, setGroupOptions] = useState([]); - const { username, display_name, password, github_id, wechat_id, email, quota, group } = + const { username, display_name, password, github_id, wechat_id, email, quota, discord_id } = inputs; const handleInputChange = (e, { name, value }) => { setInputs((inputs) => ({ ...inputs, [name]: value })); @@ -166,6 +167,16 @@ const EditUser = () => { readOnly /> +