rag/internal/middleware/grpc/auth.go

67 lines
1.6 KiB
Go
Raw Normal View History

2024-06-15 16:55:25 +00:00
package grpc
2024-06-13 01:16:48 +00:00
import (
"context"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
2024-07-15 17:09:07 +00:00
"leafdev.top/leaf/rag/consts"
2024-07-14 15:58:23 +00:00
"leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/models"
2024-06-13 01:16:48 +00:00
)
func JwtAuth(ctx context.Context) (context.Context, error) {
tokenString, err := auth.AuthFromMD(ctx, "bearer")
if err != nil {
return nil, err
}
2024-07-15 17:09:07 +00:00
sub := consts.AnonymousUser
2024-07-15 17:48:05 +00:00
var jwtIdToken = models.User{}
2024-06-13 01:16:48 +00:00
2024-07-14 09:44:49 +00:00
if config.DebugMode.Enable {
2024-06-16 06:00:31 +00:00
jwtIdToken.Token.Sub = sub
2024-06-13 01:16:48 +00:00
} else {
2024-06-15 16:55:25 +00:00
token, err := jwks.ParseJWT(tokenString)
2024-06-13 01:16:48 +00:00
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err)
}
sub, err = token.Claims.GetSubject()
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err)
}
2024-07-15 17:09:07 +00:00
// 如果 token.Header 中没有 typ
if token.Header["typ"] == "" {
return nil, consts.ErrEmptyResponse
}
// must id token
if token.Header["typ"] != "id_token" {
return nil, consts.ErrTokenError
}
jwtIdToken.Valid = true
2024-07-15 17:48:05 +00:00
err = mapstructure.Decode(token.Claims, &jwtIdToken.Token)
2024-06-13 01:16:48 +00:00
if err != nil {
2024-07-14 09:44:49 +00:00
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())
2024-06-13 01:16:48 +00:00
return nil, err
}
}
2024-07-15 17:48:05 +00:00
ctx = logging.InjectFields(ctx, logging.Fields{consts.AuthMiddlewareKey, sub})
2024-06-13 01:16:48 +00:00
2024-07-15 17:48:05 +00:00
return context.WithValue(ctx, consts.AuthMiddlewareKey, &jwtIdToken), nil
2024-06-13 01:16:48 +00:00
}
func Gate(ctx context.Context) (context.Context, error) {
// 检查调用的方法
//
return ctx, nil
}