rag/internal/middleware/grpc/auth.go
2024-06-16 00:55:25 +08:00

50 lines
1.3 KiB
Go

package grpc
import (
"context"
"framework_v2/internal/app/config"
"framework_v2/internal/app/jwks"
"framework_v2/internal/app/logger"
"framework_v2/internal/app/user"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func JwtAuth(ctx context.Context) (context.Context, error) {
tokenString, err := auth.AuthFromMD(ctx, "bearer")
if err != nil {
return nil, err
}
sub := "anonymous"
var jwtIdToken *user.UserTokenInfo
if config.Config.DebugMode.Enable {
jwtIdToken = &user.UserTokenInfo{
Sub: sub,
}
} else {
token, err := jwks.ParseJWT(tokenString)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err)
}
sub, err = token.Claims.GetSubject()
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err)
}
err = mapstructure.Decode(token.Claims, &jwtIdToken)
if err != nil {
logger.Logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
return nil, err
}
}
ctx = logging.InjectFields(ctx, logging.Fields{"auth.sub", sub})
return context.WithValue(ctx, "auth", jwtIdToken), nil
}