diff --git a/Makefile b/Makefile index 88f3760..6ad240b 100644 --- a/Makefile +++ b/Makefile @@ -15,4 +15,5 @@ swag: hash-migration: atlas migrate hash --dir "file://internal/migrations" proto: + buf dep update buf generate \ No newline at end of file diff --git a/api/library/library.proto b/api/library/library.proto index 50001ce..941cc94 100644 --- a/api/library/library.proto +++ b/api/library/library.proto @@ -5,7 +5,7 @@ package LibraryService; option go_package = "leafdev.top/leaf/rag/proto/library"; import "google/api/annotations.proto"; -import "google.golang.org/grpc/health/grpc_health_v1"; +//import "google.golang.org/grpc/health/grpc_health_v1"; service LibraryService { rpc ListLibrary(ListLibraryRequest) returns (ListLibraryResponse) { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 3892533..a84856c 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -26,7 +26,7 @@ func Execute() { rootCmd.AddCommand(workerCommand) rootCmd.AddCommand(scheduleCommand) - rootCmd.AddCommand(httpCommand) + //rootCmd.AddCommand(httpCommand) if config.DebugMode.Enable { rootCmd.AddCommand(migrateCommand) diff --git a/internal/cmd/rpc.go b/internal/cmd/rpc.go index 1c2a78c..1a733f8 100644 --- a/internal/cmd/rpc.go +++ b/internal/cmd/rpc.go @@ -15,6 +15,9 @@ import ( "leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/internal/services/libraryService" "net" + "net/http" + "strings" + "sync" ) var rpcCommand = &cobra.Command{ @@ -30,12 +33,15 @@ func StartGRPC() { if config.ListenAddr.GRPC == "" { config.ListenAddr.GRPC = "0.0.0.0:8081" } + if config.ListenAddr.HTTP == "" { + config.ListenAddr.HTTP = "0.0.0.0:8080" + } lis, err := net.Listen("tcp", config.ListenAddr.GRPC) if err != nil { - panic("failed to listen: " + err.Error()) + panic("GRPC failed to listen: " + err.Error()) } - logger.Info("Server listening at " + config.ListenAddr.GRPC) + logger.Info("GRPC Server listening at " + config.ListenAddr.GRPC) var opts = []grpc.ServerOption{ grpc.ChainUnaryInterceptor( @@ -53,24 +59,43 @@ func StartGRPC() { library.RegisterLibraryServiceServer(grpcServer, libraryService.LibraryService{}) - err = grpcServer.Serve(lis) - if err != nil { - panic(err) - return - } + var wg = sync.WaitGroup{} - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() + wg.Add(1) + // 同时启动 grpc 和 http + go func() { + err = grpcServer.Serve(lis) + if err != nil { + panic(err) + } - mux := runtime.NewServeMux() - clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} - err := gw.RegisterYourServiceHandlerFromEndpoint(ctx, mux, config.ListenAddr.GRPC, clientOpts) - if err != nil { - return err - } + defer wg.Done() + }() - // Start HTTP server (and proxy calls to gRPC server endpoint) - return http.ListenAndServe(":8081", mux) + wg.Add(1) + go func() { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer wg.Done() + mux := runtime.NewServeMux() + clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} + err = libraryGw.RegisterLibraryServiceHandlerFromEndpoint(ctx, mux, "127.0.0.1:"+getPortFromAddr(config.ListenAddr.GRPC), clientOpts) + if err != nil { + panic(err) + } + logger.Info("GRPC Gateway listening at " + config.ListenAddr.HTTP) + err = http.ListenAndServe(config.ListenAddr.HTTP, mux) + if err != nil { + panic(err) + return + } + }() + + wg.Wait() +} + +func getPortFromAddr(addr string) string { + return addr[strings.LastIndex(addr, ":")+1:] } diff --git a/internal/middleware/grpc/auth.go b/internal/middleware/grpc/auth.go index c0128df..68926d7 100644 --- a/internal/middleware/grpc/auth.go +++ b/internal/middleware/grpc/auth.go @@ -7,6 +7,7 @@ import ( "github.com/mitchellh/mapstructure" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "leafdev.top/leaf/rag/consts" "leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/models" ) @@ -17,8 +18,8 @@ func JwtAuth(ctx context.Context) (context.Context, error) { return nil, err } - sub := "anonymous" - var jwtIdToken *models.User + sub := consts.AnonymousUser + var jwtIdToken = &models.User{} if config.DebugMode.Enable { jwtIdToken.Token.Sub = sub @@ -32,6 +33,18 @@ func JwtAuth(ctx context.Context) (context.Context, error) { return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err) } + // 如果 token.Header 中没有 typ + if token.Header["typ"] == "" { + return nil, consts.ErrEmptyResponse + } + + // must id token + if token.Header["typ"] != "id_token" { + return nil, consts.ErrTokenError + } + + jwtIdToken.Valid = true + err = mapstructure.Decode(token.Claims, &jwtIdToken) if err != nil { logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) diff --git a/internal/providers/jwks/jwks.go b/internal/providers/jwks/jwks.go index 1d0f054..ca811e7 100644 --- a/internal/providers/jwks/jwks.go +++ b/internal/providers/jwks/jwks.go @@ -35,10 +35,9 @@ func RefreshJWKS() { } func ParseJWT(jwtB64 string) (*jwt.Token, error) { - //if Jwks.Keyfunc == nil { - // Logger.Error(ErrJWKSNotInitialized.Error()) - // return nil, ErrJWKSNotInitialized - //} + if Jwks.Keyfunc == nil { + return nil, ErrJWKSNotInitialized + } token, err := jwt.Parse(jwtB64, Jwks.Keyfunc)