rag/internal/cmd/rpc.go

102 lines
2.5 KiB
Go
Raw Normal View History

2024-06-13 01:16:48 +00:00
package cmd
import (
2024-07-15 13:08:48 +00:00
"context"
2024-06-13 01:16:48 +00:00
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
2024-07-15 13:08:48 +00:00
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
2024-06-13 01:16:48 +00:00
"github.com/spf13/cobra"
"google.golang.org/grpc"
2024-07-15 13:08:48 +00:00
"google.golang.org/grpc/credentials/insecure"
2024-06-13 01:16:48 +00:00
"google.golang.org/grpc/reflection"
2024-07-15 13:08:48 +00:00
"leafdev.top/leaf/rag/api/library"
libraryGw "leafdev.top/leaf/rag/api/library"
2024-07-14 15:58:23 +00:00
grpc2 "leafdev.top/leaf/rag/internal/middleware/grpc"
"leafdev.top/leaf/rag/internal/providers/jwks"
2024-07-15 13:08:48 +00:00
"leafdev.top/leaf/rag/internal/services/libraryService"
2024-06-13 01:16:48 +00:00
"net"
2024-07-15 17:09:07 +00:00
"net/http"
"strings"
"sync"
2024-06-13 01:16:48 +00:00
)
var rpcCommand = &cobra.Command{
2024-07-14 09:44:49 +00:00
Use: "rpc",
2024-06-13 01:16:48 +00:00
Run: func(cmd *cobra.Command, args []string) {
2024-06-15 16:55:25 +00:00
jwks.InitJwksRefresh()
2024-07-15 13:08:48 +00:00
StartGRPC()
2024-06-13 01:16:48 +00:00
},
}
2024-07-15 13:08:48 +00:00
func StartGRPC() {
2024-07-14 09:44:49 +00:00
if config.ListenAddr.GRPC == "" {
config.ListenAddr.GRPC = "0.0.0.0:8081"
2024-06-13 01:16:48 +00:00
}
2024-07-15 17:09:07 +00:00
if config.ListenAddr.HTTP == "" {
config.ListenAddr.HTTP = "0.0.0.0:8080"
}
2024-06-13 01:16:48 +00:00
2024-07-14 09:44:49 +00:00
lis, err := net.Listen("tcp", config.ListenAddr.GRPC)
2024-06-13 01:16:48 +00:00
if err != nil {
2024-07-15 17:09:07 +00:00
panic("GRPC failed to listen: " + err.Error())
2024-06-13 01:16:48 +00:00
}
2024-07-15 17:09:07 +00:00
logger.Info("GRPC Server listening at " + config.ListenAddr.GRPC)
2024-06-13 01:16:48 +00:00
var opts = []grpc.ServerOption{
grpc.ChainUnaryInterceptor(
2024-06-15 16:55:25 +00:00
logging.UnaryServerInterceptor(grpc2.ZapLogInterceptor()),
auth.UnaryServerInterceptor(grpc2.JwtAuth),
2024-06-13 01:16:48 +00:00
),
grpc.ChainStreamInterceptor(
2024-06-15 16:55:25 +00:00
logging.StreamServerInterceptor(grpc2.ZapLogInterceptor()),
auth.StreamServerInterceptor(grpc2.JwtAuth),
2024-06-13 01:16:48 +00:00
),
}
grpcServer := grpc.NewServer(opts...)
reflection.Register(grpcServer)
2024-07-15 13:08:48 +00:00
library.RegisterLibraryServiceServer(grpcServer, libraryService.LibraryService{})
2024-06-13 01:16:48 +00:00
2024-07-15 17:09:07 +00:00
var wg = sync.WaitGroup{}
2024-06-13 01:16:48 +00:00
2024-07-15 17:09:07 +00:00
wg.Add(1)
// 同时启动 grpc 和 http
go func() {
err = grpcServer.Serve(lis)
if err != nil {
panic(err)
}
2024-07-15 13:08:48 +00:00
2024-07-15 17:09:07 +00:00
defer wg.Done()
}()
wg.Add(1)
go func() {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer wg.Done()
2024-07-15 13:08:48 +00:00
2024-07-15 17:09:07 +00:00
mux := runtime.NewServeMux()
clientOpts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
err = libraryGw.RegisterLibraryServiceHandlerFromEndpoint(ctx, mux, "127.0.0.1:"+getPortFromAddr(config.ListenAddr.GRPC), clientOpts)
if err != nil {
panic(err)
}
logger.Info("GRPC Gateway listening at " + config.ListenAddr.HTTP)
err = http.ListenAndServe(config.ListenAddr.HTTP, mux)
if err != nil {
panic(err)
return
}
}()
wg.Wait()
}
2024-07-15 13:08:48 +00:00
2024-07-15 17:09:07 +00:00
func getPortFromAddr(addr string) string {
return addr[strings.LastIndex(addr, ":")+1:]
2024-06-13 01:16:48 +00:00
}