package grpc import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging" "github.com/mitchellh/mapstructure" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "leafdev.top/leaf/rag/consts" "leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/models" ) func JwtAuth(ctx context.Context) (context.Context, error) { tokenString, err := auth.AuthFromMD(ctx, "bearer") if err != nil { return nil, err } sub := consts.AnonymousUser var jwtIdToken = models.User{} if config.DebugMode.Enable { jwtIdToken.Token.Sub = sub } else { token, err := jwks.ParseJWT(tokenString) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err) } sub, err = token.Claims.GetSubject() if err != nil { return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err) } // 如果 token.Header 中没有 typ if token.Header["typ"] == "" { return nil, consts.ErrEmptyResponse } // must id token if token.Header["typ"] != "id_token" { return nil, consts.ErrTokenError } jwtIdToken.Valid = true err = mapstructure.Decode(token.Claims, &jwtIdToken.Token) if err != nil { logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) return nil, err } } ctx = logging.InjectFields(ctx, logging.Fields{consts.AuthMiddlewareKey, sub}) return context.WithValue(ctx, consts.AuthMiddlewareKey, &jwtIdToken), nil }