diff --git a/README.md b/README.md index db43ba6e..596def13 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 21. 支持 Cloudflare Turnstile 用户校验。 22. 支持用户管理,支持**多种用户登录注册方式**: + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + + 支持使用飞书进行授权登录。 + [GitHub 开放授权](https://github.com/settings/applications/new)。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 diff --git a/common/config/config.go b/common/config/config.go index 3524183a..9fd7cba0 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -66,6 +66,9 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var LarkClientId = "" +var LarkClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/controller/github.go b/controller/auth/github.go similarity index 98% rename from controller/github.go rename to controller/auth/github.go index 7d7fa106..cf073133 100644 --- a/controller/github.go +++ b/controller/auth/github.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "bytes" @@ -11,6 +11,7 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) { }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 00000000..21446d46 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,201 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = common.RoleCommonUser + user.Status = common.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != common.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/wechat.go b/controller/auth/wechat.go similarity index 97% rename from controller/wechat.go rename to controller/auth/wechat.go index 74be5604..80552c9a 100644 --- a/controller/wechat.go +++ b/controller/auth/wechat.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "encoding/json" @@ -7,6 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) { }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func WeChatBind(c *gin.Context) { diff --git a/controller/misc.go b/controller/misc.go index f27fdb12..2928b8fb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) { "email_verification": config.EmailVerificationEnabled, "github_oauth": config.GitHubOAuthEnabled, "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, "system_name": config.SystemName, "logo": config.Logo, "footer_html": config.Footer, diff --git a/controller/user.go b/controller/user.go index 61055878..e87a03a2 100644 --- a/controller/user.go +++ b/controller/user.go @@ -58,11 +58,11 @@ func Login(c *gin.Context) { }) return } - setupLogin(&user, c) + SetupLogin(&user, c) } // setup session & cookies and then return user info -func setupLogin(user *model.User, c *gin.Context) { +func SetupLogin(user *model.User, c *gin.Context) { session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) diff --git a/model/option.go b/model/option.go index 1d1c28b4..cee9bd3b 100644 --- a/model/option.go +++ b/model/option.go @@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { config.GitHubClientId = value case "GitHubClientSecret": config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/user.go b/model/user.go index 5e729b5e..42d8f7b1 100644 --- a/model/user.go +++ b/model/user.go @@ -24,6 +24,7 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -41,21 +42,21 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { - query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) - - switch order { - case "quota": - query = query.Order("quota desc") - case "used_quota": - query = query.Order("used_quota desc") - case "request_count": - query = query.Order("request_count desc") - default: - query = query.Order("id desc") - } - - err = query.Find(&users).Error - return users, err + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error + return users, err } func SearchUsers(keyword string) (users []*User, err error) { @@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/router/api-router.go b/router/api-router.go index 1558640f..a36232b3 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -2,6 +2,7 @@ package router import ( "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" "github.com/songquanpeng/one-api/middleware" "github.com/gin-contrib/gzip" @@ -21,10 +22,11 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) - apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) - apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) diff --git a/web/default/src/App.js b/web/default/src/App.js index 13c884dc..4ece4eeb 100644 --- a/web/default/src/App.js +++ b/web/default/src/App.js @@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; +import LarkOAuth from './components/LarkOAuth'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind lark + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default LarkOAuth; diff --git a/web/default/src/components/LoginForm.js b/web/default/src/components/LoginForm.js index b48f64c4..01408f56 100644 --- a/web/default/src/components/LoginForm.js +++ b/web/default/src/components/LoginForm.js @@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; -import { onGitHubOAuthClicked } from './utils'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; +import larkIcon from '../images/lark.svg'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -124,7 +125,7 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.lark_client_id ? ( <> Or {status.github_oauth ? ( @@ -137,6 +138,18 @@ const LoginForm = () => { ) : ( <> )} + {status.lark_client_id ? ( + + ) : ( + <> + )} {status.wechat_login ? ( ) } + { + status.lark_client_id && ( + + ) + }