package auth import ( "bytes" "encoding/json" "errors" "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" "time" ) type GitHubOAuthResponse struct { AccessToken string `json:"access_token"` Scope string `json:"scope"` TokenType string `json:"token_type"` } type GitHubUser struct { Login string `json:"login"` Name string `json:"name"` Email string `json:"email"` } func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err } req, err := http.NewRequest("POST", "https://github.com/login/oauth/access_token", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") client := http.Client{ Timeout: 5 * time.Second, } res, err := client.Do(req) if err != nil { logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() var oAuthResponse GitHubOAuthResponse err = json.NewDecoder(res.Body).Decode(&oAuthResponse) if err != nil { return nil, err } req, err = http.NewRequest("GET", "https://api.github.com/user", nil) if err != nil { return nil, err } req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() var githubUser GitHubUser err = json.NewDecoder(res2.Body).Decode(&githubUser) if err != nil { return nil, err } if githubUser.Login == "" { return nil, errors.New("返回值非法,用户字段为空,请稍后重试!") } return &githubUser, nil } func GitHubOAuth(c *gin.Context) { session := sessions.Default(c) state := c.Query("state") if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { c.JSON(http.StatusForbidden, gin.H{ "success": false, "message": "state is empty or not same", }) return } username := session.Get("username") if username != nil { GitHubBind(c) return } if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", }) return } code := c.Query("code") githubUser, err := getGitHubUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ GitHubId: githubUser.Login, } if model.IsGitHubIdAlreadyTaken(user.GitHubId) { err := user.FillUserByGitHubId() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { if config.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name } else { user.DisplayName = "GitHub User" } user.Email = githubUser.Email user.Role = model.RoleCommonUser user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } } else { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员关闭了新用户注册", }) return } } if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", }) return } code := c.Query("code") githubUser, err := getGitHubUserInfoByCode(code) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user := model.User{ GitHubId: githubUser.Login, } if model.IsGitHubIdAlreadyTaken(user.GitHubId) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该 GitHub 账户已被绑定", }) return } session := sessions.Default(c) id := session.Get("id") // id := c.GetInt("id") // critical bug! user.Id = id.(int) err = user.FillUserById() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } user.GitHubId = githubUser.Login err = user.Update(false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "bind", }) return } func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) state := random.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), }) return } c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", "data": state, }) }