From 99c8c7750439493f4ccd7b0a945506fff41f5c24 Mon Sep 17 00:00:00 2001 From: OnEvent <35189106+Oniokey@users.noreply.github.com> Date: Sat, 21 Sep 2024 23:03:20 +0800 Subject: [PATCH] feat: add oidc support (#1725) * feat: add the ui for configuring the third-party standard OAuth2.0/OIDC. - update SystemSetting.js - add setup ui - add configuration * feat: add the ui for "allow the OAuth 2.0 to login" - update SystemSetting.js * feat: add OAuth 2.0 web ui and its process functions - update common.js - update AuthLogin.js - update config.js * fix: missing "Userinfo" endpoint configuration entry, used by OAuth clients to request user information from the IdP. - update config.js - update SystemSetting.js * feat: updated the icons for Lark and OIDC to match the style of the icons for WeChat, EMail, GitHub. - update lark.svg - new oidc.svg * refactor: Changing OAuth 2.0 to OIDC * feat: add OIDC login method * feat: Add support for OIDC login to the backend * fix: Change the AppId and AppSecret on the Web UI to the standard usage: ClientId, ClientSecret. * feat: Support quick configuration of OIDC through Well-Known Discovery Endpoint * feat: Standardize terminology, add well-known configuration - Change the AppId and AppSecret on the Server End to the standard usage: ClientId, ClientSecret. - add Well-Known configuration to store in database, no actual use in server end but store and display in web ui only --- common/config/config.go | 8 + controller/auth/oidc.go | 225 ++++++++++++++++++ controller/misc.go | 42 ++-- model/option.go | 15 ++ model/user.go | 13 + router/api.go | 1 + web/berry/src/assets/images/icons/lark.svg | 6 +- web/berry/src/assets/images/icons/oidc.svg | 7 + web/berry/src/config.js | 7 +- web/berry/src/hooks/useLogin.js | 24 +- web/berry/src/routes/OtherRoutes.js | 5 + web/berry/src/utils/common.js | 15 ++ .../views/Authentication/Auth/OidcOAuth.js | 94 ++++++++ .../Authentication/AuthForms/AuthLogin.js | 28 ++- web/berry/src/views/Profile/index.js | 22 +- .../views/Setting/component/SystemSetting.js | 173 +++++++++++++- 16 files changed, 659 insertions(+), 26 deletions(-) create mode 100644 controller/auth/oidc.go create mode 100644 web/berry/src/assets/images/icons/oidc.svg create mode 100644 web/berry/src/views/Authentication/Auth/OidcOAuth.js diff --git a/common/config/config.go b/common/config/config.go index 11da0b96..43f56862 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -70,6 +71,13 @@ var GitHubClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 00000000..7b4ad4b9 --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb..ae900870 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/model/option.go b/model/option.go index bed8d4c3..8fd30aee 100644 --- a/model/option.go +++ b/model/option.go @@ -28,6 +28,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/user.go b/model/user.go index 924d72f9..a964a0d7 100644 --- a/model/user.go +++ b/model/user.go @@ -39,6 +39,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/router/api.go b/router/api.go index d2ada4eb..6d00c6ea 100644 --- a/router/api.go +++ b/router/api.go @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) diff --git a/web/berry/src/assets/images/icons/lark.svg b/web/berry/src/assets/images/icons/lark.svg index 239e1bef..79688e2a 100644 --- a/web/berry/src/assets/images/icons/lark.svg +++ b/web/berry/src/assets/images/icons/lark.svg @@ -1 +1,5 @@ - \ No newline at end of file + + + diff --git a/web/berry/src/assets/images/icons/oidc.svg b/web/berry/src/assets/images/icons/oidc.svg new file mode 100644 index 00000000..96e01f81 --- /dev/null +++ b/web/berry/src/assets/images/icons/oidc.svg @@ -0,0 +1,7 @@ + + + + diff --git a/web/berry/src/config.js b/web/berry/src/config.js index eeeda99a..8c1faf9b 100644 --- a/web/berry/src/config.js +++ b/web/berry/src/config.js @@ -22,7 +22,12 @@ const config = { turnstile_site_key: '', version: '', wechat_login: false, - wechat_qrcode: '' + wechat_qrcode: '', + oidc: false, + oidc_client_id: '', + oidc_authorization_endpoint: '', + oidc_token_endpoint: '', + oidc_userinfo_endpoint: '', } }; diff --git a/web/berry/src/hooks/useLogin.js b/web/berry/src/hooks/useLogin.js index 39d8b407..6d89727d 100644 --- a/web/berry/src/hooks/useLogin.js +++ b/web/berry/src/hooks/useLogin.js @@ -70,6 +70,28 @@ const useLogin = () => { } }; + const oidcLogin = async (code, state) => { + try { + const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/panel'); + } else { + dispatch({ type: LOGIN, payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/panel'); + } + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + } + const wechatLogin = async (code) => { try { const res = await API.get(`/api/oauth/wechat?code=${code}`); @@ -94,7 +116,7 @@ const useLogin = () => { navigate('/'); }; - return { login, logout, githubLogin, wechatLogin, larkLogin }; + return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin }; }; export default useLogin; diff --git a/web/berry/src/routes/OtherRoutes.js b/web/berry/src/routes/OtherRoutes.js index 58c0b660..a4bdb5d3 100644 --- a/web/berry/src/routes/OtherRoutes.js +++ b/web/berry/src/routes/OtherRoutes.js @@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login')) const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register'))); const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth'))); const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth'))); +const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth'))); const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword'))); const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword'))); const Home = Loadable(lazy(() => import('views/Home'))); @@ -53,6 +54,10 @@ const OtherRoutes = { path: '/oauth/lark', element: }, + { + path: 'oauth/oidc', + element: + }, { path: '/404', element: diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index d74d032e..f9c2896c 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -98,6 +98,21 @@ export async function onLarkOAuthClicked(lark_client_id) { window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); } +export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { + const state = await getOAuthState(); + if (!state) return; + const redirect_uri = `${window.location.origin}/oauth/oidc`; + const response_type = "code"; + const scope = "openid profile email"; + const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; + if (openInNewTab) { + window.open(url); + } else + { + window.location.href = url; + } +} + export function isAdmin() { let user = localStorage.getItem('user'); if (!user) return false; diff --git a/web/berry/src/views/Authentication/Auth/OidcOAuth.js b/web/berry/src/views/Authentication/Auth/OidcOAuth.js new file mode 100644 index 00000000..55d9372d --- /dev/null +++ b/web/berry/src/views/Authentication/Auth/OidcOAuth.js @@ -0,0 +1,94 @@ +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import React, { useEffect, useState } from 'react'; +import { showError } from 'utils/common'; +import useLogin from 'hooks/useLogin'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material'; + +// project imports +import AuthWrapper from '../AuthWrapper'; +import AuthCardWrapper from '../AuthCardWrapper'; +import Logo from 'ui-component/Logo'; + +// assets + +// ================================|| AUTH3 - LOGIN ||================================ // + +const OidcOAuth = () => { + const theme = useTheme(); + const matchDownSM = useMediaQuery(theme.breakpoints.down('md')); + + const [searchParams] = useSearchParams(); + const [prompt, setPrompt] = useState('处理中...'); + const { oidcLogin } = useLogin(); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const { success, message } = await oidcLogin(code, state); + if (!success) { + if (message) { + showError(message); + } + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + navigate('/login'); + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + + + + + + + + + + + + + + + + OIDC 登录 + + + + + + + + + {prompt} + + + + + + + + + + ); +}; + +export default OidcOAuth; diff --git a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js index bc7a35c0..7efd0362 100644 --- a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js +++ b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js @@ -36,7 +36,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff'; import Github from 'assets/images/icons/github.svg'; import Wechat from 'assets/images/icons/wechat.svg'; import Lark from 'assets/images/icons/lark.svg'; -import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common'; +import OIDC from 'assets/images/icons/oidc.svg'; +import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common'; // ============================|| FIREBASE - LOGIN ||============================ // @@ -50,7 +51,7 @@ const LoginForm = ({ ...others }) => { // const [checked, setChecked] = useState(true); let tripartiteLogin = false; - if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) { + if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) { tripartiteLogin = true; } @@ -145,6 +146,29 @@ const LoginForm = ({ ...others }) => { )} + {siteInfo.oidc && ( + + + + + + )} 8) { + oidc_id = inputs.oidc_id.slice(0, 6) + '...' + inputs.oidc_id.slice(-6); + } + return oidc_id; + } + return ( <> @@ -141,6 +151,9 @@ export default function Profile() { + @@ -216,6 +229,13 @@ export default function Profile() { )} + {status.oidc && !inputs.oidc_id && ( + + + + )} + + + +