Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jan 16, 2026

Explicitly tracking unpadded compute thread dimensions, then we can derive the padded value, and further validate if padded value is provided in cparams.

@github-actions
Copy link

github-actions bot commented Jan 16, 2026

Description

  • Rename getThreadCountInDim to getStaticComputeThreadsInDim for clarity

  • Add compute_bdimx/y/z fields to CompileParams for unpadded thread counts

  • Update normalization scheduler to track both padded and unpadded thread dimensions

  • Use unpadded compute threads for register sharing calculations

Changes walkthrough

Relevant files
Enhancement
parallel_dimension_map.cpp
Rename function and use compute thread fields                       

csrc/parallel_dimension_map.cpp

  • Rename function getThreadCountInDim to getStaticComputeThreadsInDim
  • Change from using cparams.bdimx/y/z to cparams.compute_bdimx/y/z
  • Update both adjustMappingsForWarpPadding and
    adjustMappingsForWarpSpecialization callers
  • +9/-9     
    normalization_inner_tma.cpp
    Track unpadded compute thread dimensions                                 

    csrc/scheduler/normalization_inner_tma.cpp

  • Add compute_bdimx/y/z variables to track unpadded thread counts
  • Set compute thread dimensions before warp specialization padding
  • Store compute thread counts in params->cparams.compute_bdimx/y/z
  • Use unpadded compute threads for register sharing calculation
  • +7/-2     
    parallel_dimension_map.h
    Update function declaration                                                           

    csrc/parallel_dimension_map.h

  • Update function declaration to match new name
    getStaticComputeThreadsInDim
  • +1/-1     
    executor_params.h
    Add compute thread fields to CompileParams                             

    csrc/runtime/executor_params.h

  • Add compute_bdimx/y/z optional fields to CompileParams struct
  • Update equality operator to include new compute thread fields
  • Add documentation comment for compute thread fields
  • +9/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    Missing Error Handling

    The new getStaticComputeThreadsInDim function doesn't have the same error handling as getThreadCountInDim. Specifically, it lacks the NVF_ERROR check for GpuLower::hasCurrent() and doesn't handle the case when compile-time CTA shape is not known. This could lead to runtime crashes or undefined behavior.

    int64_t ParallelDimensionMap::getStaticComputeThreadsInDim(ParallelType pt) {
      if (!dim_map_.contains(pt)) {
        return 1;
      }
      if (dim_map_.at(pt)->isConstScalar()) {
        return dim_map_.at(pt)->value().as<int64_t>();
      }
      // If dimension is dynamic but we have compile-time CTA shape available,
    Missing Documentation

    The new compute_bdimx/y/z fields are added without documentation explaining their purpose and relationship to the existing bdimx/y/z fields. This could lead to confusion about when to use which field.

    // Threads used for computation, excluding warp specialization padding
    std::optional<int64_t> compute_bdimx = std::nullopt;
    std::optional<int64_t> compute_bdimy = std::nullopt;
    std::optional<int64_t> compute_bdimz = std::nullopt;

    Test failures

    • (Medium, 1) Scalar numerical mismatch in thunder GPT nvFuser CUDA test

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32
    • (Medium, 1) nvFuser PingPongCircularBuffering large numerical mismatch on H100

      Test Name H100 Source
      PingPongCircularBuffering.StageSlicePositionComputeAt/stage_slice_position_4 Link

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 16, 2026

    Greptile Summary

    This PR separates the storage of padded and unpadded thread dimensions in compile parameters to fix register sharing calculations for warp-specialized kernels.

    Key Changes:

    • Added compute_bdimx, compute_bdimy, and compute_bdimz fields to CompileParams to store unpadded thread counts used for actual computation
    • Renamed getThreadCountInDim() to getStaticComputeThreadsInDim() and updated it to read from the new unpadded compute fields
    • Updated the normalization inner TMA scheduler to capture unpadded thread dimensions before applying warp specialization padding
    • The existing bdimx, bdimy, bdimz fields continue to store the total padded thread dimensions for kernel launch

    Rationale:
    In warp-specialized kernels, thread blocks are padded (e.g., adding 128 threads for async warps), but register sharing calculations need to use only the actual compute threads, not the padded total. Previously, the code attempted to compute unpadded values by subtracting kWarpSpecializationPaddedThreads, but this approach was error-prone. The new approach stores both values explicitly at the point where padding is applied.

    Confidence Score: 5/5

    • This PR is safe to merge with no identified issues
    • The changes are well-structured and improve code correctness by explicitly tracking unpadded compute thread dimensions. The refactoring separates concerns clearly (padded dimensions for launch vs unpadded dimensions for computation), follows the existing code patterns, maintains backward compatibility by preserving the original bdim* fields, and the equality operator was properly updated to include the new fields
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/runtime/executor_params.h Added three new optional fields (compute_bdimx, compute_bdimy, compute_bdimz) to store unpadded thread dimensions for computation, and updated the equality operator to include these fields
    csrc/parallel_dimension_map.h Renamed method from getThreadCountInDim to getStaticComputeThreadsInDim to better reflect that it returns unpadded compute thread counts
    csrc/parallel_dimension_map.cpp Updated method implementation to use compute_bdim* fields instead of padded bdim* fields for register sharing calculations in warp specialization
    csrc/scheduler/normalization_inner_tma.cpp Added initialization of compute_bdim* variables before warp padding is applied, and stores these unpadded values in cparams for correct register sharing computation

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as normalization_inner_tma.cpp
        participant CompileParams as executor_params.h (CompileParams)
        participant ParallelDimMap as parallel_dimension_map.cpp
        participant RegisterSharing as Register Sharing Logic
    
        Note over Scheduler: Initialize thread dimensions
        Scheduler->>Scheduler: compute_bdimx = bdimx (before padding)
        Scheduler->>Scheduler: compute_bdimy = 1
        Scheduler->>Scheduler: compute_bdimz = 1
        
        Note over Scheduler: Apply warp specialization padding
        alt Warp Specialization on TIDy
            Scheduler->>Scheduler: bdimy += 1 (add padding)
        else Warp Specialization on TIDx
            Scheduler->>Scheduler: bdimx += kWarpSpecializationPaddedThreads
        end
        
        Note over Scheduler: Store both padded and unpadded values
        Scheduler->>CompileParams: cparams.bdimx = bdimx (padded)
        Scheduler->>CompileParams: cparams.bdimy = bdimy (padded)
        Scheduler->>CompileParams: cparams.bdimz = bdimz (padded)
        Scheduler->>CompileParams: cparams.compute_bdimx = compute_bdimx (unpadded)
        Scheduler->>CompileParams: cparams.compute_bdimy = compute_bdimy (unpadded)
        Scheduler->>CompileParams: cparams.compute_bdimz = compute_bdimz (unpadded)
        
        Note over Scheduler: Calculate register sharing for compute threads
        Scheduler->>RegisterSharing: getRegisterSharing(reg_per_thread, <br/>compute_bdimx * compute_bdimy * compute_bdimz, <br/>kWarpSpecializationPaddedThreads)
        
        Note over ParallelDimMap: Later during lowering
        ParallelDimMap->>CompileParams: Read compute_bdim* values
        ParallelDimMap->>ParallelDimMap: getStaticComputeThreadsInDim(pt)<br/>returns unpadded thread count
        ParallelDimMap->>ParallelDimMap: Use unpadded values for <br/>register sharing calculations
    
    Loading

    @liqiangxl liqiangxl requested a review from rdspring1 January 16, 2026 00:24
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants