118 lines
2.7 KiB
Go
118 lines
2.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"leafdev.top/Leaf/leaf-library-3/internal/constants"
|
|
"leafdev.top/Leaf/leaf-library-3/internal/dto/user"
|
|
errs2 "leafdev.top/Leaf/leaf-library-3/internal/errs"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/mitchellh/mapstructure"
|
|
)
|
|
|
|
func (a *Service) AuthFromToken(tokenType constants.JwtTokenTypes, token string) (*user.User, error) {
|
|
if a.config.Debug.Enabled {
|
|
return a.parseUserJWT(tokenType, "")
|
|
}
|
|
|
|
return a.parseUserJWT(tokenType, token)
|
|
}
|
|
|
|
func (a *Service) GetUserFromIdToken(idToken string) (*user.User, error) {
|
|
return a.parseUserJWT(constants.JwtTokenTypeIDToken, idToken)
|
|
}
|
|
|
|
func (a *Service) GetUser(ctx *fiber.Ctx) *user.User {
|
|
userCtx := ctx.Locals(constants.AuthMiddlewareKey)
|
|
|
|
u, ok := userCtx.(*user.User)
|
|
u.ID = u.Token.Sub
|
|
|
|
if !ok {
|
|
panic("User context is not valid")
|
|
}
|
|
|
|
return u
|
|
}
|
|
|
|
func (a *Service) GetCtx(ctx context.Context) *user.User {
|
|
userCtx := ctx.Value(constants.AuthMiddlewareKey)
|
|
|
|
u, ok := userCtx.(*user.User)
|
|
u.ID = u.Token.Sub
|
|
|
|
if !ok {
|
|
panic("User context is not valid")
|
|
}
|
|
|
|
return u
|
|
}
|
|
|
|
func (a *Service) GetUserSafe(ctx *fiber.Ctx) (*user.User, bool) {
|
|
userCtx := ctx.Locals(constants.AuthMiddlewareKey)
|
|
|
|
u, ok := userCtx.(*user.User)
|
|
u.ID = u.Token.Sub
|
|
|
|
return u, ok
|
|
}
|
|
|
|
func (a *Service) GetCtxSafe(ctx context.Context) (*user.User, bool) {
|
|
userCtx := ctx.Value(constants.AuthMiddlewareKey)
|
|
|
|
u, ok := userCtx.(*user.User)
|
|
u.ID = u.Token.Sub
|
|
|
|
return u, ok
|
|
}
|
|
|
|
func (a *Service) SetUser(ctx context.Context, user *user.User) context.Context {
|
|
return context.WithValue(ctx, constants.AuthMiddlewareKey, user)
|
|
}
|
|
|
|
func (a *Service) parseUserJWT(tokenType constants.JwtTokenTypes, jwtToken string) (*user.User, error) {
|
|
var sub = user.AnonymousUser
|
|
var jwtIdToken = new(user.User)
|
|
|
|
if a.config.Debug.Enabled {
|
|
jwtIdToken.Token.Sub = sub
|
|
jwtIdToken.Valid = true
|
|
return jwtIdToken, nil
|
|
} else {
|
|
token, err := a.jwks.ParseJWT(jwtToken)
|
|
if err != nil {
|
|
return nil, errs2.NotValidToken
|
|
}
|
|
|
|
subStr, err := token.Claims.GetSubject()
|
|
if err != nil {
|
|
return nil, errs2.NotValidToken
|
|
}
|
|
|
|
sub = user.ID(subStr)
|
|
|
|
// 如果 token.Header 中没有 typ
|
|
if token.Header["typ"] == "" {
|
|
return nil, errs2.EmptyResponse
|
|
}
|
|
|
|
// 验证 token 类型
|
|
if tokenType != "" && tokenType.String() != token.Header["typ"] {
|
|
return nil, errs2.TokenError
|
|
}
|
|
|
|
jwtIdToken.Valid = true
|
|
|
|
err = mapstructure.Decode(token.Claims, &jwtIdToken.Token)
|
|
if err != nil {
|
|
a.logger.Logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
|
|
return nil, nil
|
|
}
|
|
|
|
// 手动指定,因为 mapstructure 无法转换 UserID 类型
|
|
jwtIdToken.Token.Sub = sub
|
|
}
|
|
|
|
return jwtIdToken, nil
|
|
}
|