改进 jwt 验证

This commit is contained in:
Twilight 2024-07-14 18:14:01 +08:00
parent 0e421a31a9
commit ee1e339007
5 changed files with 40 additions and 14 deletions

View File

@ -10,5 +10,5 @@ import (
func InitApiRoutes() { func InitApiRoutes() {
var router = *providers.MustGet[gin.Engine]() var router = *providers.MustGet[gin.Engine]()
router.GET("/", http.MiddlewareJSONResponse, http.ValidateUser, user.CurrentUser) router.GET("/", http.MiddlewareJSONResponse, http.RequireJWTIDToken, user.CurrentUser)
} }

View File

@ -10,6 +10,8 @@ var AuthLogic = logic.NewAuthLogic()
func CurrentUser(c *gin.Context) { func CurrentUser(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"IP": c.ClientIP(), "IP": c.ClientIP(),
"User": AuthLogic.GinUser(c).Valid, "Valid": AuthLogic.GinUser(c).Valid,
"UserEmail": AuthLogic.GinUser(c).Token.Email,
"UserId": AuthLogic.GinUser(c).Token.Sub,
}) })
} }

View File

@ -15,13 +15,12 @@ import (
type AuthLogic struct { type AuthLogic struct {
} }
const AnonymousUser = "anonymous"
var ( var (
ErrNotValidToken = errors.New("无效的 JWT 令牌。") ErrNotValidToken = errors.New("无效的 JWT 令牌")
ErrJWTFormatError = errors.New("JWT 格式错误。") ErrJWTFormatError = errors.New("JWT 格式错误")
ErrNotBearerType = errors.New("不是 Bearer 类型。") ErrNotBearerType = errors.New("不是 Bearer 类型")
ErrEmptyResponse = errors.New("我们的服务器返回了空请求,可能某些环节出了问题。") ErrEmptyResponse = errors.New("我们的服务器返回了空请求,可能某些环节出了问题")
ErrTokenError = errors.New("token 类型错误")
config = *providers.MustGet[providers.GlobalConfig]() config = *providers.MustGet[providers.GlobalConfig]()
logger = *providers.MustGet[zap.Logger]() logger = *providers.MustGet[zap.Logger]()
) )
@ -30,8 +29,8 @@ func NewAuthLogic() *AuthLogic {
return &AuthLogic{} return &AuthLogic{}
} }
func (a *AuthLogic) GinMiddlewareAuth(c *gin.Context) (*types.User, error) { func (a *AuthLogic) GinMiddlewareAuth(tokenType types.JWTTokenTypes, c *gin.Context) (*types.User, error) {
var sub = AnonymousUser var sub = consts.AnonymousUser
var jwtIdToken = &types.User{} var jwtIdToken = &types.User{}
if config.DebugMode.Enable { if config.DebugMode.Enable {
@ -56,18 +55,31 @@ func (a *AuthLogic) GinMiddlewareAuth(c *gin.Context) (*types.User, error) {
token, err := jwks.ParseJWT(authSplit[1]) token, err := jwks.ParseJWT(authSplit[1])
if err != nil { if err != nil {
return nil, ErrJWTFormatError return nil, ErrNotValidToken
} }
sub, err = token.Claims.GetSubject() sub, err = token.Claims.GetSubject()
if err != nil { if err != nil {
return nil, ErrNotValidToken return nil, ErrNotValidToken
} }
// 如果 token.Header 中没有 typ
if token.Header["typ"] == "" {
return nil, ErrEmptyResponse
}
// 验证 token 类型
if tokenType != "" && tokenType.String() != token.Header["typ"] {
return nil, ErrTokenError
}
jwtIdToken.Valid = true
err = mapstructure.Decode(token.Claims, &jwtIdToken.Token) err = mapstructure.Decode(token.Claims, &jwtIdToken.Token)
if err != nil { if err != nil {
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
return nil, nil return nil, nil
} }
} }
return jwtIdToken, nil return jwtIdToken, nil

View File

@ -4,13 +4,14 @@ import (
"framework_v2/consts" "framework_v2/consts"
"framework_v2/internal/logic" "framework_v2/internal/logic"
"framework_v2/internal/providers/helper" "framework_v2/internal/providers/helper"
"framework_v2/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
) )
func ValidateUser(c *gin.Context) { func RequireJWTIDToken(c *gin.Context) {
auth := logic.NewAuthLogic() auth := logic.NewAuthLogic()
user, err := auth.GinMiddlewareAuth(c) user, err := auth.GinMiddlewareAuth(types.JWTIDToken, c)
if err != nil { if err != nil {
c.Abort() c.Abort()

View File

@ -27,3 +27,14 @@ type User struct {
Token UserTokenInfo Token UserTokenInfo
Valid bool Valid bool
} }
type JWTTokenTypes string
const (
JWTAccessToken JWTTokenTypes = "access_token"
JWTIDToken JWTTokenTypes = "id_token"
)
func (jwtTokenTypes JWTTokenTypes) String() string {
return string(jwtTokenTypes)
}