package logic import ( "github.com/gin-gonic/gin" "github.com/mitchellh/mapstructure" "go.uber.org/zap" "leafdev.top/leaf/rag/consts" "leafdev.top/leaf/rag/internal/providers" "leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/models" "strings" ) type AuthLogic struct { } var ( config = *providers.MustGet[providers.GlobalConfig]() logger = *providers.MustGet[zap.Logger]() ) func NewAuthLogic() *AuthLogic { return &AuthLogic{} } func (a *AuthLogic) GinMiddlewareAuth(tokenType models.JWTTokenTypes, c *gin.Context) (*models.User, error) { var sub = consts.AnonymousUser var jwtIdToken = &models.User{} if config.DebugMode.Enable { jwtIdToken.Token.Sub = sub jwtIdToken.Valid = true return jwtIdToken, nil } else { authorization := c.Request.Header.Get(consts.AuthHeader) if authorization == "" { return nil, consts.ErrJWTFormatError } authSplit := strings.Split(authorization, " ") if len(authSplit) != 2 { return nil, consts.ErrJWTFormatError } if authSplit[0] != consts.AuthPrefix { return nil, consts.ErrNotBearerType } token, err := jwks.ParseJWT(authSplit[1]) if err != nil { return nil, consts.ErrNotValidToken } sub, err = token.Claims.GetSubject() if err != nil { return nil, consts.ErrNotValidToken } // 如果 token.Header 中没有 typ if token.Header["typ"] == "" { return nil, consts.ErrEmptyResponse } // 验证 token 类型 if tokenType != "" && tokenType.String() != token.Header["typ"] { return nil, consts.ErrTokenError } jwtIdToken.Valid = true err = mapstructure.Decode(token.Claims, &jwtIdToken.Token) if err != nil { logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) return nil, nil } } return jwtIdToken, nil } func (a *AuthLogic) GinUser(c *gin.Context) *models.User { user, _ := c.Get(consts.AuthMiddlewareKey) return user.(*models.User) } func GetUserId(ctx *gin.Context) string { logic := AuthLogic{} return logic.GinUser(ctx).Token.Sub }