leaf-library-3/internal/api/grpc/interceptor/auth.go

150 lines
3.3 KiB
Go
Raw Normal View History

2024-12-05 17:44:29 +00:00
package interceptor
import (
"context"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"google.golang.org/grpc"
"leafdev.top/Leaf/leaf-library-3/internal/base/conf"
"leafdev.top/Leaf/leaf-library-3/internal/base/logger"
"leafdev.top/Leaf/leaf-library-3/internal/consts"
"leafdev.top/Leaf/leaf-library-3/internal/schema"
auth2 "leafdev.top/Leaf/leaf-library-3/internal/service/auth"
)
var ignoreAuthApis = map[string]bool{
// 反射
"/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo": true,
// 业务 API
"/api.v1.WorkspaceService/List": true,
}
type Auth struct {
authService *auth2.Service
logger *logger.Logger
config *conf.Config
}
func NewAuth(
authService *auth2.Service,
logger *logger.Logger,
config *conf.Config,
) *Auth {
return &Auth{
authService: authService,
logger: logger,
config: config,
}
}
func (a *Auth) notRequireAuth(fullMethodName string) bool {
var b = ignoreAuthApis[fullMethodName]
if a.config.Debug.Enabled {
if b {
a.logger.Sugar.Debugf("[GRPC Auth] Ignore auth for Method: %s", fullMethodName)
} else {
a.logger.Sugar.Debugf("[GRPC Auth] Require auth for Method: %s", fullMethodName)
}
}
return b
}
func (a *Auth) authCtx(ctx context.Context) (context.Context, error) {
var tokenString string
var err error
tokenString, err = auth.AuthFromMD(ctx, "bearer")
if err != nil {
// 如果是调试模式,就不处理报错,并且继续执行
if a.config.Debug.Enabled {
tokenString = ""
a.logger.Sugar.Debugf("[GRPC Auth] error, %s", err)
} else {
return nil, err
}
}
token, err := a.authService.AuthFromToken(schema.JWTIDToken, tokenString)
if err != nil {
return nil, err
}
if !token.Valid {
return nil, consts.ErrNotValidToken
}
ctx = logging.InjectFields(ctx, logging.Fields{consts.AuthMiddlewareKey, token.Token.Sub})
ctx = context.WithValue(ctx, consts.AuthMiddlewareKey, token)
return ctx, nil
}
func (a *Auth) UnaryJWTAuth() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if a.notRequireAuth(info.FullMethod) {
return handler(ctx, req)
}
ctx, err := a.authCtx(ctx)
if err != nil {
return nil, err
}
result, err := handler(ctx, req)
if err != nil {
return nil, err
}
return result, err
}
}
func (a *Auth) StreamJWTAuth() grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
var ctx = ss.Context()
if a.notRequireAuth(info.FullMethod) {
return handler(srv, ss)
}
ctx, err := a.authCtx(ctx)
if err != nil {
return err
}
err = handler(srv, ss)
if err != nil {
return err
}
return nil
}
}
//
//func (a *Auth) JwtAuth(ctx context.Context) (context.Context, error) {
// tokenString, err := auth.AuthFromMD(ctx, "bearer")
// if err != nil {
// return nil, err
// }
//
// token, err := a.authService.AuthFromToken(schema.JWTIDToken, tokenString)
// if err != nil {
// return nil, err
// }
//
// if !token.Valid {
// return nil, consts.ErrNotValidToken
// }
//
// ctx = logging.InjectFields(ctx, logging.Fields{consts.AuthMiddlewareKey, token.Token.Sub})
//
// return context.WithValue(ctx, consts.AuthMiddlewareKey, token), nil
//}