From 4908a9eddc64e9491462993fe0cd6bd103fe246e Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Mon, 24 Jul 2023 12:20:52 +0800 Subject: [PATCH] feat: add Discord OAuth --- common/constants.go | 4 + controller/discord.go | 244 ++++++++++++++++++++++++++ controller/misc.go | 5 +- controller/option.go | 11 +- model/option.go | 9 + model/user.go | 16 +- router/api-router.go | 1 + web/src/App.js | 9 + web/src/components/DiscordOAuth.js | 57 ++++++ web/src/components/LoginForm.js | 24 ++- web/src/components/PersonalSetting.js | 11 ++ web/src/components/SystemSetting.js | 62 +++++++ web/src/pages/Home/index.js | 6 + web/src/pages/User/EditUser.js | 13 +- 14 files changed, 461 insertions(+), 11 deletions(-) create mode 100644 controller/discord.go create mode 100644 web/src/components/DiscordOAuth.js diff --git a/common/constants.go b/common/constants.go index 81f98163..41107ff0 100644 --- a/common/constants.go +++ b/common/constants.go @@ -38,6 +38,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var DiscordOAuthEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -53,6 +54,9 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var DiscordClientId = "" +var DiscordClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/controller/discord.go b/controller/discord.go new file mode 100644 index 00000000..752b937f --- /dev/null +++ b/controller/discord.go @@ -0,0 +1,244 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type DiscordOAuthResponse struct { + AccessToken string `json:"access_token"` + Scope string `json:"scope"` + TokenType string `json:"token_type"` +} + +type DiscordUser struct { + Id string `json:"id"` + Verified bool `json:"verified"` + Username string `json:"username"` +} + +func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) { + if codeFromURLParamaters == "" { + return nil, errors.New("无效参数") + } + + RequestClient := &http.Client{} + + accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf( + "client_id=%s&client_secret=%s&grant_type=authorization_code&redirect_uri=%s/oauth/discord&code=%s&scope=identify", + common.DiscordClientId, common.DiscordClientSecret, host, codeFromURLParamaters, + ))) + + req, _ := http.NewRequest("POST", + "https://discordapp.com/api/oauth2/token", + accessTokenBody, + ) + + req.Header = http.Header{ + "Content-Type": []string{"application/x-www-form-urlencoded"}, + "Accept": []string{"application/json"}, + } + + resp, err := RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("访问令牌无效") + } + + var discordOAuthResponse DiscordOAuthResponse + + json.NewDecoder(resp.Body).Decode(&discordOAuthResponse) + + accessToken := fmt.Sprintf("Bearer %s", discordOAuthResponse.AccessToken) + + // Get User Info + req, _ = http.NewRequest("GET", "https://discord.com/api/users/@me", nil) + + req.Header = http.Header{ + "Content-Type": []string{"application/json"}, + "Authorization": []string{accessToken}, + } + + defer resp.Body.Close() + + resp, err = RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("Discord 用户信息无效") + } + + var discordUser DiscordUser + + json.NewDecoder(resp.Body).Decode(&discordUser) + + if err != nil { + return nil, err + } + + if discordUser.Id == "" { + return nil, errors.New("返回值无效,用户字段为空,请稍后再试!") + } + + if discordUser.Verified == false { + return nil, errors.New("Discord 帐户未经验证!") + } + + defer resp.Body.Close() + + return &discordUser, nil +} + +func DiscordOAuth(c *gin.Context) { + session := sessions.Default(c) + username := session.Get("username") + if username != nil { + DiscordBind(c) + return + } + + if !common.DiscordOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Discord 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + discordUser, err := getDiscordUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + DiscordId: discordUser.Id, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + err := user.FillUserByDiscordId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1) + if discordUser.Username != "" { + user.DisplayName = discordUser.Username + } else { + user.DisplayName = "Discord User" + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func DiscordBind(c *gin.Context) { + if !common.DiscordOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Discord 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + discordUser, err := getDiscordUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + DiscordId: discordUser.Id, + } + if model.IsDiscordIdAlreadyTaken(user.DiscordId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 Discord 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.DiscordId = discordUser.Id + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index 958a3716..c6634bd1 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -3,10 +3,11 @@ package controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) func GetStatus(c *gin.Context) { @@ -19,6 +20,8 @@ func GetStatus(c *gin.Context) { "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, "github_client_id": common.GitHubClientId, + "discord_oauth": common.DiscordOAuthEnabled, + "discord_client_id": common.DiscordClientId, "system_name": common.SystemName, "logo": common.Logo, "footer_html": common.Footer, diff --git a/controller/option.go b/controller/option.go index abf0d5be..a5ee8f03 100644 --- a/controller/option.go +++ b/controller/option.go @@ -2,11 +2,12 @@ package controller import ( "encoding/json" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strings" + + "github.com/gin-gonic/gin" ) func GetOptions(c *gin.Context) { @@ -49,6 +50,14 @@ func UpdateOption(c *gin.Context) { }) return } + case "DiscordOAuthEnabled": + if option.Value == "true" && common.DiscordClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Discord OAuth,请先填入 Discord Client ID 以及 Discord Client Secret!", + }) + return + } case "WeChatAuthEnabled": if option.Value == "true" && common.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/model/option.go b/model/option.go index e7bc6806..df93b704 100644 --- a/model/option.go +++ b/model/option.go @@ -30,6 +30,7 @@ func InitOptionMap() { common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) + common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) @@ -53,6 +54,8 @@ func InitOptionMap() { common.OptionMap["ServerAddress"] = "" common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" + common.OptionMap["DiscordClientId"] = "" + common.OptionMap["DiscordClientSecret"] = "" common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerToken"] = "" common.OptionMap["WeChatAccountQRCodeImageURL"] = "" @@ -135,6 +138,8 @@ func updateOptionMap(key string, value string) (err error) { common.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": common.GitHubOAuthEnabled = boolValue + case "DiscordOAuthEnabled": + common.DiscordOAuthEnabled = boolValue case "WeChatAuthEnabled": common.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -171,6 +176,10 @@ func updateOptionMap(key string, value string) (err error) { common.GitHubClientId = value case "GitHubClientSecret": common.GitHubClientSecret = value + case "DiscordClientId": + common.DiscordClientId = value + case "DiscordClientSecret": + common.DiscordClientSecret = value case "Footer": common.Footer = value case "SystemName": diff --git a/model/user.go b/model/user.go index 7c771840..234b2d99 100644 --- a/model/user.go +++ b/model/user.go @@ -3,9 +3,10 @@ package model import ( "errors" "fmt" - "gorm.io/gorm" "one-api/common" "strings" + + "gorm.io/gorm" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. @@ -19,6 +20,7 @@ type User struct { Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` + DiscordId string `json:"discord_id" gorm:"column:discord_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management @@ -169,6 +171,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByDiscordId() error { + if user.DiscordId == "" { + return errors.New("Discord id 为空!") + } + DB.Where(User{DiscordId: user.DiscordId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -193,6 +203,10 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool { return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 } +func IsDiscordIdAlreadyTaken(discordId string) bool { + return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 +} + func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 383133fa..4ec68ef9 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -21,6 +21,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) + apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) diff --git a/web/src/App.js b/web/src/App.js index c967ce2c..5ca8988d 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -12,6 +12,7 @@ import AddUser from './pages/User/AddUser'; import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; import PasswordResetForm from './components/PasswordResetForm'; import GitHubOAuth from './components/GitHubOAuth'; +import DiscordOAuth from './components/DiscordOAuth'; import PasswordResetConfirm from './components/PasswordResetConfirm'; import { UserContext } from './context/User'; import { StatusContext } from './context/Status'; @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, count) => { + const res = await API.get(`/api/oauth/discord?code=${code}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + sendCode(code, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default DiscordOAuth; \ No newline at end of file diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 110dad46..f7d9b4c0 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -37,6 +37,12 @@ const LoginForm = () => { ); }; + const onDiscordOAuthClicked = () => { + window.open( + `https://discord.com/oauth2/authorize?response_type=code&client_id=${status.discord_client_id}&redirect_uri=${window.location.origin}/oauth/discord&scope=identify`, + ); + }; + const onWeChatLoginClicked = () => { setShowWeChatLoginModal(true); }; @@ -123,28 +129,32 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.discord_oauth ? ( <> Or - {status.github_oauth ? ( + {status.discord_oauth && ( + ) } + { + status.discord_oauth && ( + + ) + }