From c144c64fff794850f4da00159850ddd19abb755b Mon Sep 17 00:00:00 2001 From: ckt1031 <65409152+ckt1031@users.noreply.github.com> Date: Mon, 24 Jul 2023 20:09:52 +0800 Subject: [PATCH] feat: support Google OAuth --- common/constants.go | 4 + controller/google.go | 242 ++++++++++++++++++++++++++ controller/misc.go | 5 +- controller/option.go | 11 +- model/option.go | 9 + model/user.go | 16 +- router/api-router.go | 1 + web/src/App.js | 9 + web/src/components/GoogleOAuth.js | 57 ++++++ web/src/components/LoginForm.js | 24 ++- web/src/components/PersonalSetting.js | 11 ++ web/src/components/SystemSetting.js | 62 +++++++ web/src/pages/Home/index.js | 6 + web/src/pages/User/EditUser.js | 13 +- 14 files changed, 459 insertions(+), 11 deletions(-) create mode 100644 controller/google.go create mode 100644 web/src/components/GoogleOAuth.js diff --git a/common/constants.go b/common/constants.go index 81f98163..a7110b3f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -39,6 +39,7 @@ var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false var WeChatAuthEnabled = false +var GoogleOAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -57,6 +58,9 @@ var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" +var GoogleClientId = "" +var GoogleClientSecret = "" + var TurnstileSiteKey = "" var TurnstileSecretKey = "" diff --git a/controller/google.go b/controller/google.go new file mode 100644 index 00000000..c5c0fa11 --- /dev/null +++ b/controller/google.go @@ -0,0 +1,242 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type GoogleAccessTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` +} + +type GoogleUser struct { + Sub string `json:"sub"` + Name string `json:"name"` +} + +func getGoogleUserInfoByCode(codeFromURLParamaters string, host string) (*GoogleUser, error) { + if codeFromURLParamaters == "" { + return nil, errors.New("无效参数") + } + + RequestClient := &http.Client{} + + accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf( + "code=%s&client_id=%s&client_secret=%s&redirect_uri=%s/oauth/google&grant_type=authorization_code", + codeFromURLParamaters, common.GoogleClientId, common.GoogleClientSecret, host, + ))) + + req, _ := http.NewRequest("POST", + "https://oauth2.googleapis.com/token", + accessTokenBody, + ) + + req.Header = http.Header{ + "Content-Type": []string{"application/x-www-form-urlencoded"}, + "Accept": []string{"application/json"}, + } + + resp, err := RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("访问令牌无效") + } + + var googleTokenResponse GoogleAccessTokenResponse + + json.NewDecoder(resp.Body).Decode(&googleTokenResponse) + + accessToken := "Bearer " + googleTokenResponse.AccessToken + + // Get User Info + req, _ = http.NewRequest("GET", "https://www.googleapis.com/oauth2/v3/userinfo", nil) + + req.Header = http.Header{ + "Content-Type": []string{"application/json"}, + "Authorization": []string{accessToken}, + } + + defer resp.Body.Close() + + resp, err = RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("Google 用户信息无效") + } + + var googleUser GoogleUser + + // Parse json to googleUser + err = json.NewDecoder(resp.Body).Decode(&googleUser) + + if err != nil { + return nil, err + } + + if googleUser.Sub == "" { + return nil, errors.New("返回值无效,用户字段为空,请稍后再试!") + } + + defer resp.Body.Close() + + return &googleUser, nil +} + +func GoogleOAuth(c *gin.Context) { + session := sessions.Default(c) + username := session.Get("username") + if username != nil { + GoogleBind(c) + return + } + + if !common.GoogleOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Google 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + googleUser, err := getGoogleUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GoogleId: googleUser.Sub, + } + if model.IsGoogleIdAlreadyTaken(user.GoogleId) { + err := user.FillUserByGoogleId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + user.Username = "google_" + strconv.Itoa(model.GetMaxUserId()+1) + if googleUser.Name != "" { + user.DisplayName = googleUser.Name + } else { + user.DisplayName = "Google User" + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func GoogleBind(c *gin.Context) { + if !common.GoogleOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Google 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + googleUser, err := getGoogleUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GoogleId: googleUser.Sub, + } + if model.IsGoogleIdAlreadyTaken(user.GoogleId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 Google 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.GoogleId = googleUser.Sub + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index 958a3716..078e0a5d 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -3,10 +3,11 @@ package controller import ( "encoding/json" "fmt" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" + + "github.com/gin-gonic/gin" ) func GetStatus(c *gin.Context) { @@ -19,6 +20,8 @@ func GetStatus(c *gin.Context) { "email_verification": common.EmailVerificationEnabled, "github_oauth": common.GitHubOAuthEnabled, "github_client_id": common.GitHubClientId, + "google_oauth": common.GoogleOAuthEnabled, + "google_client_id": common.GoogleClientId, "system_name": common.SystemName, "logo": common.Logo, "footer_html": common.Footer, diff --git a/controller/option.go b/controller/option.go index abf0d5be..01fae423 100644 --- a/controller/option.go +++ b/controller/option.go @@ -2,11 +2,12 @@ package controller import ( "encoding/json" - "github.com/gin-gonic/gin" "net/http" "one-api/common" "one-api/model" "strings" + + "github.com/gin-gonic/gin" ) func GetOptions(c *gin.Context) { @@ -57,6 +58,14 @@ func UpdateOption(c *gin.Context) { }) return } + case "GoogleOAuthEnabled": + if option.Value == "true" && common.GoogleClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Google OAuth,请先填入 Google Client ID 以及 Google Client Secret!", + }) + return + } case "TurnstileCheckEnabled": if option.Value == "true" && common.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/model/option.go b/model/option.go index e7bc6806..a3fe35ac 100644 --- a/model/option.go +++ b/model/option.go @@ -31,6 +31,7 @@ func InitOptionMap() { common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["GoogleOAuthEnabled"] = strconv.FormatBool(common.GoogleOAuthEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) @@ -56,6 +57,8 @@ func InitOptionMap() { common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerToken"] = "" common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["GoogleClientId"] = "" + common.OptionMap["GoogleClientSecret"] = "" common.OptionMap["TurnstileSiteKey"] = "" common.OptionMap["TurnstileSecretKey"] = "" common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) @@ -137,6 +140,8 @@ func updateOptionMap(key string, value string) (err error) { common.GitHubOAuthEnabled = boolValue case "WeChatAuthEnabled": common.WeChatAuthEnabled = boolValue + case "GoogleOAuthEnabled": + common.GoogleOAuthEnabled = boolValue case "TurnstileCheckEnabled": common.TurnstileCheckEnabled = boolValue case "RegisterEnabled": @@ -183,6 +188,10 @@ func updateOptionMap(key string, value string) (err error) { common.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": common.WeChatAccountQRCodeImageURL = value + case "GoogleClientId": + common.GoogleClientId = value + case "GoogleClientSecret": + common.GoogleClientSecret = value case "TurnstileSiteKey": common.TurnstileSiteKey = value case "TurnstileSecretKey": diff --git a/model/user.go b/model/user.go index 7c771840..6ccf9001 100644 --- a/model/user.go +++ b/model/user.go @@ -3,9 +3,10 @@ package model import ( "errors" "fmt" - "gorm.io/gorm" "one-api/common" "strings" + + "gorm.io/gorm" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. @@ -20,6 +21,7 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + GoogleId string `json:"google_id" gorm:"column:google_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` @@ -177,6 +179,14 @@ func (user *User) FillUserByWeChatId() error { return nil } +func (user *User) FillUserByGoogleId() error { + if user.WeChatId == "" { + return errors.New("Google id 为空!") + } + DB.Where(User{GoogleId: user.GoogleId}).First(user) + return nil +} + func (user *User) FillUserByUsername() error { if user.Username == "" { return errors.New("username 为空!") @@ -193,6 +203,10 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool { return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 } +func IsGoogleIdAlreadyTaken(googleId string) bool { + return DB.Where("google_id = ?", googleId).Find(&User{}).RowsAffected == 1 +} + func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 383133fa..94898702 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -22,6 +22,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) + apiRouter.GET("/oauth/google", middleware.CriticalRateLimit(), controller.GoogleOAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) diff --git a/web/src/App.js b/web/src/App.js index c967ce2c..2a605530 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -12,6 +12,7 @@ import AddUser from './pages/User/AddUser'; import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; import PasswordResetForm from './components/PasswordResetForm'; import GitHubOAuth from './components/GitHubOAuth'; +import GoogleOAuth from './components/GoogleOAuth'; import PasswordResetConfirm from './components/PasswordResetConfirm'; import { UserContext } from './context/User'; import { StatusContext } from './context/Status'; @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, count) => { + const res = await API.get(`/api/oauth/google?code=${code}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + sendCode(code, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GoogleOAuth; \ No newline at end of file diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 110dad46..7f7dc5b9 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -31,6 +31,12 @@ const LoginForm = () => { const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + const openGoogleOAuth = () => { + window.open( + `https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=profile` + ); + }; + const onGitHubOAuthClicked = () => { window.open( `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` @@ -123,28 +129,32 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.google_oauth ? ( <> Or - {status.github_oauth ? ( + {status.github_oauth && ( ) } + { + status.google_oauth && ( + + ) + }