Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@github-actions
Copy link

Description

  • Add circular buffer support for TMA inner persistent scheduler

  • Implement grouped reduction handling for warp specialized mode

  • Add inline optimization logic for TMA TVs and broadcast dependencies

  • Configure circular buffer stages and prefetch distance for TMA operations

Changes walkthrough

Relevant files
Enhancement
normalization_inner_tma.cpp
Add circular buffer and warp specialization support           

csrc/scheduler/normalization_inner_tma.cpp

  • Added conditional logic for circular buffer options with grouped
    reduction support
  • Implemented TMA TV inline optimization at specific positions (after
    BIDx)
  • Added dependency handling for broadcast TVs in multi-reduction
    scenarios
  • Configured circular buffer stages and prefetch parameters for TMA
    operations
  • +59/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Performance Impact Validation

    The new circular buffer and aggressive inlining logic (lines 395-428) introduces significant complexity for TMA operations. The nested dependency analysis and inlining decisions could have substantial performance implications that need thorough benchmarking and validation against existing performance baselines.

    if (params->circular_buffer_options.isEnable()) {
      // when warp specialized, the iteration domain of tma tv is scheduled as:
      // 1. GridStrideLoop
      // 2. BIDx
      // 3. Serial (Compute Warp Groups, TIDy in compute warp groups)
      // 4. Serial (Multiple TMAs share one mbarrier, serial or grouped reduction
      //            in compuate warp groups)
      constexpr int64_t pos_after_bidx = 2;
      for (auto tv : tma_tvs) {
        inlineSelectedAt({tv}, tv, pos_after_bidx);
        exclude_tvs.insert(tv);
      }
    
      // Happens in layer norm where the result of the 1st reduction is used by
      // the 2nd reduction. Since each reduction is grouped in its iteration
      // dimension we can't inline deeper than the group position.
      if (group_pos > 0 && reduction_tvs.size() > 1) {
        for (auto tv1 : reduction_tvs) {
          for (auto tv2 : reduction_tvs) {
            if (tv1 == tv2) {
              continue;
            }
            auto all_vals = DependencyCheck::getAllValsBetween({tv1}, {tv2});
            auto gp_tvs = ir_utils::filterByType<TensorView>(all_vals);
            for (auto gp_tv : gp_tvs) {
              if (gp_tv->hasBroadcast() && !exclude_tvs.contains(gp_tv)) {
                inlineSelectedAt({gp_tv}, gp_tv, group_pos);
                exclude_tvs.insert(gp_tv);
              }
            }
          }
        }
      }
    }
    Error Handling Robustness

    The code assumes specific conditions for group_pos values (lines 375-378) and performs dependency analysis without comprehensive error handling. The dependency checking and filtering operations could fail or be expensive for complex fusion patterns.

    NVF_CHECK_EQ(
        group_pos,
        -1,
        "Grouped reduction is only supported in warp specialized mode");
    Memory and Resource Management

    The circular buffer implementation (lines 433-444) could have significant memory implications, especially with multiple TMA TVs. The stage and prefetch distance parameters need validation to ensure they don't cause memory overflow or performance degradation.

    if (params->circular_buffer_options.isEnable()) {
      int64_t number_of_stages = params->circular_buffer_options.stage;
      int64_t prefetch_distance = params->circular_buffer_options.prefetch;
      CircularBufferType circular_buffer_type =
          params->circular_buffer_options.type;
      for (auto tv : tma_tvs) {
        if (tv->getComputeAtPosition() > 0) {
          tv->circularBuffer(
              number_of_stages, prefetch_distance, circular_buffer_type);
        }
      }
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants