Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion cmd/mithril/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/Overclock-Validator/mithril/pkg/accountsdb"
"github.com/Overclock-Validator/mithril/pkg/arena"
"github.com/Overclock-Validator/mithril/pkg/grpc"
"github.com/Overclock-Validator/mithril/pkg/mlog"
"github.com/Overclock-Validator/mithril/pkg/replay"
"github.com/Overclock-Validator/mithril/pkg/rpcserver"
Expand Down Expand Up @@ -88,6 +89,10 @@ var (
borrowedAccountArenaSize uint64

rpcPort int

// grpc flags
grpcPort int
enableGrpc bool
)

func init() {
Expand All @@ -114,6 +119,8 @@ func init() {
Verifier.Flags().BoolVar(&sbpf.UsePool, "use-pool", true, "Disable to allocate fresh slices")
Verifier.Flags().StringVar(&snapshotDlPath, "download-snapshot", "", "Path to download snapshot to")
Verifier.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.")
Verifier.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.")
Verifier.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.")

// flags for RPC catchup mode
CatchupRpc.Flags().StringVarP(&outputDir, "out", "o", "", "Output path for writing AccountsDB data to")
Expand All @@ -132,6 +139,8 @@ func init() {
CatchupRpc.Flags().StringVar(&blockDir, "blockdir", "/tmp/blocks", "Path containing slot.json files")
CatchupRpc.Flags().StringVar(&scratchDir, "scratchdir", "/tmp", "Path for downloads (e.g. snapshots) and other temp state")
CatchupRpc.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.")
CatchupRpc.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.")
CatchupRpc.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.")

// flags for Overcast catchup mode
CatchupOvercast.Flags().StringVarP(&outputDir, "out", "o", "", "Output path for writing AccountsDB data to")
Expand All @@ -151,6 +160,8 @@ func init() {
CatchupOvercast.Flags().StringVar(&blockDir, "blockdir", "/tmp/blocks", "Path containing slot.json files")
CatchupOvercast.Flags().StringVar(&scratchDir, "scratchdir", "/tmp", "Path for downloads (e.g. snapshots) and other temp state")
CatchupOvercast.Flags().IntVar(&rpcPort, "rpc-server-port", 0, "RPC server port. Default off.")
CatchupOvercast.Flags().BoolVar(&enableGrpc, "enable-grpc", false, "Enable gRPC server. Default off.")
CatchupOvercast.Flags().IntVar(&grpcPort, "grpc-port", 50051, "gRPC server port. Default 50051.")
}

func runVerifier(c *cobra.Command, args []string) {
Expand All @@ -164,7 +175,6 @@ func runVerifier(c *cobra.Command, args []string) {
if !loadFromSnapshot && !loadFromAccountsDb && snapshotDlPath == "" {
klog.Fatalf("must specify either to load from a snapshot, or load from an existing AccountsDB, or download a snapshot.")
}

var err error
var accountsDbDir string
var accountsDb *accountsdb.AccountsDb
Expand All @@ -185,6 +195,7 @@ func runVerifier(c *cobra.Command, args []string) {
defer pprof.StopCPUProfile()
}


if rpcEndpoint == "" {
rpcEndpoint = "https://api.mainnet-beta.solana.com"
}
Expand Down Expand Up @@ -289,6 +300,19 @@ func runVerifier(c *cobra.Command, args []string) {
rpcServer.Start()
mlog.Log.Infof("started RPC server on port %d", rpcPort)
}
if enableGrpc {

if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 {
grpcPort = 50051
}

grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil)
err := grpcServer.Start()
if err != nil {
klog.Fatalf("failed to start gRPC server: %v", err)
}
mlog.Log.Infof("started gRPC server on port %d", grpcPort)
}

replay.ReplayBlocks(c.Context(), accountsDb, accountsDbDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), false, false, dbgOpts, metricsWriter, rpcServer)
mlog.Log.Infof("done replaying, closing DB")
Expand Down Expand Up @@ -369,6 +393,20 @@ func runRpcCatchup(c *cobra.Command, args []string) {
mlog.Log.Infof("started RPC server on port %d", rpcPort)
}

if enableGrpc {

if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 {
grpcPort = 50051
}

grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil)
err := grpcServer.Start()
if err != nil {
klog.Fatalf("failed to start gRPC server: %v", err)
}
mlog.Log.Infof("started gRPC server on port %d", grpcPort)
}

replay.ReplayBlocks(c.Context(), accountsDb, outputDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), true, false, dbgOpts, metricsWriter, rpcServer)
mlog.Log.Infof("done replaying, closing DB")
accountsDb.CloseDb()
Expand Down Expand Up @@ -452,6 +490,21 @@ func runOvercastCatchup(c *cobra.Command, args []string) {
mlog.Log.Infof("started RPC server on port %d", rpcPort)
}

if enableGrpc {

if grpcPort == 0 || grpcPort > 65535 || grpcPort < 0 {
grpcPort = 50051
}

grpcServer := grpc.NewGrpcServer(uint16(grpcPort), nil)
err := grpcServer.Start()
if err != nil {
klog.Fatalf("failed to start gRPC server: %v", err)
}
mlog.Log.Infof("started gRPC server on port %d", grpcPort)
}


replay.ReplayBlocks(c.Context(), accountsDb, outputDir, manifest, uint64(startSlot), uint64(endSlot), rpcEndpoint, blockDir, int(txParallelism), true, true, dbgOpts, metricsWriter, rpcServer)
mlog.Log.Infof("done replaying, closing DB")
accountsDb.CloseDb()
Expand Down
264 changes: 264 additions & 0 deletions pkg/grpc/geyser_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
package grpc

import (
"context"
"slices"

"google.golang.org/grpc"

b "github.com/Overclock-Validator/mithril/pkg/block"
pb "github.com/rpcpool/yellowstone-grpc/examples/golang/proto"

"github.com/gagliardetto/solana-go"
)

type GeyserService struct {
pb.UnimplementedGeyserServer

// Channel for sending Blocks to client
blockChan chan *b.Block
}


func NewGeyserService() *GeyserService {
return &GeyserService{
blockChan: make(chan *b.Block, 100),
}
}

// GetBlockChannel returns the channel where Block messages can be sent
// External code can send blocks to this channel, and they will be filtered and forwarded to clients
func (s *GeyserService) GetBlockChannel() chan<- *b.Block {
return s.blockChan
}

func (s *GeyserService) Ping(ctx context.Context, req *pb.PingRequest) (*pb.PongResponse, error) {
return &pb.PongResponse{
Count: req.Count,
}, nil
}

func (s *GeyserService) Subscribe(stream grpc.BidiStreamingServer[pb.SubscribeRequest, pb.SubscribeUpdate]) error {
ctx := stream.Context()

// Channel for receiving SubscribeRequest messages from the client
requestChan := make(chan *pb.SubscribeRequest, 10)

// Channel for sending SubscribeUpdate messages to the client
updateChan := make(chan *pb.SubscribeUpdate, 100)

// Channel to signal when we're done
done := make(chan error, 1)

// Main loop: process requests and filter updates
var activeRequest *pb.SubscribeRequest
// Goroutine to receive SubscribeRequest messages from the client
go func() {
for {
req, err := stream.Recv()
if err != nil {
done <- err
close(requestChan)
return
}
requestChan <- req
}
}()

// Goroutine to send SubscribeUpdate messages to the client
go func() {
for {
select {
case update := <-updateChan:
if err := stream.Send(update); err != nil {
done <- err
return
}
case <-ctx.Done():
done <- ctx.Err()
return
}
}
}()

// Goroutine to send blocks to the client
go func() {
for {
select {
case req, ok := <-s.blockChan:
if !ok {
return
}
block := s.convertBlockToSubscribeUpdate(req)
if err := s.sendUpdate(block, activeRequest, updateChan, ctx); err != nil {
done <- err
return
}
case <-ctx.Done():
return
}
}
}()




for {
select {
case req, ok := <-requestChan:
if !ok {
return nil
}

// Handle ping/pong
if req.Ping != nil {
pong := &pb.SubscribeUpdate{
UpdateOneof: &pb.SubscribeUpdate_Pong{
Pong: &pb.SubscribeUpdatePong{
Id: req.Ping.Id,
},
},
}
if err := s.sendUpdate(pong, activeRequest, updateChan, ctx); err != nil {
return err
}
continue
}

activeRequest = req


case err := <-done:
return err

case <-ctx.Done():
return ctx.Err()
}
}

}

func (s *GeyserService) extractFilterIDs(req *pb.SubscribeRequest) []string {
var filterIDs []string

if req != nil && len(req.Blocks) > 0 {
for id := range req.Blocks {
filterIDs = append(filterIDs, id)
}
}

return filterIDs
}

// sendUpdate sends an update to the client channel with proper error handling
func (s *GeyserService) sendUpdate(update *pb.SubscribeUpdate, req *pb.SubscribeRequest, updateChan chan<- *pb.SubscribeUpdate, ctx context.Context) error {
if req != nil {
update.Filters = s.extractFilterIDs(req)
}
select {
case updateChan <- update:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

// shouldSendBlock checks if a Block should be sent based on the active request filters
// This filters the block BEFORE converting to SubscribeUpdate for better performance
func (s *GeyserService) shouldSendBlock(block *b.Block, req *pb.SubscribeRequest) bool {
if req == nil {
return true
}

if len(req.Blocks) == 0 {
return false
}

return s.matchesBlockFilterForBlock(block, req)
}

// matchesBlockFilterForBlock checks if a Block matches the request filters
func (s *GeyserService) matchesBlockFilterForBlock(block *b.Block, req *pb.SubscribeRequest) bool {
for _, filter := range req.Blocks {
if len(filter.AccountInclude) > 0 {
// Check if any account in block.UpdatedAccts matches
for _, pubkey := range block.UpdatedAccts {
pubkeyStr := pubkey.String()
if slices.Contains(filter.AccountInclude, pubkeyStr) {
return true
}
}

// Check EpochUpdatedAccts if available
for _, account := range block.EpochUpdatedAccts {
if account != nil {
pubkeyStr := account.Key.String()
if slices.Contains(filter.AccountInclude, pubkeyStr) {
return true
}
}
}

// Check ParentEpochUpdatedAccts if available
for _, account := range block.ParentEpochUpdatedAccts {
if account != nil {
pubkeyStr := account.Key.String()
if slices.Contains(filter.AccountInclude, pubkeyStr) {
return true
}
}
}

// If account_include filter exists but no match found, don't send
return false
}

// If no specific filters, match all blocks
return true
}

return true
}

// convertBlockToSubscribeUpdate converts a Block to SubscribeUpdate with Block update
func (s *GeyserService) convertBlockToSubscribeUpdate(block *b.Block) *pb.SubscribeUpdate {
blockUpdate := &pb.SubscribeUpdateBlock{
Slot: block.Slot,
Blockhash: solana.Hash(block.Blockhash).String(),
ParentSlot: block.ParentSlot,
ParentBlockhash: solana.Hash(block.LastBlockhash).String(),
ExecutedTransactionCount: uint64(len(block.Transactions)),
UpdatedAccountCount: uint64(len(block.UpdatedAccts)),
EntriesCount: uint64(len(block.Entries)),
}

// Set block height if available
if block.BlockHeight > 0 {
blockUpdate.BlockHeight = &pb.BlockHeight{
BlockHeight: block.BlockHeight,
}
}

// Set block time if available
if block.UnixTimestamp > 0 {
blockUpdate.BlockTime = &pb.UnixTimestamp{
Timestamp: block.UnixTimestamp,
}
}

// Convert accounts if available (from UpdatedAccts or EpochUpdatedAccts)
// Note: This is a simplified conversion - you may need to fetch full account data
accounts := make([]*pb.SubscribeUpdateAccountInfo, 0, len(block.UpdatedAccts))
for _, pubkey := range block.UpdatedAccts {
accounts = append(accounts, &pb.SubscribeUpdateAccountInfo{
Pubkey: pubkey[:],
})
}
blockUpdate.Accounts = accounts

return &pb.SubscribeUpdate{
UpdateOneof: &pb.SubscribeUpdate_Block{
Block: blockUpdate,
},
}
}
Loading