diff --git a/cmd/mithril/node/node.go b/cmd/mithril/node/node.go index f6a6b013..5735b265 100644 --- a/cmd/mithril/node/node.go +++ b/cmd/mithril/node/node.go @@ -18,6 +18,7 @@ import ( "github.com/Overclock-Validator/mithril/pkg/accountsdb" "github.com/Overclock-Validator/mithril/pkg/arena" + "github.com/Overclock-Validator/mithril/pkg/grpc" "github.com/Overclock-Validator/mithril/pkg/mlog" "github.com/Overclock-Validator/mithril/pkg/replay" "github.com/Overclock-Validator/mithril/pkg/rpcserver" @@ -88,6 +89,10 @@ var ( borrowedAccountArenaSize uint64 rpcPort int + + // grpc flags + grpcPort int + enableGrpc bool ) func init() { @@ -114,6 +119,8 @@ func init() { Verifier.Flags().BoolVar(&sbpf.UsePool, "use-pool", true, "Disable to allocate fresh slices") Verifier.Flags().StringVar(&snapshotDlPath, "download-snapshot", "", "Path to download snapshot to") Verifier.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.") + Verifier.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.") + Verifier.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.") // flags for RPC catchup mode CatchupRpc.Flags().StringVarP(&outputDir, "out", "o", "", "Output path for writing AccountsDB data to") @@ -132,6 +139,8 @@ func init() { CatchupRpc.Flags().StringVar(&blockDir, "blockdir", "/tmp/blocks", "Path containing slot.json files") CatchupRpc.Flags().StringVar(&scratchDir, "scratchdir", "/tmp", "Path for downloads (e.g. snapshots) and other temp state") CatchupRpc.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.") + CatchupRpc.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.") + CatchupRpc.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.") // flags for Overcast catchup mode CatchupOvercast.Flags().StringVarP(&outputDir, "out", "o", "", "Output path for writing AccountsDB data to") @@ -151,6 +160,8 @@ func init() { CatchupOvercast.Flags().StringVar(&blockDir, "blockdir", "/tmp/blocks", "Path containing slot.json files") CatchupOvercast.Flags().StringVar(&scratchDir, "scratchdir", "/tmp", "Path for downloads (e.g. snapshots) and other temp state") CatchupOvercast.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.") + CatchupOvercast.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.") + CatchupOvercast.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.") } func runVerifier(c *cobra.Command, args []string) { @@ -164,7 +175,6 @@ func runVerifier(c *cobra.Command, args []string) { if !loadFromSnapshot && !loadFromAccountsDb && snapshotDlPath == "" { klog.Fatalf("must specify either to load from a snapshot, or load from an existing AccountsDB, or download a snapshot.") } - var err error var accountsDbDir string var accountsDb *accountsdb.AccountsDb @@ -185,6 +195,7 @@ func runVerifier(c *cobra.Command, args []string) { defer pprof.StopCPUProfile() } + if rpcEndpoint == "" { rpcEndpoint = "https://api.mainnet-beta.solana.com" } @@ -289,6 +300,19 @@ func runVerifier(c *cobra.Command, args []string) { rpcServer.Start() mlog.Log.Infof("started RPC server on port %d", rpcPort) } + if enableGrpc { + + if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 { + grpcPort = 50051 + } + + grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil) + err := grpcServer.Start() + if err != nil { + klog.Fatalf("failed to start gRPC server: %v", err) + } + mlog.Log.Infof("started gRPC server on port %d", grpcPort) + } replay.ReplayBlocks(c.Context(), accountsDb, accountsDbDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), false, false, dbgOpts, metricsWriter, rpcServer) mlog.Log.Infof("done replaying, closing DB") @@ -369,6 +393,20 @@ func runRpcCatchup(c *cobra.Command, args []string) { mlog.Log.Infof("started RPC server on port %d", rpcPort) } + if enableGrpc { + + if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 { + grpcPort = 50051 + } + + grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil) + err := grpcServer.Start() + if err != nil { + klog.Fatalf("failed to start gRPC server: %v", err) + } + mlog.Log.Infof("started gRPC server on port %d", grpcPort) + } + replay.ReplayBlocks(c.Context(), accountsDb, outputDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), true, false, dbgOpts, metricsWriter, rpcServer) mlog.Log.Infof("done replaying, closing DB") accountsDb.CloseDb() @@ -452,6 +490,21 @@ func runOvercastCatchup(c *cobra.Command, args []string) { mlog.Log.Infof("started RPC server on port %d", rpcPort) } + if enableGrpc { + + if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 { + grpcPort = 50051 + } + + grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil) + err := grpcServer.Start() + if err != nil { + klog.Fatalf("failed to start gRPC server: %v", err) + } + mlog.Log.Infof("started gRPC server on port %d", grpcPort) + } + + replay.ReplayBlocks(c.Context(), accountsDb, outputDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), true, true, dbgOpts, metricsWriter, rpcServer) mlog.Log.Infof("done replaying, closing DB") accountsDb.CloseDb() diff --git a/pkg/grpc/geyser_service.go b/pkg/grpc/geyser_service.go new file mode 100644 index 00000000..01111847 --- /dev/null +++ b/pkg/grpc/geyser_service.go @@ -0,0 +1,264 @@ +package grpc + +import ( + "context" + "slices" + + "google.golang.org/grpc" + + b "github.com/Overclock-Validator/mithril/pkg/block" + pb "github.com/rpcpool/yellowstone-grpc/examples/golang/proto" + + "github.com/gagliardetto/solana-go" +) + +type GeyserService struct { + pb.UnimplementedGeyserServer + + // Channel for sending Blocks to client + blockChan chan *b.Block +} + + +func NewGeyserService() *GeyserService { + return &GeyserService{ + blockChan: make(chan *b.Block, 100), + } +} + +// GetBlockChannel returns the channel where Block messages can be sent +// External code can send blocks to this channel, and they will be filtered and forwarded to clients +func (s *GeyserService) GetBlockChannel() chan<- *b.Block { + return s.blockChan +} + +func (s *GeyserService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PongResponse, error) { + return &pb.PongResponse{ + Count: req.Count, + }, nil +} + +func (s *GeyserService) Subscribe(stream grpc.BidiStreamingServer[pb.SubscribeRequest, pb.SubscribeUpdate]) error { + ctx := stream.Context() + + // Channel for receiving SubscribeRequest messages from the client + requestChan := make(chan *pb.SubscribeRequest, 10) + + // Channel for sending SubscribeUpdate messages to the client + updateChan := make(chan *pb.SubscribeUpdate, 100) + + // Channel to signal when we're done + done := make(chan error, 1) + + // Main loop: process requests and filter updates + var activeRequest *pb.SubscribeRequest + // Goroutine to receive SubscribeRequest messages from the client + go func() { + for { + req, err := stream.Recv() + if err != nil { + done <- err + close(requestChan) + return + } + requestChan <- req + } + }() + + // Goroutine to send SubscribeUpdate messages to the client + go func() { + for { + select { + case update := <-updateChan: + if err := stream.Send(update); err != nil { + done <- err + return + } + case <-ctx.Done(): + done <- ctx.Err() + return + } + } + }() + + // Goroutine to send blocks to the client + go func() { + for { + select { + case req, ok := <-s.blockChan: + if !ok { + return + } + block := s.convertBlockToSubscribeUpdate(req) + if err := s.sendUpdate(block, activeRequest, updateChan, ctx); err != nil { + done <- err + return + } + case <-ctx.Done(): + return + } + } + }() + + + + + for { + select { + case req, ok := <-requestChan: + if !ok { + return nil + } + + // Handle ping/pong + if req.Ping != nil { + pong := &pb.SubscribeUpdate{ + UpdateOneof: &pb.SubscribeUpdate_Pong{ + Pong: &pb.SubscribeUpdatePong{ + Id: req.Ping.Id, + }, + }, + } + if err := s.sendUpdate(pong, activeRequest, updateChan, ctx); err != nil { + return err + } + continue + } + + activeRequest = req + + + case err := <-done: + return err + + case <-ctx.Done(): + return ctx.Err() + } + } + +} + +func (s *GeyserService) extractFilterIDs(req *pb.SubscribeRequest) []string { + var filterIDs []string + + if req != nil && len(req.Blocks) > 0 { + for id := range req.Blocks { + filterIDs = append(filterIDs, id) + } + } + + return filterIDs +} + +// sendUpdate sends an update to the client channel with proper error handling +func (s *GeyserService) sendUpdate(update *pb.SubscribeUpdate, req *pb.SubscribeRequest, updateChan chan<- *pb.SubscribeUpdate, ctx context.Context) error { + if req != nil { + update.Filters = s.extractFilterIDs(req) + } + select { + case updateChan <- update: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// shouldSendBlock checks if a Block should be sent based on the active request filters +// This filters the block BEFORE converting to SubscribeUpdate for better performance +func (s *GeyserService) shouldSendBlock(block *b.Block, req *pb.SubscribeRequest) bool { + if req == nil { + return true + } + + if len(req.Blocks) == 0 { + return false + } + + return s.matchesBlockFilterForBlock(block, req) +} + +// matchesBlockFilterForBlock checks if a Block matches the request filters +func (s *GeyserService) matchesBlockFilterForBlock(block *b.Block, req *pb.SubscribeRequest) bool { + for _, filter := range req.Blocks { + if len(filter.AccountInclude) > 0 { + // Check if any account in block.UpdatedAccts matches + for _, pubkey := range block.UpdatedAccts { + pubkeyStr := pubkey.String() + if slices.Contains(filter.AccountInclude, pubkeyStr) { + return true + } + } + + // Check EpochUpdatedAccts if available + for _, account := range block.EpochUpdatedAccts { + if account != nil { + pubkeyStr := account.Key.String() + if slices.Contains(filter.AccountInclude, pubkeyStr) { + return true + } + } + } + + // Check ParentEpochUpdatedAccts if available + for _, account := range block.ParentEpochUpdatedAccts { + if account != nil { + pubkeyStr := account.Key.String() + if slices.Contains(filter.AccountInclude, pubkeyStr) { + return true + } + } + } + + // If account_include filter exists but no match found, don't send + return false + } + + // If no specific filters, match all blocks + return true + } + + return true +} + +// convertBlockToSubscribeUpdate converts a Block to SubscribeUpdate with Block update +func (s *GeyserService) convertBlockToSubscribeUpdate(block *b.Block) *pb.SubscribeUpdate { + blockUpdate := &pb.SubscribeUpdateBlock{ + Slot: block.Slot, + Blockhash: solana.Hash(block.Blockhash).String(), + ParentSlot: block.ParentSlot, + ParentBlockhash: solana.Hash(block.LastBlockhash).String(), + ExecutedTransactionCount: uint64(len(block.Transactions)), + UpdatedAccountCount: uint64(len(block.UpdatedAccts)), + EntriesCount: uint64(len(block.Entries)), + } + + // Set block height if available + if block.BlockHeight > 0 { + blockUpdate.BlockHeight = &pb.BlockHeight{ + BlockHeight: block.BlockHeight, + } + } + + // Set block time if available + if block.UnixTimestamp > 0 { + blockUpdate.BlockTime = &pb.UnixTimestamp{ + Timestamp: block.UnixTimestamp, + } + } + + // Convert accounts if available (from UpdatedAccts or EpochUpdatedAccts) + // Note: This is a simplified conversion - you may need to fetch full account data + accounts := make([]*pb.SubscribeUpdateAccountInfo, 0, len(block.UpdatedAccts)) + for _, pubkey := range block.UpdatedAccts { + accounts = append(accounts, &pb.SubscribeUpdateAccountInfo{ + Pubkey: pubkey[:], + }) + } + blockUpdate.Accounts = accounts + + return &pb.SubscribeUpdate{ + UpdateOneof: &pb.SubscribeUpdate_Block{ + Block: blockUpdate, + }, + } +} diff --git a/pkg/grpc/grpc.go b/pkg/grpc/grpc.go new file mode 100644 index 00000000..4e78c69f --- /dev/null +++ b/pkg/grpc/grpc.go @@ -0,0 +1,62 @@ +package grpc + +import ( + "fmt" + "log" + "net" + + b "github.com/Overclock-Validator/mithril/pkg/block" + pb "github.com/rpcpool/yellowstone-grpc/examples/golang/proto" + "google.golang.org/grpc" +) + +type GrpcServer struct { + server *grpc.Server + listener net.Listener + port uint16 + geyserService *GeyserService +} + +func NewGrpcServer(port uint16, opts []grpc.ServerOption) *GrpcServer { + lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + server := grpc.NewServer(opts...) + + return &GrpcServer{ + port: port, + listener: lis, + server: server, + geyserService: NewGeyserService(), + } +} + +func (s *GrpcServer) Start() error { + pb.RegisterGeyserServer(s.server, s.geyserService) + return s.server.Serve(s.listener) +} + +func (s *GrpcServer) GracefulStop() { + s.server.GracefulStop() +} + +// GetBlockChannel returns the channel where Block messages can be sent. +// External packages can send blocks to this channel, and they will be filtered and forwarded to clients. +// +// Example usage from another package: +// +// server := grpc.NewGrpcServer(50051, nil) +// blockChan := server.GetBlockChannel() +// +// // Send a block (will be converted to SubscribeUpdate and filtered) +// block := &block.Block{ +// Slot: 12345, +// Blockhash: [32]byte{...}, +// // ... other fields +// } +// blockChan <- block +func (s *GrpcServer) GetBlockChannel() chan<- *b.Block { + return s.geyserService.GetBlockChannel() +} + diff --git a/pkg/grpc/grpc_test.go b/pkg/grpc/grpc_test.go new file mode 100644 index 00000000..03062f71 --- /dev/null +++ b/pkg/grpc/grpc_test.go @@ -0,0 +1,124 @@ +package grpc + +import ( + "context" + "fmt" + "testing" + "time" + + pb "github.com/rpcpool/yellowstone-grpc/examples/golang/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +func TestGrpcServer(t *testing.T) { + // Use a test port + port := uint16(50052) + + // Create a new gRPC server + server := NewGrpcServer(port, nil) + + // Start server in a goroutine + errChan := make(chan error, 1) + go func() { + errChan <- server.Start() + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + // Try to connect to the server to verify it's running + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := grpc.NewClient( + fmt.Sprintf("localhost:%d", port), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + // Verify connection is ready + if !conn.WaitForStateChange(ctx, conn.GetState()) { + t.Log("Connection state changed (server is running)") + } + + // Stop the server + server.server.Stop() + + // Wait for server to stop + select { + case err := <-errChan: + if err != nil && err != grpc.ErrServerStopped { + t.Logf("Server stopped with error: %v", err) + } + case <-time.After(2 * time.Second): + t.Log("Server stopped") + } +} + +func TestPingPong(t *testing.T) { + // Use a test port + port := uint16(50053) + + // Create a new gRPC server (Geyser service is automatically created and registered) + server := NewGrpcServer(port, nil) + + // Start server in a goroutine + errChan := make(chan error, 1) + go func() { + errChan <- server.Start() + }() + + // Give the server a moment to start + time.Sleep(100 * time.Millisecond) + + // Connect to the server + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := grpc.NewClient( + fmt.Sprintf("localhost:%d", port), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + // Create Geyser client + client := pb.NewGeyserClient(conn) + + // Test Ping with count = 42 + pingReq := &pb.PingRequest{ + Count: 42, + } + + pongResp, err := client.Ping(ctx, pingReq) + if err != nil { + t.Fatalf("Ping failed: %v", err) + } + + // Verify the response + if pongResp.Count != 42 { + t.Errorf("Expected count 42, got %d", pongResp.Count) + } + + t.Logf("Ping/Pong successful: sent count=%d, received count=%d", pingReq.Count, pongResp.Count) + + // Stop the server + server.server.Stop() + + // Wait for server to stop + select { + case err := <-errChan: + if err != nil && err != grpc.ErrServerStopped { + t.Logf("Server stopped with error: %v", err) + } + case <-time.After(2 * time.Second): + t.Log("Server stopped") + } +} +