diff --git a/common/constants.go b/common/constants.go index a1096901..abf77b22 100644 --- a/common/constants.go +++ b/common/constants.go @@ -40,6 +40,7 @@ var EmailVerificationEnabled = false var GitHubOAuthEnabled = false var DiscordOAuthEnabled = false var WeChatAuthEnabled = false +var GoogleOAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -61,6 +62,9 @@ var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" +var GoogleClientId = "" +var GoogleClientSecret = "" + var TurnstileSiteKey = "" var TurnstileSecretKey = "" diff --git a/controller/google.go b/controller/google.go new file mode 100644 index 00000000..c5c0fa11 --- /dev/null +++ b/controller/google.go @@ -0,0 +1,242 @@ +package controller + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "one-api/common" + "one-api/model" + "strconv" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" +) + +type GoogleAccessTokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` +} + +type GoogleUser struct { + Sub string `json:"sub"` + Name string `json:"name"` +} + +func getGoogleUserInfoByCode(codeFromURLParamaters string, host string) (*GoogleUser, error) { + if codeFromURLParamaters == "" { + return nil, errors.New("无效参数") + } + + RequestClient := &http.Client{} + + accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf( + "code=%s&client_id=%s&client_secret=%s&redirect_uri=%s/oauth/google&grant_type=authorization_code", + codeFromURLParamaters, common.GoogleClientId, common.GoogleClientSecret, host, + ))) + + req, _ := http.NewRequest("POST", + "https://oauth2.googleapis.com/token", + accessTokenBody, + ) + + req.Header = http.Header{ + "Content-Type": []string{"application/x-www-form-urlencoded"}, + "Accept": []string{"application/json"}, + } + + resp, err := RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("访问令牌无效") + } + + var googleTokenResponse GoogleAccessTokenResponse + + json.NewDecoder(resp.Body).Decode(&googleTokenResponse) + + accessToken := "Bearer " + googleTokenResponse.AccessToken + + // Get User Info + req, _ = http.NewRequest("GET", "https://www.googleapis.com/oauth2/v3/userinfo", nil) + + req.Header = http.Header{ + "Content-Type": []string{"application/json"}, + "Authorization": []string{accessToken}, + } + + defer resp.Body.Close() + + resp, err = RequestClient.Do(req) + + if resp.StatusCode != 200 || err != nil { + return nil, errors.New("Google 用户信息无效") + } + + var googleUser GoogleUser + + // Parse json to googleUser + err = json.NewDecoder(resp.Body).Decode(&googleUser) + + if err != nil { + return nil, err + } + + if googleUser.Sub == "" { + return nil, errors.New("返回值无效,用户字段为空,请稍后再试!") + } + + defer resp.Body.Close() + + return &googleUser, nil +} + +func GoogleOAuth(c *gin.Context) { + session := sessions.Default(c) + username := session.Get("username") + if username != nil { + GoogleBind(c) + return + } + + if !common.GoogleOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Google 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + googleUser, err := getGoogleUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GoogleId: googleUser.Sub, + } + if model.IsGoogleIdAlreadyTaken(user.GoogleId) { + err := user.FillUserByGoogleId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if common.RegisterEnabled { + user.Username = "google_" + strconv.Itoa(model.GetMaxUserId()+1) + if googleUser.Name != "" { + user.DisplayName = googleUser.Name + } else { + user.DisplayName = "Google User" + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + setupLogin(&user, c) +} + +func GoogleBind(c *gin.Context) { + if !common.GoogleOAuthEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 Google 登录以及注册", + }) + return + } + code := c.Query("code") + + // Get protocal whether http or https and host + host := c.Request.Host + if c.Request.TLS == nil { + host = "http://" + host + } else { + host = "https://" + host + } + + googleUser, err := getGoogleUserInfoByCode(code, host) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + GoogleId: googleUser.Sub, + } + if model.IsGoogleIdAlreadyTaken(user.GoogleId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 Google 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.GoogleId = googleUser.Sub + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index c6634bd1..b25bac2f 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -22,6 +22,8 @@ func GetStatus(c *gin.Context) { "github_client_id": common.GitHubClientId, "discord_oauth": common.DiscordOAuthEnabled, "discord_client_id": common.DiscordClientId, + "google_oauth": common.GoogleOAuthEnabled, + "google_client_id": common.GoogleClientId, "system_name": common.SystemName, "logo": common.Logo, "footer_html": common.Footer, diff --git a/controller/option.go b/controller/option.go index a5ee8f03..c4758e7c 100644 --- a/controller/option.go +++ b/controller/option.go @@ -66,6 +66,14 @@ func UpdateOption(c *gin.Context) { }) return } + case "GoogleOAuthEnabled": + if option.Value == "true" && common.GoogleClientId == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无法启用 Google OAuth,请先填入 Google Client ID 以及 Google Client Secret!", + }) + return + } case "TurnstileCheckEnabled": if option.Value == "true" && common.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ diff --git a/model/option.go b/model/option.go index df93b704..512a81b4 100644 --- a/model/option.go +++ b/model/option.go @@ -32,6 +32,7 @@ func InitOptionMap() { common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) + common.OptionMap["GoogleOAuthEnabled"] = strconv.FormatBool(common.GoogleOAuthEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) @@ -59,6 +60,8 @@ func InitOptionMap() { common.OptionMap["WeChatServerAddress"] = "" common.OptionMap["WeChatServerToken"] = "" common.OptionMap["WeChatAccountQRCodeImageURL"] = "" + common.OptionMap["GoogleClientId"] = "" + common.OptionMap["GoogleClientSecret"] = "" common.OptionMap["TurnstileSiteKey"] = "" common.OptionMap["TurnstileSecretKey"] = "" common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) @@ -142,6 +145,8 @@ func updateOptionMap(key string, value string) (err error) { common.DiscordOAuthEnabled = boolValue case "WeChatAuthEnabled": common.WeChatAuthEnabled = boolValue + case "GoogleOAuthEnabled": + common.GoogleOAuthEnabled = boolValue case "TurnstileCheckEnabled": common.TurnstileCheckEnabled = boolValue case "RegisterEnabled": @@ -192,6 +197,10 @@ func updateOptionMap(key string, value string) (err error) { common.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": common.WeChatAccountQRCodeImageURL = value + case "GoogleClientId": + common.GoogleClientId = value + case "GoogleClientSecret": + common.GoogleClientSecret = value case "TurnstileSiteKey": common.TurnstileSiteKey = value case "TurnstileSecretKey": diff --git a/model/user.go b/model/user.go index 234b2d99..325910c8 100644 --- a/model/user.go +++ b/model/user.go @@ -22,6 +22,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` DiscordId string `json:"discord_id" gorm:"column:discord_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + GoogleId string `json:"google_id" gorm:"column:google_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` @@ -187,6 +188,14 @@ func (user *User) FillUserByWeChatId() error { return nil } +func (user *User) FillUserByGoogleId() error { + if user.WeChatId == "" { + return errors.New("Google id 为空!") + } + DB.Where(User{GoogleId: user.GoogleId}).First(user) + return nil +} + func (user *User) FillUserByUsername() error { if user.Username == "" { return errors.New("username 为空!") @@ -207,6 +216,10 @@ func IsDiscordIdAlreadyTaken(discordId string) bool { return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1 } +func IsGoogleIdAlreadyTaken(googleId string) bool { + return DB.Where("google_id = ?", googleId).Find(&User{}).RowsAffected == 1 +} + func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 9036341b..6e8572c2 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) + apiRouter.GET("/oauth/google", middleware.CriticalRateLimit(), controller.GoogleOAuth) apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) diff --git a/web/src/App.js b/web/src/App.js index 5ca8988d..c7601040 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -13,6 +13,7 @@ import { API, getLogo, getSystemName, showError, showNotice } from './helpers'; import PasswordResetForm from './components/PasswordResetForm'; import GitHubOAuth from './components/GitHubOAuth'; import DiscordOAuth from './components/DiscordOAuth'; +import GoogleOAuth from './components/GoogleOAuth'; import PasswordResetConfirm from './components/PasswordResetConfirm'; import { UserContext } from './context/User'; import { StatusContext } from './context/Status'; @@ -241,6 +242,7 @@ function App() { } /> }> @@ -248,6 +250,15 @@ function App() { } /> + }> + + support-google-oauth + + } + /> - }> - - - + + }> + + + } /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, count) => { + const res = await API.get(`/api/oauth/google?code=${code}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind GitHub + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + sendCode(code, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default GoogleOAuth; \ No newline at end of file diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index 9fc1d160..46664323 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -40,6 +40,12 @@ const LoginForm = () => { const [showWeChatLoginModal, setShowWeChatLoginModal] = useState(false); + const openGoogleOAuth = () => { + window.open( + `https://accounts.google.com/o/oauth2/v2/auth?client_id=${status.google_client_id}&redirect_uri=${window.location.origin}/oauth/google&response_type=code&scope=profile` + ); + }; + const onGitHubOAuthClicked = () => { window.open( `https://github.com/login/oauth/authorize?client_id=${status.github_client_id}&scope=user:email` @@ -153,7 +159,7 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login || status.discord_oauth ? ( + {status.github_oauth || status.wechat_login || status.discord_oauth || status.google_oauth ? ( <> Or {status.discord_oauth && ( @@ -180,6 +186,14 @@ const LoginForm = () => { onClick={onWeChatLoginClicked} /> )} + {status.google_oauth && ( + ) } + { + status.google_oauth && ( + + + ) + }