diff --git a/common/config/config.go b/common/config/config.go index 11da0b96..f9e4a540 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -70,6 +71,12 @@ var GitHubClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcAppId = "" +var OidcAppSecret = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 00000000..02865c60 --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcAppId, + "client_secret": config.OidcAppSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb..0aef52c0 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,29 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_app_id": config.OidcAppId, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/model/option.go b/model/option.go index bed8d4c3..fa9f9c98 100644 --- a/model/option.go +++ b/model/option.go @@ -28,6 +28,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -176,6 +179,16 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcAppId": + config.OidcAppId = value + case "OidcAppSecret": + config.OidcAppSecret = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/user.go b/model/user.go index 924d72f9..a964a0d7 100644 --- a/model/user.go +++ b/model/user.go @@ -39,6 +39,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/router/api.go b/router/api.go index d2ada4eb..6d00c6ea 100644 --- a/router/api.go +++ b/router/api.go @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)