- Change the AppId and AppSecret on the Server End to the standard usage: ClientId, ClientSecret. - add Well-Known configuration to store in database, no actual use in server end but store and display in web ui only
226 lines
5.2 KiB
Go
226 lines
5.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-contrib/sessions"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/songquanpeng/one-api/common/config"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/controller"
|
|
"github.com/songquanpeng/one-api/model"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
type OidcResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
IDToken string `json:"id_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope"`
|
|
}
|
|
|
|
type OidcUser struct {
|
|
OpenID string `json:"sub"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
Picture string `json:"picture"`
|
|
}
|
|
|
|
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
|
if code == "" {
|
|
return nil, errors.New("无效的参数")
|
|
}
|
|
values := map[string]string{
|
|
"client_id": config.OidcClientId,
|
|
"client_secret": config.OidcClientSecret,
|
|
"code": code,
|
|
"grant_type": "authorization_code",
|
|
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
|
|
}
|
|
jsonData, err := json.Marshal(values)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
client := http.Client{
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
logger.SysLog(err.Error())
|
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
|
}
|
|
defer res.Body.Close()
|
|
var oidcResponse OidcResponse
|
|
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
|
|
res2, err := client.Do(req)
|
|
if err != nil {
|
|
logger.SysLog(err.Error())
|
|
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
|
|
}
|
|
var oidcUser OidcUser
|
|
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &oidcUser, nil
|
|
}
|
|
|
|
func OidcAuth(c *gin.Context) {
|
|
session := sessions.Default(c)
|
|
state := c.Query("state")
|
|
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
|
c.JSON(http.StatusForbidden, gin.H{
|
|
"success": false,
|
|
"message": "state is empty or not same",
|
|
})
|
|
return
|
|
}
|
|
username := session.Get("username")
|
|
if username != nil {
|
|
OidcBind(c)
|
|
return
|
|
}
|
|
if !config.OidcEnabled {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
|
})
|
|
return
|
|
}
|
|
code := c.Query("code")
|
|
oidcUser, err := getOidcUserInfoByCode(code)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
user := model.User{
|
|
OidcId: oidcUser.OpenID,
|
|
}
|
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
|
err := user.FillUserByOidcId()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
} else {
|
|
if config.RegisterEnabled {
|
|
user.Email = oidcUser.Email
|
|
if oidcUser.PreferredUsername != "" {
|
|
user.Username = oidcUser.PreferredUsername
|
|
} else {
|
|
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
|
|
}
|
|
if oidcUser.Name != "" {
|
|
user.DisplayName = oidcUser.Name
|
|
} else {
|
|
user.DisplayName = "OIDC User"
|
|
}
|
|
err := user.Insert(0)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
} else {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "管理员关闭了新用户注册",
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
if user.Status != model.UserStatusEnabled {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"message": "用户已被封禁",
|
|
"success": false,
|
|
})
|
|
return
|
|
}
|
|
controller.SetupLogin(&user, c)
|
|
}
|
|
|
|
func OidcBind(c *gin.Context) {
|
|
if !config.OidcEnabled {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "管理员未开启通过 OIDC 登录以及注册",
|
|
})
|
|
return
|
|
}
|
|
code := c.Query("code")
|
|
oidcUser, err := getOidcUserInfoByCode(code)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
user := model.User{
|
|
OidcId: oidcUser.OpenID,
|
|
}
|
|
if model.IsOidcIdAlreadyTaken(user.OidcId) {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": "该 OIDC 账户已被绑定",
|
|
})
|
|
return
|
|
}
|
|
session := sessions.Default(c)
|
|
id := session.Get("id")
|
|
// id := c.GetInt("id") // critical bug!
|
|
user.Id = id.(int)
|
|
err = user.FillUserById()
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
user.OidcId = oidcUser.OpenID
|
|
err = user.Update(false)
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": false,
|
|
"message": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"success": true,
|
|
"message": "bind",
|
|
})
|
|
return
|
|
}
|