diff --git a/cmd/mithril/node/node.go b/cmd/mithril/node/node.go index d0c7a5f9..353f6c41 100644 --- a/cmd/mithril/node/node.go +++ b/cmd/mithril/node/node.go @@ -16,6 +16,7 @@ import ( "runtime/pprof" "strconv" "strings" + "sync/atomic" "syscall" "time" @@ -83,9 +84,11 @@ var ( loadFromSnapshot bool loadFromAccountsDb bool - bootstrapMode string // "auto", "snapshot", or "accountsdb" + bootstrapMode string // "auto", "snapshot", "local-snapshot", or "accountsdb" snapshotArchivePath string incrementalSnapshotFilename string + localFullSnapshotPath string // Path to existing full snapshot file (for local-snapshot mode) + localIncrSnapshotPath string // Path to existing incremental snapshot file (optional) accountsPath string scratchDirectory string rpcEndpoints []string @@ -167,7 +170,9 @@ func init() { // flags for 'mithril run' (live full node mode) // [bootstrap] section flags - Run.Flags().StringVar(&bootstrapMode, "bootstrap-mode", "auto", "Bootstrap mode: 'auto' (use AccountsDB if exists, else snapshot), 'accountsdb' (require existing), 'snapshot' (rebuild from snapshot), 'new-snapshot' (always download fresh)") + Run.Flags().StringVar(&bootstrapMode, "bootstrap-mode", "auto", "Bootstrap mode: 'auto', 'accountsdb', 'snapshot', 'new-snapshot', or 'local-snapshot'") + Run.Flags().StringVar(&localFullSnapshotPath, "full-snapshot", "", "Path to existing full snapshot file (triggers local-snapshot mode)") + Run.Flags().StringVar(&localIncrSnapshotPath, "incremental-snapshot", "", "Path to existing incremental snapshot file (optional, used with --full-snapshot)") // [ledger] section flags Run.Flags().StringVarP(&accountsPath, "accounts-path", "o", "", "Output path for writing AccountsDB data to") @@ -1087,6 +1092,23 @@ func runLive(c *cobra.Command, args []string) { // Use configured snapshot directory (storage.snapshots / snapshot.download_path), not scratch snapshotDownloadPath := snapshotDlPath + // Auto-detect local-snapshot mode when --full-snapshot is provided + if localFullSnapshotPath != "" { + if bootstrapMode != "auto" && bootstrapMode != "local-snapshot" { + mlog.Log.Infof("WARNING: --full-snapshot provided, overriding --bootstrap-mode=%s with local-snapshot", bootstrapMode) + } + bootstrapMode = "local-snapshot" + // Validate the snapshot file exists + if _, err := os.Stat(localFullSnapshotPath); os.IsNotExist(err) { + klog.Fatalf("full snapshot file not found: %s", localFullSnapshotPath) + } + if localIncrSnapshotPath != "" { + if _, err := os.Stat(localIncrSnapshotPath); os.IsNotExist(err) { + klog.Fatalf("incremental snapshot file not found: %s", localIncrSnapshotPath) + } + } + } + // Prune old history entries if needed (keeps last 100) if accountsPath != "" { if err := state.PruneHistory(accountsPath); err != nil { @@ -1242,6 +1264,70 @@ func runLive(c *cobra.Command, args []string) { // Record bootstrap in history state.RecordBootstrap(accountsPath, manifest.Bank.Slot, "", replay.CurrentRunID, getVersion(), getCommit()) + case "local-snapshot": + // Mode: Build from user-provided snapshot files (no download) + if localFullSnapshotPath == "" { + klog.Fatalf("mode=local-snapshot requires --full-snapshot flag") + } + if accountsPath == "" { + klog.Fatalf("mode=local-snapshot requires --accounts-path flag") + } + + mlog.Log.Infof("mode=local-snapshot: building AccountsDB from local snapshot files") + mlog.Log.Infof(" full snapshot: %s", localFullSnapshotPath) + if localIncrSnapshotPath != "" { + mlog.Log.Infof(" incremental snapshot: %s", localIncrSnapshotPath) + } else { + mlog.Log.Infof(" incremental snapshot: (will fetch from network)") + } + + // Clean previous AccountsDB + if mithrilState != nil { + state.RecordRebuild(accountsPath, mithrilState.LastSlot, mithrilState.LastBankhash, getVersion(), getCommit(), "local-snapshot mode") + } else { + state.RecordRebuild(accountsPath, 0, "", getVersion(), getCommit(), "local-snapshot mode (no prior state)") + } + mlog.Log.Infof("cleaning up previous AccountsDB artifacts in %s", accountsPath) + snapshot.CleanAccountsDbDir(accountsPath) + + // Build from local snapshot + // BuildAccountsDb accepts an optional incremental as second param + var incrPath string + if localIncrSnapshotPath != "" { + // Both full and incremental provided + mlog.Log.Infof("building AccountsDB from full + incremental snapshots") + incrPath = localIncrSnapshotPath + } else { + // Only full provided - try to fetch incremental + mlog.Log.Infof("building AccountsDB from full snapshot (will attempt to fetch incremental)") + fullSlot := parseSlotFromSnapshotName(filepath.Base(localFullSnapshotPath)) + if fullSlot > 0 && snapshotDownloadPath != "" { + fetchedIncr, incrErr := downloadIncrementalForFullSnapshot(ctx, rpcEndpoints, fullSlot, snapshotDownloadPath) + if incrErr == nil && fetchedIncr != "" { + mlog.Log.Infof("found matching incremental snapshot at %s", fetchedIncr) + incrPath = fetchedIncr + } else if incrErr != nil { + mlog.Log.Infof("could not fetch incremental: %v (building from full snapshot only)", incrErr) + } + } + } + + accountsDb, manifest, err = snapshot.BuildAccountsDb(ctx, localFullSnapshotPath, incrPath, accountsPath) + if err != nil { + klog.Fatalf("failed to build AccountsDB from local snapshot: %v", err) + } + + // Write state file + var snapshotEpoch uint64 + if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { + snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) + } + mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) + if err := mithrilState.Save(accountsPath); err != nil { + mlog.Log.Errorf("failed to save state file: %v", err) + } + state.RecordBootstrap(accountsPath, manifest.Bank.Slot, "", replay.CurrentRunID, getVersion(), getCommit()) + case "auto": fallthrough default: @@ -1773,9 +1859,11 @@ func printStartupInfo(commandName string) { case "auto": bootstrapDesc = "use existing AccountsDB if valid, else download snapshot" case "snapshot": - bootstrapDesc = "rebuild from local snapshot" + bootstrapDesc = "rebuild from snapshot (reuse if fresh)" case "new-snapshot": bootstrapDesc = "download fresh snapshot from network" + case "local-snapshot": + bootstrapDesc = "build from user-provided snapshot file" case "accountsdb": bootstrapDesc = "require existing AccountsDB" default: @@ -1968,7 +2056,8 @@ func detectExistingAccountsDB(path string) (bool, uint64) { return true, manifest.Bank.Slot } -// detectExistingSnapshots finds snapshot files in the given directory +// detectExistingSnapshots finds snapshot files in the given directory. +// It skips .partial files (incomplete downloads from crashed runs). func detectExistingSnapshots(dir string) []snapshotInfo { if dir == "" { return nil @@ -1987,6 +2076,11 @@ func detectExistingSnapshots(dir string) []snapshotInfo { } name := entry.Name() + // Skip partial downloads (incomplete files from crashed runs) + if strings.HasSuffix(name, ".partial") { + continue + } + // Full snapshot: snapshot-{slot}-{hash}.tar.zst if len(name) > 9 && name[:9] == "snapshot-" && filepath.Ext(name) == ".zst" { slot := parseSlotFromSnapshotName(name) @@ -2065,6 +2159,30 @@ func parseSlotFromIncrementalName(name string) uint64 { return 0 } +// downloadIncrementalForFullSnapshot attempts to download an incremental snapshot +// that builds on the given full snapshot slot. +func downloadIncrementalForFullSnapshot(ctx context.Context, rpcEndpoints []string, fullSlot uint64, downloadPath string) (string, error) { + if downloadPath == "" { + return "", fmt.Errorf("no download path specified") + } + + // Query current slot to get reference + currentSlot, err := queryCurrentSlot(ctx, rpcEndpoints) + if err != nil { + return "", fmt.Errorf("could not query current slot: %w", err) + } + + mlog.Log.Infof("searching for incremental snapshot (full=%d, current=%d)", fullSlot, currentSlot) + + // Use snapshotdl to find and download an incremental + incrPath, _, _, err := snapshotdl.DownloadIncrementalSnapshot(rpcEndpoints, downloadPath, int(currentSlot), int(fullSlot)) + if err != nil { + return "", fmt.Errorf("failed to download incremental: %w", err) + } + + return incrPath, nil +} + // detectFreshSnapshot checks for an existing snapshot file within the freshness threshold. // Returns the snapshotInfo if found, nil otherwise. func detectFreshSnapshot(snapshotDir string, fullThreshold int, rpcEndpoints []string, ctx context.Context) *snapshotInfo { @@ -2141,7 +2259,7 @@ func buildFromExistingSnapshot(ctx context.Context, snap *snapshotInfo, snapshot // Create progress display for extract dp := progress.NewDualProgress() - accountsDb, manifest, err := snapshot.BuildAccountsDbWithIncr(ctx, fullSnapshotPath, snapshotDir, int(snap.slot), int(snap.slot), accountsPath, rpcEndpoints, blockstorePath, snapCfg, dp) + accountsDb, manifest, err := snapshot.BuildAccountsDbWithIncr(ctx, fullSnapshotPath, snapshotDir, int(snap.slot), int(snap.slot), accountsPath, rpcEndpoints, blockstorePath, snapCfg, dp, nil) if err != nil { return nil, nil, fmt.Errorf("failed to build AccountsDB from snapshot: %w", err) } @@ -2150,38 +2268,150 @@ func buildFromExistingSnapshot(ctx context.Context, snap *snapshotInfo, snapshot return accountsDb, manifest, nil } -// downloadAndBuildFromSnapshot finds, downloads, and builds AccountsDB from a snapshot +// downloadAndBuildFromSnapshot finds, downloads, and builds AccountsDB from a snapshot. +// Supports interactive source switching during download - press 'n' to try the next source. func downloadAndBuildFromSnapshot(ctx context.Context, rpcEndpoints []string, snapshotDownloadPath, accountsPath, blockstorePath string) (*accountsdb.AccountsDb, *snapshot.SnapshotManifest, error) { snapCfg := buildSnapshotConfig(rpcEndpoints) - fullSnapshotDlStart := time.Now() - fullSnapshotInfo, err := snapshotdl.GetSnapshotURLWithInfo(ctx, snapCfg) + + // Set logging info for detailed speed test log + snapCfg.LogDir = logDir + snapCfg.RunID = replay.CurrentRunID + + // Get all ranked snapshot sources (runs Stage 1 + Stage 2 testing) + sourceSelector, err := snapshotdl.GetRankedSnapshotSources(ctx, snapCfg) if err != nil { - return nil, nil, fmt.Errorf("error getting snapshot URL: %w", err) + return nil, nil, fmt.Errorf("error getting snapshot sources: %w", err) + } + defer sourceSelector.Close() + + // Get initial source + currentSource := sourceSelector.Current() + if currentSource == nil { + return nil, nil, fmt.Errorf("no snapshot sources available") } - fullSnapshotURL := fullSnapshotInfo.URL - fullSnapshotSlot := fullSnapshotInfo.Slot // Print a clean summary of the selected snapshot source progress.PrintSnapshotSourceSummary( - fullSnapshotInfo.NodeIP, - fullSnapshotInfo.Slot, - fullSnapshotInfo.ReferenceSlot, - fullSnapshotInfo.NodeVersion, - fullSnapshotInfo.SpeedMBs, - fullSnapshotInfo.RTTMs, - time.Since(fullSnapshotDlStart), + currentSource.NodeIP, + currentSource.Slot, + currentSource.ReferenceSlot, + currentSource.Version, + currentSource.SpeedMBs, + currentSource.RTTMs, + sourceSelector.SearchTime, ) // Create progress display for snapshot download and extract dp := progress.NewDualProgress() - accountsDb, manifest, err := snapshot.BuildAccountsDbWithIncr(ctx, fullSnapshotURL, snapshotDownloadPath, fullSnapshotSlot, fullSnapshotSlot, accountsPath, rpcEndpoints, blockstorePath, snapCfg, dp) - if err != nil { - return nil, nil, fmt.Errorf("failed to build AccountsDB from snapshot: %w", err) - } - mlog.Log.Infof("finished building AccountsDB") + // Track if source switch was requested + var sourceSwitchRequested atomic.Bool - return accountsDb, manifest, nil + // Try sources with interactive switching support + for { + // Check if parent context was cancelled + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } + + currentSource = sourceSelector.Current() + if currentSource == nil { + return nil, nil, fmt.Errorf("exhausted all %d snapshot sources", sourceSelector.TotalSources()) + } + + // Create a cancellable context for this download attempt + downloadCtx, cancelDownload := context.WithCancel(ctx) + + // Format source info for display + sourceInfo := fmt.Sprintf("Source %d/%d: %s (%.1f MB/s)", + sourceSelector.CurrentIndex()+1, + sourceSelector.TotalSources(), + currentSource.NodeIP, + currentSource.SpeedMBs, + ) + + // Enable source switching UI on progress bar + sourceSwitchRequested.Store(false) + dp.EnableSourceSwitching(sourceInfo, func() { + if sourceSelector.HasMore() { + sourceSwitchRequested.Store(true) + cancelDownload() + mlog.Log.Infof("User requested source switch - cancelling current download...") + } else { + mlog.Log.Infof("No more sources available to switch to") + } + }) + + // Attempt download from current source + accountsDb, manifest, err := snapshot.BuildAccountsDbWithIncr( + downloadCtx, + currentSource.URL, + snapshotDownloadPath, + currentSource.Slot, + currentSource.Slot, + accountsPath, + rpcEndpoints, + blockstorePath, + snapCfg, + dp, + sourceSelector, // Pass selector for cached incremental source lookup + ) + + // Disable source switching after this attempt + dp.DisableSourceSwitching() + cancelDownload() // Clean up context + + // Check results + if err == nil { + // Success! + mlog.Log.Infof("finished building AccountsDB") + return accountsDb, manifest, nil + } + + // Handle source switch request + if sourceSwitchRequested.Load() || (downloadCtx.Err() != nil && ctx.Err() == nil) { + // Source switch was requested or download was cancelled (but not parent ctx) + nextSource := sourceSelector.Next() + if nextSource == nil { + return nil, nil, fmt.Errorf("exhausted all %d snapshot sources after user-initiated switch", sourceSelector.TotalSources()) + } + + // Clean up any partial download + snapshot.CleanAccountsDbDir(accountsPath) + + // Update source info and continue + mlog.Log.Infof("Switching to source %d/%d: %s (%.1f MB/s)", + sourceSelector.CurrentIndex()+1, + sourceSelector.TotalSources(), + nextSource.NodeIP, + nextSource.SpeedMBs, + ) + continue + } + + // Actual error (not source switch) + // Check if parent context was cancelled + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } + + // Try next source automatically on error + mlog.Log.Errorf("Download from %s failed: %v", currentSource.NodeIP, err) + nextSource := sourceSelector.Next() + if nextSource == nil { + return nil, nil, fmt.Errorf("failed to download from all %d sources, last error: %w", sourceSelector.TotalSources(), err) + } + + // Clean up any partial download + snapshot.CleanAccountsDbDir(accountsPath) + + mlog.Log.Infof("Trying next source %d/%d: %s (%.1f MB/s)", + sourceSelector.CurrentIndex()+1, + sourceSelector.TotalSources(), + nextSource.NodeIP, + nextSource.SpeedMBs, + ) + } } // killExistingMithrilProcesses finds and kills any other running mithril processes. diff --git a/pkg/progress/progress.go b/pkg/progress/progress.go index bd841ffa..fc17af70 100644 --- a/pkg/progress/progress.go +++ b/pkg/progress/progress.go @@ -248,6 +248,13 @@ type DualProgress struct { output io.Writer useColor bool downloadTotal int64 // cached download total for ratio calculation + + // Source switching support + sourceSwitchEnabled bool + sourceInfo string // e.g., "Source 1/5: 192.168.1.1" + onSourceSwitch func() // callback when user requests source switch + keyboardStopCh chan struct{} + oldTermState *term.State // saved terminal state for raw mode } // NewDualProgress creates a new dual progress display @@ -310,6 +317,7 @@ func (d *DualProgress) Start() { return } d.started = true + sourceEnabled := d.sourceSwitchEnabled d.mu.Unlock() // Print pipeline description using stages (same stage = parallel) @@ -323,6 +331,11 @@ func (d *DualProgress) Start() { fmt.Fprintf(d.output, "%s", colorReset) } + // Print source info if enabled (above the progress bars) + if sourceEnabled { + d.printSourceInfo() + } + // Print initial empty lines for progress bars (2 bars) fmt.Fprintln(d.output) fmt.Fprintln(d.output) @@ -356,13 +369,11 @@ func (d *DualProgress) updateLoop() { // render updates the display with both bars func (d *DualProgress) render() { - d.mu.Lock() - defer d.mu.Unlock() - downloadLine := d.Download.Render(d.useColor) extractLine := d.Extract.Render(d.useColor) if d.useColor { + // Always use 2-line mode (source info was printed once above, if enabled) // Single atomic write to avoid terminal buffering issues: // \r - carriage return to start of line (ensures clean positioning) // \x1b[2A - move cursor up 2 lines @@ -389,6 +400,9 @@ func (d *DualProgress) Stop() { d.done = true d.mu.Unlock() + // Disable source switching and restore terminal + d.DisableSourceSwitching() + close(d.stopCh) <-d.doneCh } @@ -422,6 +436,156 @@ func (d *DualProgress) IsInterrupted() bool { return d.interrupted } +// EnableSourceSwitching enables keyboard-based source switching during download. +// The onSwitch callback is called when the user presses 'n' or 's' to switch sources. +// sourceInfo is displayed ONCE above the progress bars (e.g., "Source 1/5: 192.168.1.1"). +// If called before Start(), the source info will be printed when Start() is called. +// If called after Start(), the source info is printed immediately above the progress bars. +func (d *DualProgress) EnableSourceSwitching(sourceInfo string, onSwitch func()) { + d.mu.Lock() + + alreadyStarted := d.started + + d.sourceSwitchEnabled = true + d.sourceInfo = sourceInfo + d.onSourceSwitch = onSwitch + d.keyboardStopCh = make(chan struct{}) + d.mu.Unlock() + + // If already started, print the source info line now (above the progress bars) + // This requires moving cursor up, printing, and adjusting + if alreadyStarted { + d.printSourceInfo() + } + // If not started yet, Start() will print the source info + + // Start keyboard listener if terminal supports it + if term.IsTerminal(int(os.Stdin.Fd())) { + go d.keyboardListener() + } +} + +// printSourceInfo prints the source info line once +func (d *DualProgress) printSourceInfo() { + d.mu.Lock() + sourceInfo := d.sourceInfo + d.mu.Unlock() + + if sourceInfo == "" { + return + } + + if d.useColor { + hint := fmt.Sprintf("%s[Press 'n' for next source, Ctrl+C to exit]%s", colorDim, colorReset) + sourceLine := fmt.Sprintf("%s%s%s %s", colorTeal, sourceInfo, colorReset, hint) + fmt.Fprintf(d.output, "%s\n", sourceLine) + } else { + fmt.Fprintf(d.output, "%s [Press 'n' for next source, Ctrl+C to exit]\n", sourceInfo) + } +} + +// UpdateSourceInfo updates the source info displayed during download +func (d *DualProgress) UpdateSourceInfo(sourceInfo string) { + d.mu.Lock() + defer d.mu.Unlock() + d.sourceInfo = sourceInfo +} + +// DisableSourceSwitching stops the keyboard listener and restores terminal state +func (d *DualProgress) DisableSourceSwitching() { + d.mu.Lock() + if !d.sourceSwitchEnabled { + d.mu.Unlock() + return + } + d.sourceSwitchEnabled = false + keyboardStopCh := d.keyboardStopCh + oldState := d.oldTermState + d.mu.Unlock() + + // Stop keyboard listener + if keyboardStopCh != nil { + close(keyboardStopCh) + } + + // Restore terminal state + if oldState != nil { + term.Restore(int(os.Stdin.Fd()), oldState) + } +} + +// keyboardListener runs in the background listening for keypresses +func (d *DualProgress) keyboardListener() { + // Set terminal to raw mode to read individual keypresses + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + // Can't set raw mode, skip keyboard listening + return + } + + d.mu.Lock() + d.oldTermState = oldState + d.mu.Unlock() + + // Restore terminal on exit + defer term.Restore(fd, oldState) + + // Create a single reader goroutine that sends keypresses on a channel + // This avoids spawning new goroutines for each read attempt + keyCh := make(chan byte, 8) + readerDone := make(chan struct{}) + + go func() { + defer close(readerDone) + buf := make([]byte, 1) + for { + n, err := os.Stdin.Read(buf) + if err != nil { + return + } + if n > 0 { + select { + case keyCh <- buf[0]: + default: + // Channel full, drop the key + } + } + } + }() + + // Process keypresses until stopped + for { + select { + case <-d.keyboardStopCh: + // Cleanup: we can't easily interrupt the blocked Read, + // but at least we stop processing and the goroutine will + // naturally exit when stdin is closed or on next keypress + return + case key := <-keyCh: + // Check for switch keys: 'n', 's', or space + if key == 'n' || key == 'N' || key == 's' || key == 'S' || key == ' ' { + d.mu.Lock() + onSwitch := d.onSourceSwitch + d.mu.Unlock() + if onSwitch != nil { + onSwitch() + } + } + // Check for Ctrl+C (0x03) + if key == 0x03 { + // In raw mode, Ctrl+C is captured as byte 0x03 instead of triggering SIGINT. + // We need to restore the terminal and send SIGINT ourselves so the normal + // signal handler (signal.NotifyContext) can process it. + term.Restore(fd, oldState) + // Send SIGINT to self - this will trigger context cancellation + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + return + } + } + } +} + // IndexingProgress tracks shard flush progress type IndexingProgress struct { label string diff --git a/pkg/snapshot/bufmonreader.go b/pkg/snapshot/bufmonreader.go index 8fb54a7a..8a2432db 100644 --- a/pkg/snapshot/bufmonreader.go +++ b/pkg/snapshot/bufmonreader.go @@ -55,8 +55,14 @@ func NewBufMonReaderHTTP(ctx context.Context, url string) (*bufmonreader, error) return NewBufMonReaderHTTPWithSave(ctx, url, "") } +// PartialSuffix is appended to snapshot files during download to mark them as incomplete. +// Once download completes successfully, the file is atomically renamed to remove this suffix. +const PartialSuffix = ".partial" + // NewBufMonReaderHTTPWithSave streams from HTTP URL and optionally saves to disk. // If savePath is non-empty, the data will be written to disk while streaming. +// The file is initially written with a .partial suffix for safety - use FinalizePartialDownload +// after successful processing to rename it to the final path. // Returns: (*bufmonreader, error) func NewBufMonReaderHTTPWithSave(ctx context.Context, url string, savePath string) (*bufmonreader, error) { resp, err := http.Head(url) @@ -85,12 +91,14 @@ func NewBufMonReaderHTTPWithSave(ctx context.Context, url string, savePath strin var closer io.Closer = resp.Body // If savePath is provided, use TeeReader to write to disk while streaming + // Write to .partial file first for crash safety if savePath != "" { + partialPath := savePath + PartialSuffix // Note: Don't log here - caller logs before progress bar starts to avoid breaking cursor positioning - outFile, err := os.Create(savePath) + outFile, err := os.Create(partialPath) if err != nil { resp.Body.Close() - return nil, fmt.Errorf("creating save file %s: %v", savePath, err) + return nil, fmt.Errorf("creating save file %s: %v", partialPath, err) } // TeeReader splits the stream: data goes to both the tar reader AND the file @@ -109,6 +117,45 @@ func NewBufMonReaderHTTPWithSave(ctx context.Context, url string, savePath strin }, nil } +// FinalizePartialDownload atomically renames a completed .partial file to its final name. +// This should be called after successfully processing a snapshot that was saved with +// NewBufMonReaderHTTPWithSave. If savePath is empty, this is a no-op. +func FinalizePartialDownload(savePath string) error { + if savePath == "" { + return nil + } + partialPath := savePath + PartialSuffix + if _, err := os.Stat(partialPath); os.IsNotExist(err) { + // No partial file exists (maybe wasn't saving, or already finalized) + return nil + } + // Sync the file to ensure all data is flushed to disk before rename + f, err := os.Open(partialPath) + if err == nil { + f.Sync() + f.Close() + } + // Atomic rename + if err := os.Rename(partialPath, savePath); err != nil { + return fmt.Errorf("failed to finalize snapshot %s: %w", savePath, err) + } + mlog.Log.Infof("Finalized snapshot download: %s", savePath) + return nil +} + +// CleanupPartialDownload removes a .partial file if it exists. +// This should be called on error/cancellation to clean up incomplete downloads. +func CleanupPartialDownload(savePath string) { + if savePath == "" { + return + } + partialPath := savePath + PartialSuffix + if _, err := os.Stat(partialPath); err == nil { + mlog.Log.Infof("Cleaning up partial download: %s", partialPath) + os.Remove(partialPath) + } +} + // multiCloser closes multiple io.Closers type multiCloser struct { closers []io.Closer diff --git a/pkg/snapshot/build_db.go b/pkg/snapshot/build_db.go index 95708da9..6c05b715 100644 --- a/pkg/snapshot/build_db.go +++ b/pkg/snapshot/build_db.go @@ -56,6 +56,8 @@ func CleanAccountsDbDir(accountsDbDir string) { // maxSnapshots controls how many snapshots to keep: // - 0 = delete all snapshots (stream-only mode, used by new-snapshot bootstrap) // - N > 0 = keep N newest snapshots, delete the rest +// +// This function also always cleans up any .partial files (incomplete downloads from crashes). func CleanSnapshotDownloadDir(downloadPath string, maxSnapshots int) { if downloadPath == "" || maxSnapshots < 0 { return @@ -65,6 +67,24 @@ func CleanSnapshotDownloadDir(downloadPath string, maxSnapshots int) { return // Directory may not exist yet } + // Always clean up partial downloads first (crash recovery) + for _, entry := range entries { + name := entry.Name() + if strings.HasSuffix(name, PartialSuffix) { + path := filepath.Join(downloadPath, name) + mlog.Log.Infof("Cleaning up incomplete download from previous run: %s", name) + if err := os.Remove(path); err != nil { + mlog.Log.Errorf("Failed to remove partial download %s: %v", name, err) + } + } + } + + // Re-read entries after cleaning partials (in case we removed any) + entries, err = os.ReadDir(downloadPath) + if err != nil { + return + } + // Collect snapshot files with their info type snapshotFile struct { name string @@ -337,12 +357,8 @@ func readTar( // cleanupPartial deletes the partial download file if it exists cleanupPartial := func(reason string) { if savePath != "" { - if _, statErr := os.Stat(savePath); statErr == nil { - mlog.Log.Infof("Deleting partial download %s (%s)", savePath, reason) - if rmErr := os.Remove(savePath); rmErr != nil { - mlog.Log.Errorf("Failed to delete partial download %s: %v", savePath, rmErr) - } - } + mlog.Log.Infof("Cleaning up partial download (%s)", reason) + CleanupPartialDownload(savePath) } } @@ -389,6 +405,13 @@ func readTar( } } + // Successfully processed the entire tar - finalize the download by renaming from .partial + if err := FinalizePartialDownload(savePath); err != nil { + mlog.Log.Errorf("Failed to finalize snapshot download: %v", err) + // Don't return error here - the snapshot was processed successfully, + // the finalization failure just means we won't be able to reuse it + } + return nil } diff --git a/pkg/snapshot/build_db_with_incr.go b/pkg/snapshot/build_db_with_incr.go index d21b90f4..fdb8c235 100644 --- a/pkg/snapshot/build_db_with_incr.go +++ b/pkg/snapshot/build_db_with_incr.go @@ -35,6 +35,8 @@ func fmtDuration(d time.Duration) string { } // BuildAccountsDbWithIncr builds the accounts database from full + incremental snapshots. +// If sourceSelector is provided, it will be used to find incremental sources from cached +// Stage 2 results (much faster than full cluster search). Pass nil to use legacy behavior. func BuildAccountsDbWithIncr( ctx context.Context, fullSnapshotFile string, @@ -46,6 +48,7 @@ func BuildAccountsDbWithIncr( blockDir string, snapCfg snapshotdl.SnapshotConfig, dp *progress.DualProgress, + sourceSelector *snapshotdl.SourceSelector, ) (*accountsdb.AccountsDb, *SnapshotManifest, error) { // Clean any leftover artifacts from previous incomplete runs (e.g., Ctrl+C) CleanAccountsDbDir(accountsDbDir) @@ -137,72 +140,213 @@ func BuildAccountsDbWithIncr( // and flush once at the end. This avoids the bug where pools still reference // the old ShardLogger after reinit. - // Get incremental snapshot URL (tries same source first, then searches if needed) + // Get incremental snapshot - try cached sources first if available, then fall back to cluster search mlog.Log.Infof("finding incremental snapshot matching full slot %d...", fullSnapshotSlot) incrSnapshotDlStart := time.Now() - incrementalSnapshotPath, _, incrSlot, err := snapshotdl.GetIncrementalSnapshotURL(fullSnapshotFile, referenceSlot, fullSnapshotSlot, snapCfg) - if err != nil { - klog.Fatalf("error getting incremental snapshot URL: %s", err) + + // Try to get incremental from cached Stage 2 sources (much faster) + var incrSelector *snapshotdl.IncrementalSelector + if sourceSelector != nil { + incrSelector = sourceSelector.GetIncrementalSelector(ctx, fullSnapshotSlot, snapCfg.Verbose) } - mlog.Log.Infof("found incremental snapshot URL in %s: %s", fmtDuration(time.Since(incrSnapshotDlStart)), incrementalSnapshotPath) - - // Retry loop for incremental snapshot download - // If download fails mid-way (not context cancellation), re-discover sources and retry - maxIncrRetries := 3 - for incrAttempt := range maxIncrRetries { - if ctx.Err() != nil { - return nil, nil, fmt.Errorf("attempting to download incremental snapshot: %w", ctx.Err()) - } - if incrAttempt > 0 { - // Re-discover incremental snapshot URL (sources may have changed) - mlog.Log.Infof("Incremental download failed, re-discovering sources (attempt %d/%d)...", incrAttempt+1, maxIncrRetries) - incrementalSnapshotPath, _, incrSlot, err = snapshotdl.GetIncrementalSnapshotURL(fullSnapshotFile, referenceSlot, fullSnapshotSlot, snapCfg) - if err != nil { - mlog.Log.Errorf("Failed to re-discover incremental snapshot: %v", err) + + var incrementalSnapshotPath string + var incrSlot int + var incrSwitchRequested atomic.Bool + + if incrSelector != nil && incrSelector.TotalSources() > 0 { + // Use cached incremental sources with switching support + mlog.Log.Infof("using cached Stage 2 sources for incremental (found %d sources)", incrSelector.TotalSources()) + defer incrSelector.Close() + + for { + if ctx.Err() != nil { + return nil, nil, fmt.Errorf("downloading incremental snapshot: %w", ctx.Err()) + } + + currentIncr := incrSelector.Current() + if currentIncr == nil { + // Exhausted cached sources, fall back to full search + mlog.Log.Infof("exhausted %d cached incremental sources, falling back to cluster search...", incrSelector.TotalSources()) + incrSelector = nil + break + } + + incrementalSnapshotPath = currentIncr.URL + incrSlot = currentIncr.EndSlot + + // Create cancellable context for this download attempt + incrCtx, cancelIncr := context.WithCancel(ctx) + + // Show source info with switching hint + threshNote := "" + if !currentIncr.WithinThresh { + threshNote = " (outside threshold)" + } + mlog.Log.Infof("📸 Incremental source %d/%d: %s (slot %d, %d slots behind tip%s, %.1f MB/s)", + incrSelector.CurrentIndex()+1, + incrSelector.TotalSources(), + currentIncr.NodeIP, + currentIncr.EndSlot, + currentIncr.Age(), + threshNote, + currentIncr.SpeedMBs, + ) + + // Enable 'n' key switching for incremental if we have progress display + incrSwitchRequested.Store(false) + if dp != nil && incrSelector.HasMore() { + sourceInfo := fmt.Sprintf("Incr Source %d/%d: %s (%d slots behind)", + incrSelector.CurrentIndex()+1, + incrSelector.TotalSources(), + currentIncr.NodeIP, + currentIncr.Age(), + ) + dp.EnableSourceSwitching(sourceInfo, func() { + incrSwitchRequested.Store(true) + cancelIncr() + mlog.Log.Infof("User requested incremental source switch...") + }) + } + + // Try to download from this incremental source + incrSnapshotStart := time.Now() + incrementalManifestCopy, manifestErr := UnmarshalManifestFromSnapshot(incrCtx, incrementalSnapshotPath, accountsDbDir) + + if manifestErr == nil { + *incrementalManifest = *incrementalManifestCopy + mlog.Log.Infof("parsed manifest from incrementalFile=%s", incrementalSnapshotPath) + + // Determine save path + var incrSavePath string + if snapCfg.MaxFullSnapshots > 0 && (strings.HasPrefix(incrementalSnapshotPath, "http://") || strings.HasPrefix(incrementalSnapshotPath, "https://")) { + if snapshotDownloadPath != "" { + if err := os.MkdirAll(snapshotDownloadPath, 0o755); err != nil { + cancelIncr() + if dp != nil { + dp.DisableSourceSwitching() + } + return nil, nil, fmt.Errorf("failed to create snapshot download directory %s: %w", snapshotDownloadPath, err) + } + urlParts := strings.Split(incrementalSnapshotPath, "/") + filename := urlParts[len(urlParts)-1] + incrSavePath = filepath.Join(snapshotDownloadPath, filename) + mlog.Log.Infof("Will save incremental snapshot to %s while streaming", incrSavePath) + } + } + + err = readTar(incrCtx, wg, incrementalSnapshotPath, pools.appendVecCopying, readTarOptions{savePath: incrSavePath, isIncremental: true}) + wg.Wait() + + if dp != nil { + dp.DisableSourceSwitching() + } + cancelIncr() + + if err == nil { + mlog.Log.Infof("finished reading %s in %s", incrementalSnapshotPath, fmtDuration(time.Since(start))) + mlog.Log.Infof("done processing incremental snapshot in %s.", fmtDuration(time.Since(incrSnapshotStart))) + break // Success! + } + } else { + err = manifestErr + if dp != nil { + dp.DisableSourceSwitching() + } + cancelIncr() + } + + // Handle switch request or error + if incrSwitchRequested.Load() || (incrCtx.Err() != nil && ctx.Err() == nil) { + nextIncr := incrSelector.Next() + if nextIncr == nil { + mlog.Log.Infof("No more cached incremental sources, falling back to cluster search...") + incrSelector = nil + break + } + mlog.Log.Infof("Switching to incremental source %d/%d: %s (%d slots behind)", + incrSelector.CurrentIndex()+1, + incrSelector.TotalSources(), + nextIncr.NodeIP, + nextIncr.Age(), + ) continue } - mlog.Log.Infof("Found new incremental snapshot URL: %s (slot %d)", incrementalSnapshotPath, incrSlot) + + // Check parent context + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } + + // Error - try next source + mlog.Log.Errorf("Incremental download from %s failed: %v", currentIncr.NodeIP, err) + nextIncr := incrSelector.Next() + if nextIncr == nil { + mlog.Log.Infof("No more cached incremental sources, falling back to cluster search...") + incrSelector = nil + break + } } + } - incrSnapshotStart := time.Now() - incrementalManifestCopy, err := UnmarshalManifestFromSnapshot(ctx, incrementalSnapshotPath, accountsDbDir) + // Fall back to full cluster search if needed (no cached sources or all exhausted) + if incrSelector == nil { + incrementalSnapshotPath, _, incrSlot, err = snapshotdl.GetIncrementalSnapshotURL(fullSnapshotFile, referenceSlot, fullSnapshotSlot, snapCfg) if err != nil { - mlog.Log.Errorf("reading incremental snapshot manifest: %v", err) - continue + klog.Fatalf("error getting incremental snapshot URL: %s", err) } - // Copy the manifest so the worker pool's pointer has the value. - *incrementalManifest = *incrementalManifestCopy - mlog.Log.Infof("parsed manifest from incrementalFile=%s", incrementalSnapshotPath) - - // Determine save path for incremental snapshot if streaming from HTTP - var incrSavePath string - if snapCfg.MaxFullSnapshots > 0 && (strings.HasPrefix(incrementalSnapshotPath, "http://") || strings.HasPrefix(incrementalSnapshotPath, "https://")) { - if snapshotDownloadPath != "" { - // Ensure snapshot download directory exists (may not exist if full was local) - if err := os.MkdirAll(snapshotDownloadPath, 0o755); err != nil { - return nil, nil, fmt.Errorf("failed to create snapshot download directory %s: %w", snapshotDownloadPath, err) + mlog.Log.Infof("found incremental snapshot URL in %s: %s", fmtDuration(time.Since(incrSnapshotDlStart)), incrementalSnapshotPath) + + // Retry loop for incremental snapshot download (legacy behavior) + maxIncrRetries := 3 + for incrAttempt := range maxIncrRetries { + if ctx.Err() != nil { + return nil, nil, fmt.Errorf("attempting to download incremental snapshot: %w", ctx.Err()) + } + if incrAttempt > 0 { + mlog.Log.Infof("Incremental download failed, re-discovering sources (attempt %d/%d)...", incrAttempt+1, maxIncrRetries) + incrementalSnapshotPath, _, incrSlot, err = snapshotdl.GetIncrementalSnapshotURL(fullSnapshotFile, referenceSlot, fullSnapshotSlot, snapCfg) + if err != nil { + mlog.Log.Errorf("Failed to re-discover incremental snapshot: %v", err) + continue } - // Extract filename from URL and create save path - urlParts := strings.Split(incrementalSnapshotPath, "/") - filename := urlParts[len(urlParts)-1] - incrSavePath = filepath.Join(snapshotDownloadPath, filename) - mlog.Log.Infof("Will save incremental snapshot to %s while streaming", incrSavePath) + mlog.Log.Infof("Found new incremental snapshot URL: %s (slot %d)", incrementalSnapshotPath, incrSlot) } - } - err = readTar(ctx, wg, incrementalSnapshotPath, pools.appendVecCopying, readTarOptions{savePath: incrSavePath, isIncremental: true}) - wg.Wait() - mlog.Log.Infof("finished reading %s in %s", incrementalSnapshotPath, fmtDuration(time.Since(start))) - mlog.Log.Infof("done processing incremental snapshot in %s.", fmtDuration(time.Since(incrSnapshotStart))) - // Check if we should retry - if err == nil { - break // Success + incrSnapshotStart := time.Now() + incrementalManifestCopy, err := UnmarshalManifestFromSnapshot(ctx, incrementalSnapshotPath, accountsDbDir) + if err != nil { + mlog.Log.Errorf("reading incremental snapshot manifest: %v", err) + continue + } + *incrementalManifest = *incrementalManifestCopy + mlog.Log.Infof("parsed manifest from incrementalFile=%s", incrementalSnapshotPath) + + var incrSavePath string + if snapCfg.MaxFullSnapshots > 0 && (strings.HasPrefix(incrementalSnapshotPath, "http://") || strings.HasPrefix(incrementalSnapshotPath, "https://")) { + if snapshotDownloadPath != "" { + if err := os.MkdirAll(snapshotDownloadPath, 0o755); err != nil { + return nil, nil, fmt.Errorf("failed to create snapshot download directory %s: %w", snapshotDownloadPath, err) + } + urlParts := strings.Split(incrementalSnapshotPath, "/") + filename := urlParts[len(urlParts)-1] + incrSavePath = filepath.Join(snapshotDownloadPath, filename) + mlog.Log.Infof("Will save incremental snapshot to %s while streaming", incrSavePath) + } + } + + err = readTar(ctx, wg, incrementalSnapshotPath, pools.appendVecCopying, readTarOptions{savePath: incrSavePath, isIncremental: true}) + wg.Wait() + mlog.Log.Infof("finished reading %s in %s", incrementalSnapshotPath, fmtDuration(time.Since(start))) + mlog.Log.Infof("done processing incremental snapshot in %s.", fmtDuration(time.Since(incrSnapshotStart))) + if err == nil { + break + } + mlog.Log.Errorf("Incremental download failed: %v", err) + } + if err != nil { + return nil, nil, err } - // Download failed mid-way, will retry with re-discovery - mlog.Log.Errorf("Incremental download failed: %v", err) - } - if err != nil { - return nil, nil, err } // Show indexing progress for shard flush diff --git a/pkg/snapshotdl/snapshotdl.go b/pkg/snapshotdl/snapshotdl.go index b8bd1f1f..8ac610a1 100644 --- a/pkg/snapshotdl/snapshotdl.go +++ b/pkg/snapshotdl/snapshotdl.go @@ -84,7 +84,8 @@ type SnapshotConfig struct { MinIncrementalSpeedMBs float64 // Minimum speed for incremental sources (MB/s, 0 = no minimum) // Logging - LogDir string // Directory for snapshot finder logs (default: /mnt/mithril-logs/snapshot-finder) + LogDir string // Directory for snapshot finder logs (default: /mnt/mithril-logs) + RunID string // Mithril run ID for log file naming } // DefaultSnapshotConfig returns production-ready defaults matching solana-snapshot-finder-go @@ -137,7 +138,8 @@ func DefaultSnapshotConfig() SnapshotConfig { MinIncrementalSpeedMBs: 2.0, // Minimum 2 MB/s for incrementals (~8min for 1GB) // Logging - LogDir: "/mnt/mithril-logs/snapshot-finder", + LogDir: "/mnt/mithril-logs", + RunID: "", // Set by caller } } diff --git a/pkg/snapshotdl/source_selector.go b/pkg/snapshotdl/source_selector.go new file mode 100644 index 00000000..5260b413 --- /dev/null +++ b/pkg/snapshotdl/source_selector.go @@ -0,0 +1,748 @@ +package snapshotdl + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "github.com/Overclock-Validator/mithril/pkg/mlog" + "github.com/Overclock-Validator/solana-snapshot-finder-go/pkg/config" + "github.com/Overclock-Validator/solana-snapshot-finder-go/pkg/rpc" + "github.com/Overclock-Validator/solana-snapshot-finder-go/pkg/snapshot" +) + +// RankedSource represents a snapshot source with its speed test results +type RankedSource struct { + URL string // HTTP URL for streaming + Slot int // Full snapshot slot + ReferenceSlot int // Current network slot (for calculating age) + NodeIP string // IP:port of the node + NodeRPC string // Full RPC URL (http://ip:port) + Version string // Solana version of the node + SpeedMBs float64 // Download speed in MB/s from Stage 2 testing + RTTMs int // Round-trip time in milliseconds + Rank int // Rank in the sorted list (1-based) +} + +// Age returns how many slots behind the snapshot is from the current tip +func (s *RankedSource) Age() int { + return s.ReferenceSlot - s.Slot +} + +// SourceSelector tracks alternative snapshot sources and allows switching between them +type SourceSelector struct { + sources []RankedSource + currentIndex int + mu sync.Mutex + switchCh chan struct{} // signal to switch sources + closed bool + SearchTime time.Duration // How long Stage 2 took + + // Cached Stage 1 results for incremental source lookup + // Stage 1 runs a short speed test on all nodes, so we have S1.MedianMBs for sorting. + // Stage 2 then selects the top nodes from Stage 1 for more accurate testing. + // For incremental lookup, we use all Stage 1 nodes (not just Stage 2 winners) + // to have a larger pool with matching base slots. + allStage1Nodes []rpc.RankedNode // All nodes that passed Stage 1 (with speed data) + referenceSlot int // Network slot at time of search + incrThreshold int // Incremental freshness threshold in slots +} + +// IncrementalSource represents an incremental snapshot source +type IncrementalSource struct { + URL string // HTTP URL for streaming + BaseSlot int // Base slot (must match full snapshot) + EndSlot int // End slot of incremental + ReferenceSlot int // Network slot at time of search + NodeIP string // IP:port of the node + NodeRPC string // Full RPC URL (http://ip:port) + Version string // Solana version of the node + SpeedMBs float64 // Download speed in MB/s from Stage 1/2 testing + RTTMs int // Round-trip time in milliseconds + Rank int // Rank in the sorted list (1-based) + WithinThresh bool // True if within incremental threshold +} + +// Age returns how many slots behind the incremental is from the tip +func (s *IncrementalSource) Age() int { + return s.ReferenceSlot - s.EndSlot +} + +// IncrementalSelector tracks incremental snapshot sources and allows switching +type IncrementalSelector struct { + sources []IncrementalSource + currentIndex int + mu sync.Mutex + switchCh chan struct{} + closed bool + baseSlot int // The full snapshot slot these incrementals are based on +} + +// NewSourceSelector creates a new source selector with the given ranked sources +func NewSourceSelector(sources []RankedSource) *SourceSelector { + return &SourceSelector{ + sources: sources, + currentIndex: 0, + switchCh: make(chan struct{}, 1), // buffered so RequestSwitch doesn't block + } +} + +// Current returns the currently selected source, or nil if exhausted +func (s *SourceSelector) Current() *RankedSource { + s.mu.Lock() + defer s.mu.Unlock() + + if s.currentIndex >= len(s.sources) { + return nil + } + return &s.sources[s.currentIndex] +} + +// CurrentIndex returns the current source index (0-based) +func (s *SourceSelector) CurrentIndex() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.currentIndex +} + +// TotalSources returns the total number of available sources +func (s *SourceSelector) TotalSources() int { + return len(s.sources) +} + +// Next advances to the next source and returns it, or nil if no more sources +func (s *SourceSelector) Next() *RankedSource { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentIndex++ + if s.currentIndex >= len(s.sources) { + return nil + } + return &s.sources[s.currentIndex] +} + +// HasMore returns true if there are more sources to try +func (s *SourceSelector) HasMore() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.currentIndex+1 < len(s.sources) +} + +// SwitchCh returns a channel that signals when the user requests a source switch +func (s *SourceSelector) SwitchCh() <-chan struct{} { + return s.switchCh +} + +// RequestSwitch signals that the user wants to switch to the next source +func (s *SourceSelector) RequestSwitch() { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return + } + s.mu.Unlock() + + // Non-blocking send (buffered channel) + select { + case s.switchCh <- struct{}{}: + default: + // Already has a pending switch request + } +} + +// Close closes the switch channel (call when download completes) +func (s *SourceSelector) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if !s.closed { + s.closed = true + close(s.switchCh) + } +} + +// GetIncrementalSelector returns an IncrementalSelector for the given base slot +// using cached Stage 1 results. Returns nil if no matching incrementals found. +// This is much faster than doing a full cluster search. +// +// Stage 1 runs a short speed test on all nodes, so we have actual download speed +// data (S1.MedianMBs) for sorting, not just latency estimates. +func (s *SourceSelector) GetIncrementalSelector(ctx context.Context, baseSlot int, verbose bool) *IncrementalSelector { + if len(s.allStage1Nodes) == 0 { + return nil + } + + // Filter for nodes with matching incremental base slot + var matchingNodes []rpc.RankedNode + for _, node := range s.allStage1Nodes { + if node.Result.HasInc && node.Result.IncBase == int64(baseSlot) { + matchingNodes = append(matchingNodes, node) + } + } + + if len(matchingNodes) == 0 { + mlog.Log.Infof("No cached nodes have incremental with base slot %d", baseSlot) + return nil + } + + mlog.Log.Infof("Found %d Stage 1 nodes with incremental base slot %d", len(matchingNodes), baseSlot) + + // Sort by: within threshold first, then by end slot (freshest), then by Stage 1 speed + for i := 0; i < len(matchingNodes)-1; i++ { + for j := i + 1; j < len(matchingNodes); j++ { + iAge := s.referenceSlot - int(matchingNodes[i].Result.IncSlot) + jAge := s.referenceSlot - int(matchingNodes[j].Result.IncSlot) + iWithin := s.incrThreshold <= 0 || iAge <= s.incrThreshold + jWithin := s.incrThreshold <= 0 || jAge <= s.incrThreshold + + // Within threshold sorts first + if jWithin && !iWithin { + matchingNodes[i], matchingNodes[j] = matchingNodes[j], matchingNodes[i] + } else if jWithin == iWithin { + // Same threshold status: sort by end slot (higher = fresher) + if matchingNodes[j].Result.IncSlot > matchingNodes[i].Result.IncSlot { + matchingNodes[i], matchingNodes[j] = matchingNodes[j], matchingNodes[i] + } else if matchingNodes[j].Result.IncSlot == matchingNodes[i].Result.IncSlot { + // Same end slot: sort by Stage 1 speed (higher = faster) + if matchingNodes[j].S1.MedianMBs > matchingNodes[i].S1.MedianMBs { + matchingNodes[i], matchingNodes[j] = matchingNodes[j], matchingNodes[i] + } + } + } + } + } + + // Convert to IncrementalSource list (fetch URLs) + var sources []IncrementalSource + maxSources := 8 + if len(matchingNodes) < maxSources { + maxSources = len(matchingNodes) + } + + for i := 0; i < maxSources; i++ { + node := matchingNodes[i] + + // Get incremental URL from this node + urlInfo, err := snapshot.GetSnapshotURL(ctx, node.Result.RPC, "incremental") + if err != nil || urlInfo == nil || urlInfo.BaseSlot != baseSlot { + if verbose { + mlog.Log.Infof("Skipping %s: failed to get incremental URL or base mismatch: %v", node.Result.RPC, err) + } + continue + } + + age := s.referenceSlot - urlInfo.Slot + withinThresh := s.incrThreshold <= 0 || age <= s.incrThreshold + + // Extract IP from RPC URL + nodeIP := node.Result.RPC + if idx := strings.Index(nodeIP, "://"); idx != -1 { + nodeIP = nodeIP[idx+3:] + } + + sources = append(sources, IncrementalSource{ + URL: urlInfo.URL, + BaseSlot: urlInfo.BaseSlot, + EndSlot: urlInfo.Slot, + ReferenceSlot: s.referenceSlot, + NodeIP: nodeIP, + NodeRPC: node.Result.RPC, + Version: node.Result.Version, + SpeedMBs: node.S1.MedianMBs, // Use actual Stage 1 speed test result + RTTMs: int(node.Result.Latency), + Rank: len(sources) + 1, + WithinThresh: withinThresh, + }) + } + + if len(sources) == 0 { + return nil + } + + mlog.Log.Infof("Found %d cached incremental sources for base slot %d", len(sources), baseSlot) + + return &IncrementalSelector{ + sources: sources, + currentIndex: 0, + switchCh: make(chan struct{}, 1), + baseSlot: baseSlot, + } +} + +// NewIncrementalSelector creates a new incremental selector +func NewIncrementalSelector(sources []IncrementalSource, baseSlot int) *IncrementalSelector { + return &IncrementalSelector{ + sources: sources, + currentIndex: 0, + switchCh: make(chan struct{}, 1), + baseSlot: baseSlot, + } +} + +// Current returns the currently selected incremental source, or nil if exhausted +func (s *IncrementalSelector) Current() *IncrementalSource { + s.mu.Lock() + defer s.mu.Unlock() + + if s.currentIndex >= len(s.sources) { + return nil + } + return &s.sources[s.currentIndex] +} + +// CurrentIndex returns the current source index (0-based) +func (s *IncrementalSelector) CurrentIndex() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.currentIndex +} + +// TotalSources returns the total number of available sources +func (s *IncrementalSelector) TotalSources() int { + return len(s.sources) +} + +// BaseSlot returns the full snapshot slot these incrementals are based on +func (s *IncrementalSelector) BaseSlot() int { + return s.baseSlot +} + +// Next advances to the next source and returns it, or nil if no more sources +func (s *IncrementalSelector) Next() *IncrementalSource { + s.mu.Lock() + defer s.mu.Unlock() + + s.currentIndex++ + if s.currentIndex >= len(s.sources) { + return nil + } + return &s.sources[s.currentIndex] +} + +// HasMore returns true if there are more sources to try +func (s *IncrementalSelector) HasMore() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.currentIndex+1 < len(s.sources) +} + +// SwitchCh returns a channel that signals when the user requests a source switch +func (s *IncrementalSelector) SwitchCh() <-chan struct{} { + return s.switchCh +} + +// RequestSwitch signals that the user wants to switch to the next source +func (s *IncrementalSelector) RequestSwitch() { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return + } + s.mu.Unlock() + + select { + case s.switchCh <- struct{}{}: + default: + } +} + +// Close closes the switch channel +func (s *IncrementalSelector) Close() { + s.mu.Lock() + defer s.mu.Unlock() + if !s.closed { + s.closed = true + close(s.switchCh) + } +} + +// rankedNodesToSources converts rpc.RankedNode list to RankedSource list. +// This fetches snapshot URLs from each node (the quick HTTP call, not the download). +func rankedNodesToSources(ctx context.Context, rankedNodes []rpc.RankedNode, referenceSlot int, maxSources int, verbose bool) []RankedSource { + var sources []RankedSource + + for i := 0; i < maxSources && i < len(rankedNodes); i++ { + rn := rankedNodes[i] + + // Get snapshot URL from this node (quick metadata fetch) + urlInfo, err := snapshot.GetSnapshotURL(ctx, rn.Result.RPC, "full") + if err != nil || urlInfo == nil { + if verbose { + mlog.Log.Infof("Skipping %s: failed to get snapshot URL: %v", rn.Result.RPC, err) + } + continue + } + + // Use Stage 2 min speed if available, otherwise fall back to Stage 1 median + speed := rn.S1.MedianMBs + if rn.S2.MinMBs > 0 { + speed = rn.S2.MinMBs + } + + // Extract IP from RPC URL + nodeIP := rn.Result.RPC + if idx := strings.Index(nodeIP, "://"); idx != -1 { + nodeIP = nodeIP[idx+3:] + } + + sources = append(sources, RankedSource{ + URL: urlInfo.URL, + Slot: urlInfo.Slot, + ReferenceSlot: referenceSlot, + NodeIP: nodeIP, + NodeRPC: rn.Result.RPC, + Version: rn.Result.Version, + SpeedMBs: speed, + RTTMs: int(rn.Result.Latency), + Rank: len(sources) + 1, + }) + } + + return sources +} + +// GetRankedSnapshotSources discovers and ranks all available snapshot sources. +// Returns a SourceSelector that can be used to switch between sources during download. +// This runs Stage 1 + Stage 2 testing and prints the candidates table. +func GetRankedSnapshotSources(ctx context.Context, snapCfg SnapshotConfig) (*SourceSelector, error) { + searchStart := time.Now() + cfg := snapCfg.toInternalConfig("") + + // Step 1: Get reference slot from multiple RPCs for reliability + referenceSlot, preferredRPC, err := rpc.GetReferenceSlotFromMultiple(cfg.RPCAddresses) + if err != nil { + return nil, fmt.Errorf("error getting reference slot: %w", err) + } + if snapCfg.Verbose { + mlog.Log.Infof("Reference slot: %d (from %s)", referenceSlot, preferredRPC) + } + + // Step 2: Fetch cluster nodes + nodes := rpc.FetchClusterNodes(cfg, preferredRPC) + if len(nodes) == 0 { + return nil, fmt.Errorf("no rpc nodes available from cluster") + } + + // Step 3: Evaluate nodes with version tracking and statistics + mlog.Log.Infof("Probing %d nodes for snapshot availability...", len(nodes)) + results, stats := rpc.EvaluateNodesWithVersionsAndStats(nodes, cfg, referenceSlot) + + // Step 3.5: Filter to only full snapshots that have matching incrementals somewhere + results, incBaseStats := filterByIncrementalBaseMatch(results) + + // Print Node Discovery Report (before speed testing) + if stats != nil { + stats.PrintNodeDiscoveryReport() + } + + // Print incremental base match stats + if incBaseStats.totalWithFull > 0 { + mlog.Log.Infof("Incremental base matching: %d/%d full snapshots have compatible incrementals (%d unique base slots)", + incBaseStats.afterIncBaseMatch, incBaseStats.totalWithFull, incBaseStats.matchingFullSlots) + if incBaseStats.afterIncBaseMatch < incBaseStats.totalWithFull { + mlog.Log.Infof(" (filtered %d sources with no matching incremental base)", + incBaseStats.totalWithFull-incBaseStats.afterIncBaseMatch) + } + } + + // Step 4: Sort and select best nodes by download speed (Stage 1 + Stage 2) + mlog.Log.Infof("Testing download speeds (Stage 1 + Stage 2)...") + _, rankedNodes, speedStats := rpc.SortBestNodesWithStats(results, cfg, stats, referenceSlot) + if len(rankedNodes) == 0 { + return nil, fmt.Errorf("no suitable nodes found with snapshots (check incremental base matching)") + } + + // Print Stage 2 candidates as a table + maxCandidates := 8 + if len(rankedNodes) < maxCandidates { + maxCandidates = len(rankedNodes) + } + candidates := make([]rpc.RankedNodeInfo, maxCandidates) + for i := 0; i < maxCandidates; i++ { + rn := rankedNodes[i] + candidates[i] = rpc.RankedNodeInfo{ + Rank: i + 1, + RPC: rn.Result.RPC, + Version: rn.Result.Version, + RTTMs: int(rn.Result.Latency), + SpeedS1: rn.S1.MedianMBs, + SpeedS2: rn.S2.MinMBs, + } + } + rpc.PrintStage2CandidatesTable(candidates) + + // Print Filter Pipeline with speed test stats + filterCfg := rpc.FilterConfig{ + MaxRTTMs: cfg.MaxRTTMs, + FullThreshold: cfg.FullThreshold, + IncThreshold: cfg.IncrementalThreshold, + MinVersion: cfg.MinNodeVersion, + AllowedVersions: cfg.AllowedNodeVersions, + } + if stats != nil { + stats.PrintFilterPipeline(filterCfg, speedStats) + } + + // Write detailed speed test log + searchDuration := time.Since(searchStart) + if snapCfg.LogDir != "" { + logPath, err := writeDetailedSpeedTestLog( + snapCfg.LogDir, + snapCfg.RunID, + cfg, + snapCfg, + referenceSlot, + incBaseStats, + rankedNodes, + searchDuration, + ) + if err != nil { + mlog.Log.Infof("Warning: failed to write speed test log: %v", err) + } else if logPath != "" { + mlog.Log.Infof("Detailed speed test log written to: %s", logPath) + } + } + + // Step 5: Build RankedSource list from ranked nodes + // Use all ranked nodes (not limited by MaxSnapshotURLAttempts for switching) + maxSources := len(rankedNodes) + if maxSources > 8 { + maxSources = 8 // Limit to top 8 for URL fetching (matching Stage 2 display) + } + + sources := rankedNodesToSources(ctx, rankedNodes, referenceSlot, maxSources, snapCfg.Verbose) + if len(sources) == 0 { + return nil, fmt.Errorf("failed to get snapshot URL from any ranked node") + } + + mlog.Log.Infof("Found %d ranked snapshot sources for selection", len(sources)) + + selector := NewSourceSelector(sources) + selector.SearchTime = time.Since(searchStart) + // Cache all Stage 1 nodes for incremental source lookup + // rankedNodes contains ALL nodes that passed Stage 1 triage (with S1.MedianMBs speed data) + // Stage 2 just selects the top ones for more accurate testing, but we keep the full list + // so we have more candidates with matching incremental base slots + selector.allStage1Nodes = rankedNodes + selector.referenceSlot = referenceSlot + selector.incrThreshold = cfg.IncrementalThreshold + return selector, nil +} + +// writeDetailedSpeedTestLog writes a comprehensive log file with all filtering parameters and results. +// The log is written to logDir/snapshot-search-{runID}-{timestamp}.log +func writeDetailedSpeedTestLog( + logDir string, + runID string, + cfg config.Config, + snapCfg SnapshotConfig, + referenceSlot int, + incBaseStats incBaseMatchStats, + rankedNodes []rpc.RankedNode, + searchDuration time.Duration, +) (string, error) { + if logDir == "" { + return "", nil // Logging disabled + } + + // Ensure directory exists + if err := os.MkdirAll(logDir, 0o755); err != nil { + return "", fmt.Errorf("failed to create log directory: %w", err) + } + + // Generate filename with timestamp and run ID + timestamp := time.Now().UTC().Format("2006-01-02_15-04-05") + var filename string + if runID != "" { + filename = fmt.Sprintf("snapshot-search-%s-%s.log", runID, timestamp) + } else { + filename = fmt.Sprintf("snapshot-search-%s.log", timestamp) + } + logPath := filepath.Join(logDir, filename) + + // Create log file + file, err := os.Create(logPath) + if err != nil { + return "", fmt.Errorf("failed to create log file: %w", err) + } + defer file.Close() + + w := bufio.NewWriter(file) + defer w.Flush() + + // Header + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " MITHRIL SNAPSHOT SEARCH DETAILED LOG\n") + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, "Timestamp: %s UTC\n", time.Now().UTC().Format("2006-01-02 15:04:05")) + if runID != "" { + fmt.Fprintf(w, "Run ID: %s\n", runID) + } + fmt.Fprintf(w, "Reference Slot: %d\n", referenceSlot) + fmt.Fprintf(w, "Search Duration: %s\n", searchDuration.Round(time.Millisecond)) + fmt.Fprintf(w, "\n") + + // Configuration Parameters + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " CONFIGURATION PARAMETERS\n") + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Node Filtering ---\n") + fmt.Fprintf(w, " Max RTT: %d ms\n", cfg.MaxRTTMs) + fmt.Fprintf(w, " TCP Timeout: %d ms\n", cfg.TCPTimeoutMs) + fmt.Fprintf(w, " Min Node Version: %s\n", cfg.MinNodeVersion) + if len(cfg.AllowedNodeVersions) > 0 { + fmt.Fprintf(w, " Allowed Versions: %v\n", cfg.AllowedNodeVersions) + } else { + fmt.Fprintf(w, " Allowed Versions: (all >= min version)\n") + } + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Snapshot Thresholds ---\n") + fmt.Fprintf(w, " Full Threshold: %d slots (~%.1f min)\n", cfg.FullThreshold, float64(cfg.FullThreshold)*0.4/60) + fmt.Fprintf(w, " Incremental Thresh: %d slots (~%.1f sec)\n", cfg.IncrementalThreshold, float64(cfg.IncrementalThreshold)*0.4) + fmt.Fprintf(w, " Safety Margin: %d slots\n", cfg.SafetyMarginSlots) + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Stage 1 (Fast Triage) ---\n") + fmt.Fprintf(w, " Warmup: %d KiB\n", cfg.Stage1WarmKiB) + fmt.Fprintf(w, " Window Size: %d KiB\n", cfg.Stage1WindowKiB) + fmt.Fprintf(w, " Windows: %d (total: %d KiB)\n", cfg.Stage1Windows, cfg.Stage1WindowKiB*int64(cfg.Stage1Windows)) + fmt.Fprintf(w, " Timeout: %d ms\n", cfg.Stage1TimeoutMS) + fmt.Fprintf(w, " Concurrency: %d (0=auto)\n", cfg.Stage1Concurrency) + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Stage 2 (Sustained Test) ---\n") + fmt.Fprintf(w, " Top K Candidates: %d\n", cfg.Stage2TopK) + fmt.Fprintf(w, " Warmup Duration: %d sec\n", cfg.Stage2WarmSec) + fmt.Fprintf(w, " Measure Duration: %d sec\n", cfg.Stage2MeasureSec) + fmt.Fprintf(w, " Min Ratio: %.0f%% (collapse threshold)\n", cfg.Stage2MinRatio*100) + if cfg.Stage2MinAbsMBs > 0 { + fmt.Fprintf(w, " Min Absolute Speed: %.1f MB/s\n", cfg.Stage2MinAbsMBs) + } else { + fmt.Fprintf(w, " Min Absolute Speed: (disabled)\n") + } + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Other Settings ---\n") + fmt.Fprintf(w, " Worker Count: %d\n", cfg.WorkerCount) + fmt.Fprintf(w, " Max Snapshot Attempts: %d\n", snapCfg.MaxSnapshotURLAttempts) + fmt.Fprintf(w, " Min Incr Speed: %.1f MB/s\n", snapCfg.MinIncrementalSpeedMBs) + fmt.Fprintf(w, "\n") + + // Incremental Base Matching Statistics + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " INCREMENTAL BASE MATCHING STATISTICS\n") + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "--- Incremental Base Matching ---\n") + fmt.Fprintf(w, " Nodes with full snapshot: %d\n", incBaseStats.totalWithFull) + fmt.Fprintf(w, " Nodes with any incremental: %d\n", incBaseStats.totalWithInc) + fmt.Fprintf(w, " Unique full snapshot slots: %d\n", incBaseStats.uniqueFullSlots) + fmt.Fprintf(w, " Unique incremental bases: %d\n", incBaseStats.uniqueIncBases) + fmt.Fprintf(w, " Full slots with matching inc: %d\n", incBaseStats.matchingFullSlots) + fmt.Fprintf(w, " Nodes after base matching: %d\n", incBaseStats.afterIncBaseMatch) + fmt.Fprintf(w, "\n") + + // Full Snapshot Slots Distribution + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " FULL SNAPSHOT SLOTS DISTRIBUTION\n") + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, "\n") + + // Count nodes by full snapshot slot + slotCounts := make(map[int64]int) + for _, node := range rankedNodes { + if node.Result.FullSlot > 0 { + slotCounts[node.Result.FullSlot]++ + } + } + + // Sort slots by count (descending), then by slot (descending) + type slotCount struct { + slot int64 + count int + } + var sortedSlots []slotCount + for slot, count := range slotCounts { + sortedSlots = append(sortedSlots, slotCount{slot, count}) + } + sort.Slice(sortedSlots, func(i, j int) bool { + if sortedSlots[i].count != sortedSlots[j].count { + return sortedSlots[i].count > sortedSlots[j].count + } + return sortedSlots[i].slot > sortedSlots[j].slot + }) + + fmt.Fprintf(w, "Slot Age (slots) Nodes Notes\n") + fmt.Fprintf(w, "-----------------------------------------------------\n") + for _, sc := range sortedSlots { + age := referenceSlot - int(sc.slot) + notes := "" + if age > cfg.FullThreshold { + notes = "(outside threshold)" + } + fmt.Fprintf(w, "%-16d %10d %5d %s\n", sc.slot, age, sc.count, notes) + } + fmt.Fprintf(w, "\n") + + // Stage 2 Ranked Nodes + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " STAGE 2 RANKED NODES (TOP %d)\n", len(rankedNodes)) + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, "\n") + + fmt.Fprintf(w, "Rank RPC Version RTT S1 MB/s S2 MB/s Full Slot\n") + fmt.Fprintf(w, "--------------------------------------------------------------------------------------------\n") + + maxShow := 50 + if len(rankedNodes) < maxShow { + maxShow = len(rankedNodes) + } + for i := 0; i < maxShow; i++ { + node := rankedNodes[i] + // Extract IP from RPC URL + nodeIP := node.Result.RPC + if idx := strings.Index(nodeIP, "://"); idx != -1 { + nodeIP = nodeIP[idx+3:] + } + // Truncate if too long + if len(nodeIP) > 30 { + nodeIP = nodeIP[:27] + "..." + } + + s2Speed := "-" + if node.S2.MinMBs > 0 { + s2Speed = fmt.Sprintf("%.1f", node.S2.MinMBs) + } + + fmt.Fprintf(w, "%4d %-30s %-10s %4dms %7.1f %7s %d\n", + i+1, + nodeIP, + node.Result.Version, + int(node.Result.Latency), + node.S1.MedianMBs, + s2Speed, + node.Result.FullSlot, + ) + } + if len(rankedNodes) > maxShow { + fmt.Fprintf(w, "... and %d more nodes\n", len(rankedNodes)-maxShow) + } + fmt.Fprintf(w, "\n") + + // Footer + fmt.Fprintf(w, "================================================================================\n") + fmt.Fprintf(w, " END OF LOG\n") + fmt.Fprintf(w, "================================================================================\n") + + return logPath, nil +}