增加 gRPC gateway

This commit is contained in:
ivamp 2024-07-16 01:09:07 +08:00
parent 7ed130a16e
commit 959b44aa32
6 changed files with 64 additions and 26 deletions

View File

@ -15,4 +15,5 @@ swag:
hash-migration: hash-migration:
atlas migrate hash --dir "file://internal/migrations" atlas migrate hash --dir "file://internal/migrations"
proto: proto:
buf dep update
buf generate buf generate

View File

@ -5,7 +5,7 @@ package LibraryService;
option go_package = "leafdev.top/leaf/rag/proto/library"; option go_package = "leafdev.top/leaf/rag/proto/library";
import "google/api/annotations.proto"; import "google/api/annotations.proto";
import "google.golang.org/grpc/health/grpc_health_v1"; //import "google.golang.org/grpc/health/grpc_health_v1";
service LibraryService { service LibraryService {
rpc ListLibrary(ListLibraryRequest) returns (ListLibraryResponse) { rpc ListLibrary(ListLibraryRequest) returns (ListLibraryResponse) {

View File

@ -26,7 +26,7 @@ func Execute() {
rootCmd.AddCommand(workerCommand) rootCmd.AddCommand(workerCommand)
rootCmd.AddCommand(scheduleCommand) rootCmd.AddCommand(scheduleCommand)
rootCmd.AddCommand(httpCommand) //rootCmd.AddCommand(httpCommand)
if config.DebugMode.Enable { if config.DebugMode.Enable {
rootCmd.AddCommand(migrateCommand) rootCmd.AddCommand(migrateCommand)

View File

@ -15,6 +15,9 @@ import (
"leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/internal/services/libraryService" "leafdev.top/leaf/rag/internal/services/libraryService"
"net" "net"
"net/http"
"strings"
"sync"
) )
var rpcCommand = &cobra.Command{ var rpcCommand = &cobra.Command{
@ -30,12 +33,15 @@ func StartGRPC() {
if config.ListenAddr.GRPC == "" { if config.ListenAddr.GRPC == "" {
config.ListenAddr.GRPC = "0.0.0.0:8081" config.ListenAddr.GRPC = "0.0.0.0:8081"
} }
if config.ListenAddr.HTTP == "" {
config.ListenAddr.HTTP = "0.0.0.0:8080"
}
lis, err := net.Listen("tcp", config.ListenAddr.GRPC) lis, err := net.Listen("tcp", config.ListenAddr.GRPC)
if err != nil { if err != nil {
panic("failed to listen: " + err.Error()) panic("GRPC failed to listen: " + err.Error())
} }
logger.Info("Server listening at " + config.ListenAddr.GRPC) logger.Info("GRPC Server listening at " + config.ListenAddr.GRPC)
var opts = []grpc.ServerOption{ var opts = []grpc.ServerOption{
grpc.ChainUnaryInterceptor( grpc.ChainUnaryInterceptor(
@ -53,24 +59,43 @@ func StartGRPC() {
library.RegisterLibraryServiceServer(grpcServer, libraryService.LibraryService{}) library.RegisterLibraryServiceServer(grpcServer, libraryService.LibraryService{})
var wg = sync.WaitGroup{}
wg.Add(1)
// 同时启动 grpc 和 http
go func() {
err = grpcServer.Serve(lis) err = grpcServer.Serve(lis)
if err != nil {
panic(err)
}
defer wg.Done()
}()
wg.Add(1)
go func() {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer wg.Done()
mux := runtime.NewServeMux()
clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
err = libraryGw.RegisterLibraryServiceHandlerFromEndpoint(ctx, mux, "127.0.0.1:"+getPortFromAddr(config.ListenAddr.GRPC), clientOpts)
if err != nil {
panic(err)
}
logger.Info("GRPC Gateway listening at " + config.ListenAddr.HTTP)
err = http.ListenAndServe(config.ListenAddr.HTTP, mux)
if err != nil { if err != nil {
panic(err) panic(err)
return return
} }
}()
ctx := context.Background() wg.Wait()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
mux := runtime.NewServeMux()
clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
err := gw.RegisterYourServiceHandlerFromEndpoint(ctx, mux, config.ListenAddr.GRPC, clientOpts)
if err != nil {
return err
} }
// Start HTTP server (and proxy calls to gRPC server endpoint) func getPortFromAddr(addr string) string {
return http.ListenAndServe(":8081", mux) return addr[strings.LastIndex(addr, ":")+1:]
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"leafdev.top/leaf/rag/consts"
"leafdev.top/leaf/rag/internal/providers/jwks" "leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/models" "leafdev.top/leaf/rag/models"
) )
@ -17,8 +18,8 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
return nil, err return nil, err
} }
sub := "anonymous" sub := consts.AnonymousUser
var jwtIdToken *models.User var jwtIdToken = &models.User{}
if config.DebugMode.Enable { if config.DebugMode.Enable {
jwtIdToken.Token.Sub = sub jwtIdToken.Token.Sub = sub
@ -32,6 +33,18 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err) return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err)
} }
// 如果 token.Header 中没有 typ
if token.Header["typ"] == "" {
return nil, consts.ErrEmptyResponse
}
// must id token
if token.Header["typ"] != "id_token" {
return nil, consts.ErrTokenError
}
jwtIdToken.Valid = true
err = mapstructure.Decode(token.Claims, &jwtIdToken) err = mapstructure.Decode(token.Claims, &jwtIdToken)
if err != nil { if err != nil {
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error()) logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())

View File

@ -35,10 +35,9 @@ func RefreshJWKS() {
} }
func ParseJWT(jwtB64 string) (*jwt.Token, error) { func ParseJWT(jwtB64 string) (*jwt.Token, error) {
//if Jwks.Keyfunc == nil { if Jwks.Keyfunc == nil {
// Logger.Error(ErrJWKSNotInitialized.Error()) return nil, ErrJWKSNotInitialized
// return nil, ErrJWKSNotInitialized }
//}
token, err := jwt.Parse(jwtB64, Jwks.Keyfunc) token, err := jwt.Parse(jwtB64, Jwks.Keyfunc)