From ee1e339007f86b44d5ebff11a83d3882470d3880 Mon Sep 17 00:00:00 2001 From: Twilight Date: Sun, 14 Jul 2024 18:14:01 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E8=BF=9B=20jwt=20=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/bootstrap/router.go | 2 +- internal/handlers/controllers/user/main.go | 6 +++-- internal/logic/auth.go | 30 +++++++++++++++------- internal/middleware/http/jwt.go | 5 ++-- types/user.go | 11 ++++++++ 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/internal/bootstrap/router.go b/internal/bootstrap/router.go index 7627a37..fd15c80 100644 --- a/internal/bootstrap/router.go +++ b/internal/bootstrap/router.go @@ -10,5 +10,5 @@ import ( func InitApiRoutes() { var router = *providers.MustGet[gin.Engine]() - router.GET("/", http.MiddlewareJSONResponse, http.ValidateUser, user.CurrentUser) + router.GET("/", http.MiddlewareJSONResponse, http.RequireJWTIDToken, user.CurrentUser) } diff --git a/internal/handlers/controllers/user/main.go b/internal/handlers/controllers/user/main.go index 9e8b162..ee18ab5 100644 --- a/internal/handlers/controllers/user/main.go +++ b/internal/handlers/controllers/user/main.go @@ -9,7 +9,9 @@ var AuthLogic = logic.NewAuthLogic() func CurrentUser(c *gin.Context) { c.JSON(200, gin.H{ - "IP": c.ClientIP(), - "User": AuthLogic.GinUser(c).Valid, + "IP": c.ClientIP(), + "Valid": AuthLogic.GinUser(c).Valid, + "UserEmail": AuthLogic.GinUser(c).Token.Email, + "UserId": AuthLogic.GinUser(c).Token.Sub, }) } diff --git a/internal/logic/auth.go b/internal/logic/auth.go index 890acfa..6d65444 100644 --- a/internal/logic/auth.go +++ b/internal/logic/auth.go @@ -15,13 +15,12 @@ import ( type AuthLogic struct { } -const AnonymousUser = "anonymous" - var ( - ErrNotValidToken = errors.New("无效的 JWT 令牌。") - ErrJWTFormatError = errors.New("JWT 格式错误。") - ErrNotBearerType = errors.New("不是 Bearer 类型。") - ErrEmptyResponse = errors.New("我们的服务器返回了空请求,可能某些环节出了问题。") + ErrNotValidToken = errors.New("无效的 JWT 令牌") + ErrJWTFormatError = errors.New("JWT 格式错误") + ErrNotBearerType = errors.New("不是 Bearer 类型") + ErrEmptyResponse = errors.New("我们的服务器返回了空请求,可能某些环节出了问题") + ErrTokenError = errors.New("token 类型错误") config = *providers.MustGet[providers.GlobalConfig]() logger = *providers.MustGet[zap.Logger]() ) @@ -30,8 +29,8 @@ func NewAuthLogic() *AuthLogic { return &AuthLogic{} } -func (a *AuthLogic) GinMiddlewareAuth(c *gin.Context) (*types.User, error) { - var sub = AnonymousUser +func (a *AuthLogic) GinMiddlewareAuth(tokenType types.JWTTokenTypes, c *gin.Context) (*types.User, error) { + var sub = consts.AnonymousUser var jwtIdToken = &types.User{} if config.DebugMode.Enable { @@ -56,18 +55,31 @@ func (a *AuthLogic) GinMiddlewareAuth(c *gin.Context) (*types.User, error) { token, err := jwks.ParseJWT(authSplit[1]) if err != nil { - return nil, ErrJWTFormatError + return nil, ErrNotValidToken } sub, err = token.Claims.GetSubject() if err != nil { return nil, ErrNotValidToken } + // 如果 token.Header 中没有 typ + if token.Header["typ"] == "" { + return nil, ErrEmptyResponse + } + + // 验证 token 类型 + if tokenType != "" && tokenType.String() != token.Header["typ"] { + return nil, ErrTokenError + } + + jwtIdToken.Valid = true + err = mapstructure.Decode(token.Claims, &jwtIdToken.Token) if err != nil { logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) return nil, nil } + } return jwtIdToken, nil diff --git a/internal/middleware/http/jwt.go b/internal/middleware/http/jwt.go index 2d0ff9f..de23633 100644 --- a/internal/middleware/http/jwt.go +++ b/internal/middleware/http/jwt.go @@ -4,13 +4,14 @@ import ( "framework_v2/consts" "framework_v2/internal/logic" "framework_v2/internal/providers/helper" + "framework_v2/types" "github.com/gin-gonic/gin" "net/http" ) -func ValidateUser(c *gin.Context) { +func RequireJWTIDToken(c *gin.Context) { auth := logic.NewAuthLogic() - user, err := auth.GinMiddlewareAuth(c) + user, err := auth.GinMiddlewareAuth(types.JWTIDToken, c) if err != nil { c.Abort() diff --git a/types/user.go b/types/user.go index fca2bb9..ae7b657 100644 --- a/types/user.go +++ b/types/user.go @@ -27,3 +27,14 @@ type User struct { Token UserTokenInfo Valid bool } + +type JWTTokenTypes string + +const ( + JWTAccessToken JWTTokenTypes = "access_token" + JWTIDToken JWTTokenTypes = "id_token" +) + +func (jwtTokenTypes JWTTokenTypes) String() string { + return string(jwtTokenTypes) +}