Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@github-actions
Copy link

github-actions bot commented Jan 15, 2026

Review updated until commit 1c3c8a1

Description

  • Add validation to ensure bdimx is multiple of 32 for warp specialization on TIDx

  • Prevent warp divergence caused by thread linearization splitting warps across roles

  • Remove special case handling for TIDx==32 in circular buffering tests

  • Add test case to verify new warp divergence error checking

Changes walkthrough

Relevant files
Bug fix
parallel_dimension_map.cpp
Add warp divergence validation for TIDx specialization     

csrc/parallel_dimension_map.cpp

  • Added validation logic for warp specialization on TIDx requiring bdimx
    multiples of 32
  • Added error checking for both original and padded thread counts
  • Explained thread linearization issue that splits warps across
    producer/consumer boundaries
  • Prevents warp specialization failure when bdimx not divisible by 32
  • +37/-0   
    Tests
    test_circular_buffering.cpp
    Update tests for warp specialization validation                   

    tests/cpp/test_circular_buffering.cpp

  • Removed conditional register sharing logic for TIDx==32 case
  • Simplified circular buffer setup to use consistent approach
  • Added test case to verify warp divergence error when padded threads %
    32 != 0
  • Enhanced test coverage for new validation requirements
  • +16/-16 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Validation Logic

    The new validation logic correctly checks both original and padded thread counts for TIDx warp specialization. The error messages are comprehensive and provide clear guidance. However, verify that the validation doesn't impact performance for valid use cases and that the error conditions are properly tested across different CTA shapes.

    if (ws_pt == ParallelType::TIDx) {
      int64_t original_tidx = getThreadCountInDim(ws_pt);
      NVF_ERROR(
          original_tidx % 32 == 0,
          "Warp specialization on TIDx requires bdimx to be a multiple of 32 ",
          "to avoid splitting warps across producer/consumer boundaries. ",
          "Got bdimx = ",
          original_tidx,
          " with CTA shape (",
          original_tidx,
          ", ",
          getThreadCountInDim(ParallelType::TIDy),
          ", ",
          getThreadCountInDim(ParallelType::TIDz),
          ")");
      NVF_ERROR(
          after_pad % 32 == 0,
          "Warp specialization on TIDx requires padded bdimx to be a multiple of "
          "32 to avoid warp diverge. "
          "Got padded bdimx = ",
          after_pad,
          " (original: ",
          original_tidx,
          ", padding: ",
          ws_num_threads_pad,
          ")");
    }
    Test Coverage

    The test changes simplify the logic while adding proper error validation for the new warp divergence checks. Ensure the test adequately covers edge cases where getTmaPadThreads(ws_pt, bdim) % 32 != 0, and verify that the error message matching is robust across different compiler versions and CUDA implementations.

    // If ws_pt == ParallelType::TIDx and CTA shape is (32, 4, 2), padded
    // threads in x dim is 16, will cause warp divergence due to thread
    // linearization.
    if (ws_pt == ParallelType::TIDx &&
        getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
      const char* err_msg =
          R"(Warp specialization on TIDx requires padded bdimx to be a multiple of 32)";
      const char* str_match_pointer = strstr(e.what(), err_msg);
      ASSERT_TRUE(str_match_pointer != nullptr);
      return;
    }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 15, 2026

    Greptile Summary

    This PR adds validation to prevent warp divergence in warp-specialized kernels when using TIDx parallelization. The issue occurs due to CUDA's thread linearization formula (tidx + tidy * bdimx + tidz * bdimx * bdimy): when bdimx is not a multiple of 32, consecutive linear thread IDs can wrap to the next tidy value mid-warp, causing threads within the same warp to be assigned different roles (producer vs consumer).

    • Added two validation checks in adjustMappingsForWarpSpecialization() to ensure both original and padded bdimx are multiples of 32 when using TIDx warp specialization
    • Updated test to properly handle the new validation error, removing workaround code that avoided register sharing for problematic configurations
    • The fix prevents subtle correctness issues where warps would be split across producer/consumer boundaries

    Confidence Score: 4/5

    • This PR is safe to merge with minimal risk - adds important correctness validation
    • The validation logic is sound and addresses a real correctness issue with warp divergence. The test updates properly handle the new error cases, though the test condition could be slightly more robust for future test cases with non-32-multiple original dimensions
    • No files require special attention - both changes are straightforward validation additions

    Important Files Changed

    Filename Overview
    csrc/parallel_dimension_map.cpp Added validation to prevent warp divergence in TIDx warp specialization by requiring both original and padded bdimx to be multiples of 32
    tests/cpp/test_circular_buffering.cpp Updated test to handle new validation error, though condition could be more precise by checking after_pad rather than just pad amount

    Sequence Diagram

    sequenceDiagram
        participant Test as Test Suite
        participant KE as KernelExecutor
        participant PDM as ParallelDimensionMap
        participant Validation as Warp Specialization Validation
        
        Test->>Test: Define CTA shape (e.g., dim3(32,4,2))
        Test->>Test: Set ws_pt = ParallelType::TIDx
        Test->>KE: compile(fusion, inputs)
        KE->>PDM: adjustMappingsForWarpSpecialization()
        PDM->>PDM: Calculate other_active_threads = bdimy * bdimz
        PDM->>PDM: Calculate ws_num_threads_pad = 128 / other_active_threads
        PDM->>PDM: Calculate after_pad = original_tidx + ws_num_threads_pad
        
        alt ws_pt == TIDx
            PDM->>Validation: Check original_tidx % 32 == 0
            alt Check fails
                Validation-->>KE: Throw error: bdimx must be multiple of 32
                KE-->>Test: Exception propagated
            end
            PDM->>Validation: Check after_pad % 32 == 0
            alt Check fails
                Validation-->>KE: Throw error: padded bdimx must be multiple of 32
                KE-->>Test: Exception propagated
                Test->>Test: Verify expected error message
            end
        end
        
        PDM->>PDM: Apply padding to dimension map
        KE-->>Test: Compilation successful
    
    Loading

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +2542 to +2543
    if (ws_pt == ParallelType::TIDx &&
    getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)

    Suggested change
    if (ws_pt == ParallelType::TIDx &&
    getTmaPadThreads(ws_pt, bdim) % 32 != 0) {
    if (ws_pt == ParallelType::TIDx &&
    (bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) {

    is the test suite intended to only cover cases where original bdimx is a multiple of 32?

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants