Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jan 16, 2026

Before:

[ RUN      ] TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_4096
Parallel dimension map:
blockIdx.x: 152, exact
blockIdx.y: unused
blockIdx.z: unused
threadIdx.x: ( ( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 8) ), 4) ), 32) ) * 32 ), non-exact
threadIdx.y: ( 2 + 1 ), non-exact
threadIdx.z: unused

CUDA code:
// no register sharing due to dynamic `threadIdx.x`
mbarrier::init(toSmem((&T22[(i36 + 3LL)])), __to_uint32(((ceilDiv(i8, 32)) * 32)));

After:

Parallel dimension map:
blockIdx.x: 152, exact
blockIdx.y: unused
blockIdx.z: unused
threadIdx.x: 128, non-exact
threadIdx.y: ( 2 + 1 ), non-exact
threadIdx.z: unused

CUDA code:
// able to use register sharing
increaseRegisters<232>();
mbarrier::init(toSmem((&T22[(i36 + 3LL)])), 128U);

@liqiangxl liqiangxl changed the base branch from main to llu/use_unpadded_cta_shape January 16, 2026 00:15
@github-actions
Copy link

github-actions bot commented Jan 16, 2026

Description

  • Add validation for block dimensions (bdimx, bdimy, bdimz) in warp specialization

  • Check const scalar block dimensions match expected thread count

  • Update non-const scalar block dimensions to expected thread count

  • Ensure proper dimension mapping for warp specialization kernels

Changes walkthrough

Relevant files
Enhancement
parallel_dimension_map.cpp
Add block dimension validation in warp specialization       

csrc/parallel_dimension_map.cpp

  • Added validation logic for bdimx, bdimy, bdimz in warp specialization
  • Check if const scalar dimensions match thread_count_for_pt
  • Update non-const scalar dimensions to thread_count_for_pt
  • Enhanced error handling for dynamic block dimensions
  • +13/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Missing validation for block dimension types

    The code adds logic to handle bdimx, bdimy, bdimz but doesn't explicitly validate that 'pt' is actually one of these block dimensions. While the comment suggests this is the intent, there's no runtime check to ensure 'pt' is a block dimension before performing operations. Consider adding explicit validation or ensuring this is guaranteed by the calling context.

    // If bdimx, bdimy, or bdimz is used
    // If it is const scalar, check if it is equal to thread_count_for_pt
    // If it is not const scalar, update the dimension to thread_count_for_pt
    if (dim_map_.contains(pt)) {
      if (dim_map_.at(pt)->isConstScalar()) {
        NVF_ERROR(
            dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt);
      } else {
        dim_map_[pt] =
            IrBuilder::create<Val>(thread_count_for_pt, DataType::Index);
      }
    }
    Potential division by zero or invalid thread count

    The code uses 'thread_count_for_pt' directly without validating that it's a positive value. While this value comes from earlier in the function, adding a basic validation could prevent potential issues if the thread count calculation changes in the future.

    dim_map_[pt] =
        IrBuilder::create<Val>(thread_count_for_pt, DataType::Index);

    Test failures

    • (Medium, 1) Thunder–Torch scalar mismatch in nanogpt autograd test (thunder.tests.test_networks)

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 16, 2026

    Greptile Summary

    This PR separates the concept of "compute threads" from "total CTA threads" to correctly handle warp specialization padding in the parallel dimension map.

    Key Changes:

    • Added compute_bdimx/y/z fields to CompileParams to track threads used for computation (excluding warp specialization padding)
    • Renamed getThreadCountInDim() to getStaticComputeThreadsInDim() and updated it to use compute_bdim* values when available
    • Enhanced adjustMappingsForWarpSpecialization() to validate or update the parallel dimension map based on known CTA shape from compile params
    • Modified the normalization scheduler to track and set both padded and unpadded thread counts

    Impact:
    The changes ensure that register sharing calculations and parallel dimension tracking correctly account for the actual compute threads when warp specialization adds padding threads for async loads. This prevents incorrect thread count assumptions in kernels using warp specialization.

    Issue Found:

    • Missing error message in validation check (line 208-209 of parallel_dimension_map.cpp)

    Confidence Score: 4/5

    • This PR is safe to merge with one minor issue that should be addressed
    • The implementation correctly separates compute threads from padded CTA threads, which is crucial for warp specialization. The logic is sound and the changes are well-contained. However, there's a missing error message in the validation logic that should be added for better debugging. The changes don't introduce any logical bugs, but the missing error message reduces code quality.
    • Pay attention to csrc/parallel_dimension_map.cpp for the missing error message in the validation check

    Important Files Changed

    Filename Overview
    csrc/runtime/executor_params.h Added compute_bdimx/y/z fields to distinguish compute threads from total CTA threads with warp specialization padding
    csrc/parallel_dimension_map.h Renamed getThreadCountInDim to getStaticComputeThreadsInDim for clarity on compute-only threads
    csrc/parallel_dimension_map.cpp Updated to use compute_bdim* fields and added validation/update logic for parallel dimensions with known CTA shape
    csrc/scheduler/normalization_inner_tma.cpp Tracks compute thread counts separately from padded dimensions and sets compute_bdim* compile params

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as normalization_inner_tma
        participant Params as CompileParams
        participant PDM as ParallelDimensionMap
        participant GPU as GpuLower
        
        Note over Scheduler: Calculate CTA dimensions
        Scheduler->>Scheduler: compute_bdimx = bdimx (before padding)
        Scheduler->>Scheduler: compute_bdimy = bdimy
        Scheduler->>Scheduler: compute_bdimz = bdimz
        
        alt Warp Specialization on TIDy
            Scheduler->>Scheduler: bdimy += 1 (add padding)
        else Warp Specialization on TIDx
            Scheduler->>Scheduler: bdimx += 128 (add padding)
        end
        
        Scheduler->>Params: Set bdimx/y/z (padded values)
        Scheduler->>Params: Set compute_bdimx/y/z (unpadded)
        
        Note over PDM: Parallel Dimension Mapping
        PDM->>PDM: Build dim_map from fusion
        PDM->>PDM: adjustMappingsForWarpSpecialization()
        
        loop For each non-warp-specialized parallel type
            PDM->>GPU: getStaticComputeThreadsInDim(pt)
            GPU->>Params: Read compute_bdimx/y/z
            Params-->>GPU: Return compute thread count
            GPU-->>PDM: Return thread count
            
            alt dim_map[pt] is const scalar
                PDM->>PDM: Validate matches thread_count
            else dim_map[pt] is dynamic
                PDM->>PDM: Update dim_map[pt] = thread_count
            end
        end
        
        PDM->>PDM: Calculate padding and apply to ws_pt
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +208 to +209
    NVF_ERROR(
    dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: Missing error message in NVF_ERROR. Should include details about the mismatch.

    Suggested change
    NVF_ERROR(
    dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt);
    NVF_ERROR(
    dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt,
    "Parallel dimension mismatch for ",
    pt,
    ": expected ",
    thread_count_for_pt,
    " from compile params, but dim_map has ",
    dim_map_.at(pt)->evaluate().as<int64_t>());

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants