增加 gRPC gateway

This commit is contained in:
ivamp 2024-07-16 01:09:07 +08:00
parent 7ed130a16e
commit 959b44aa32
6 changed files with 64 additions and 26 deletions

View File

@ -15,4 +15,5 @@ swag:
hash-migration:
atlas migrate hash --dir "file://internal/migrations"
proto:
buf dep update
buf generate

View File

@ -5,7 +5,7 @@ package LibraryService;
option go_package = "leafdev.top/leaf/rag/proto/library";
import "google/api/annotations.proto";
import "google.golang.org/grpc/health/grpc_health_v1";
//import "google.golang.org/grpc/health/grpc_health_v1";
service LibraryService {
rpc ListLibrary(ListLibraryRequest) returns (ListLibraryResponse) {

View File

@ -26,7 +26,7 @@ func Execute() {
rootCmd.AddCommand(workerCommand)
rootCmd.AddCommand(scheduleCommand)
rootCmd.AddCommand(httpCommand)
//rootCmd.AddCommand(httpCommand)
if config.DebugMode.Enable {
rootCmd.AddCommand(migrateCommand)

View File

@ -15,6 +15,9 @@ import (
"leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/internal/services/libraryService"
"net"
"net/http"
"strings"
"sync"
)
var rpcCommand = &cobra.Command{
@ -30,12 +33,15 @@ func StartGRPC() {
if config.ListenAddr.GRPC == "" {
config.ListenAddr.GRPC = "0.0.0.0:8081"
}
if config.ListenAddr.HTTP == "" {
config.ListenAddr.HTTP = "0.0.0.0:8080"
}
lis, err := net.Listen("tcp", config.ListenAddr.GRPC)
if err != nil {
panic("failed to listen: " + err.Error())
panic("GRPC failed to listen: " + err.Error())
}
logger.Info("Server listening at " + config.ListenAddr.GRPC)
logger.Info("GRPC Server listening at " + config.ListenAddr.GRPC)
var opts = []grpc.ServerOption{
grpc.ChainUnaryInterceptor(
@ -53,24 +59,43 @@ func StartGRPC() {
library.RegisterLibraryServiceServer(grpcServer, libraryService.LibraryService{})
var wg = sync.WaitGroup{}
wg.Add(1)
// 同时启动 grpc 和 http
go func() {
err = grpcServer.Serve(lis)
if err != nil {
panic(err)
}
defer wg.Done()
}()
wg.Add(1)
go func() {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer wg.Done()
mux := runtime.NewServeMux()
clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
err = libraryGw.RegisterLibraryServiceHandlerFromEndpoint(ctx, mux, "127.0.0.1:"+getPortFromAddr(config.ListenAddr.GRPC), clientOpts)
if err != nil {
panic(err)
}
logger.Info("GRPC Gateway listening at " + config.ListenAddr.HTTP)
err = http.ListenAndServe(config.ListenAddr.HTTP, mux)
if err != nil {
panic(err)
return
}
}()
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
mux := runtime.NewServeMux()
clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
err := gw.RegisterYourServiceHandlerFromEndpoint(ctx, mux, config.ListenAddr.GRPC, clientOpts)
if err != nil {
return err
}
// Start HTTP server (and proxy calls to gRPC server endpoint)
return http.ListenAndServe(":8081", mux)
wg.Wait()
}
func getPortFromAddr(addr string) string {
return addr[strings.LastIndex(addr, ":")+1:]
}

View File

@ -7,6 +7,7 @@ import (
"github.com/mitchellh/mapstructure"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"leafdev.top/leaf/rag/consts"
"leafdev.top/leaf/rag/internal/providers/jwks"
"leafdev.top/leaf/rag/models"
)
@ -17,8 +18,8 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
return nil, err
}
sub := "anonymous"
var jwtIdToken *models.User
sub := consts.AnonymousUser
var jwtIdToken = &models.User{}
if config.DebugMode.Enable {
jwtIdToken.Token.Sub = sub
@ -32,6 +33,18 @@ func JwtAuth(ctx context.Context) (context.Context, error) {
return nil, status.Errorf(codes.Unauthenticated, "unable get token sub: %v", err)
}
// 如果 token.Header 中没有 typ
if token.Header["typ"] == "" {
return nil, consts.ErrEmptyResponse
}
// must id token
if token.Header["typ"] != "id_token" {
return nil, consts.ErrTokenError
}
jwtIdToken.Valid = true
err = mapstructure.Decode(token.Claims, &jwtIdToken)
if err != nil {
logger.Error("Failed to map token claims to JwtIDToken struct.\nError: " + err.Error())

View File

@ -35,10 +35,9 @@ func RefreshJWKS() {
}
func ParseJWT(jwtB64 string) (*jwt.Token, error) {
//if Jwks.Keyfunc == nil {
// Logger.Error(ErrJWKSNotInitialized.Error())
// return nil, ErrJWKSNotInitialized
//}
if Jwks.Keyfunc == nil {
return nil, ErrJWKSNotInitialized
}
token, err := jwt.Parse(jwtB64, Jwks.Keyfunc)