Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Jan 14, 2026

Stacked on #5716.

This PR just allows for using multi-dimensional extent tensors in asNested. RaggedIterDomain::combine is not yet updated and will be the subject of a next PR.

Remaining major TODOs:

  • Extend combine to support multi-dim partitioning
  • Validate component iter domains in combine

@naoyam naoyam changed the base branch from main to ragged_combine January 14, 2026 19:57
@github-actions
Copy link

github-actions bot commented Jan 14, 2026

Review updated until commit 1fde60b

Description

  • Enable multi-dimensional extents tensors in RaggedIterDomain::partition and asNested

  • Add shape validation for multi-dimensional extents (extents.ndim - 1 == ragged_dim)

  • Update component extent calculation to use last dimension of extents tensor

  • Add comprehensive tests for 2D extents and shape validation

Changes walkthrough

Relevant files
Enhancement
internal_base_nodes.cpp
Enable multi-dimensional extents in partition function     

csrc/ir/internal_base_nodes.cpp

  • Allow multi-dimensional extents in RaggedIterDomain::partition by
    filtering reduction dimensions
  • Update component extent calculation to use last dimension of N-D
    extents tensor
  • Add detailed comments explaining multi-dimensional extents support
  • Update combine function to filter reduction dimensions and improve
    error messages
  • +27/-13 
    alias.cpp
    Support multi-dimensional extents in asNested                       

    csrc/ops/alias.cpp

  • Remove 1D extents restriction from asNested function
  • Add validation for multi-dimensional extents shape correspondence
  • Implement rule: extents.ndim - 1 == ragged_dim for N-D extents
  • Add comprehensive error messages for invalid shapes
  • +27/-7   
    Tests
    test_ragged_iter_domain.cpp
    Add tests for multi-dimensional extents support                   

    tests/cpp/test_ragged_iter_domain.cpp

  • Remove test expecting multi-dimensional extents to fail in partition
  • Add AsNested2DOffsets test for 2D extents functionality
  • Add AsNestedInvalidShape test for shape validation
  • Update tests to verify multi-dimensional extents behavior
  • +49/-12 
    Documentation
    internal_base_nodes.h
    Update documentation for multi-dimensional extents             

    csrc/ir/internal_base_nodes.h

  • Update partition function documentation to explain multi-dimensional
    extents
  • Add examples for 1D and 2D extents tensor shapes
  • Clarify that outer dimensions correspond to tensor outer dimensions
  • +10/-8   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Incomplete Implementation

    The PR description states that RaggedIterDomain::combine is not yet updated for multi-dimensional extents and will be the subject of a next PR. However, the partition function now creates RaggedIterDomain objects with multi-dimensional extents, but combine() will fail when trying to process them. This could lead to runtime errors if users try to combine ragged iter domains created with multi-dimensional extents. Consider adding a runtime check or documentation warning about this limitation.

    // The combined extent is the sum of all extents in the ragged dimension
    // For a 1D extents tensor [e0, e1, ..., en-1], the total is sum(extents)
    TensorView* extents_tv = ragged->extents();
    NVF_ERROR(extents_tv != nullptr, "combine: ragged extents tensor is null");
    
    // Filter out reduction dimensions before checking
    auto extents_no_reduction =
        extents_tv->getLogicalDomain() | TensorDomain::kNoReductions;
    // Multi-dimensional extents are not yet supported in combine
    auto extents_ndim = std::ranges::distance(extents_no_reduction);
    NVF_ERROR_EQ(
        extents_ndim,
        1,
        "combine: Multi-dimensional extents are not yet supported. ",
        "Expected 1D extents tensor, got ",
        extents_ndim,
        "D extents: ",
        extents_tv->toString());
    Shape Validation Logic

    The validation logic for multi-dimensional extents (extents.ndim - 1 == ragged_dim) seems correct, but the error message could be clearer about what the expected relationship is. The current message mentions the rule but doesn't clearly explain why this constraint exists or what it means for the tensor shapes.

    // Validate shape correspondence for multi-dimensional extents
    // For N-D extents, outer dimensions must match outer dimensions of input
    // tensor Rule: extents.ndim - 1 == ragged_dim (except 1D extents which are
    // always valid).
    if (extents_ndim > 1) {
      NVF_ERROR_EQ(
          extents_ndim - 1,
          ragged_dim,
          "asNested: Multi-dimensional extents require shape ",
          "[d0, d1, ..., d(axis-1), num_components]. ",
          "Got ",
          extents_ndim,
          "D extents for partitioning axis ",
          ragged_dim);
    }

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 14, 2026

    Greptile Summary

    Extends RaggedIterDomain::partition() and asNested() to support multi-dimensional extent tensors while maintaining the existing 1D-only restriction for combine().

    • Removed 1D-only constraint from partition() and asNested(), allowing N-dimensional extent tensors
    • Added shape validation ensuring outer dimensions of multi-dim extents match input tensor structure: extents.ndim - 1 == ragged_dim
    • Last dimension of extents always defines number of components regardless of dimensionality
    • Updated combine() to explicitly reject multi-dim extents with clear error message (deferred to future PR)
    • Comprehensive test coverage for 2D extents and invalid shape scenarios

    Confidence Score: 5/5

    • Safe to merge - well-tested incremental feature addition with appropriate validation
    • The implementation correctly extends multi-dim extent support with proper validation logic, comprehensive test coverage including edge cases, and clear documentation. The staged approach (partition now, combine later) is explicitly documented.
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/ir/internal_base_nodes.cpp Extends partition() to support multi-dimensional extents by using last dimension as component count; updates combine() validation to explicitly reject multi-dim extents
    csrc/ir/internal_base_nodes.h Updated documentation for partition() to describe multi-dimensional extents support with examples
    csrc/ops/alias.cpp Adds shape validation for multi-dim extents in asNested() requiring extents.ndim - 1 == ragged_dim; removes previous 1D-only restriction
    tests/cpp/test_ragged_iter_domain.cpp Adds test for 2D extents partitioning and invalid shape validation; removes obsolete test for multi-dim extents failure

    Sequence Diagram

    sequenceDiagram
        participant User
        participant asNested
        participant Validation
        participant partition
        participant RaggedIterDomain
    
        User->>asNested: asNested(data[D,tokens,hidden], extents[D,K], axis=1)
        asNested->>Validation: Filter reduction dims from extents
        Validation->>Validation: Check extents_ndim > 0
        Validation->>Validation: If extents_ndim > 1: validate extents_ndim-1 == ragged_dim
        Note over Validation: For 2D extents at axis 1: 2-1=1 == 1 ✓
        asNested->>partition: partition(root_domain[axis], extents)
        partition->>partition: Filter reduction dims from extents
        partition->>partition: Get last dim extent for component_extent
        Note over partition: component_extent = extents[..., K]
        partition->>RaggedIterDomain: Create component_id with extent K
        partition->>RaggedIterDomain: Create ragged_id with N-D extents
        partition-->>asNested: Return (component_id, ragged_id)
        asNested->>asNested: Build logical_domain replacing axis with [component, ragged]
        asNested-->>User: TensorView[D, component=K, ragged, hidden]
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (2)

    1. csrc/ir/internal_base_nodes.cpp, line 1126-1131 (link)

      logic: the partition function now supports multi-dimensional extents (as documented in lines 1037-1041 and the PR title), but combine still restricts extents to 1D only. this will cause runtime errors when trying to combine ragged domains created with 2D or 3D extents.

    2. csrc/ir/internal_base_nodes.cpp, line 989-992 (link)

      logic: cloneWithoutRFactor(true) is called in several places (e.g., csrc/scheduler/tools/loop_domain_scheduler.cpp:301, tests/cpp/test_id_model.cpp:2546). if any of these paths encounter a RaggedIterDomain, this will throw an exception. consider implementing the mapping logic or adding validation earlier to prevent RaggedIterDomain from reaching code paths that require this functionality. are ragged domains expected to be used in contexts where cloneWithoutRFactor with mapping is called?

    12 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 3 comments

    Edit Code Review Agent Settings | Greptile

    fusion.zeroVal(), IrBuilder::create<Val>(-1, DataType::Index))
    .build();

    // Create 2D offset tensor (e.g., [num_gpus, num_experts + 1])
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: inconsistent terminology: comment says "offset tensor" but parameter is called extents throughout the codebase

    EXPECT_TRUE(ragged_id->isA<RaggedIterDomain>());
    EXPECT_TRUE(ragged_id->extents() != nullptr);

    // Verify the extents tensor is 2D (one dimension less in last dim than offsets)
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: comment mentions "offsets" but the parameter is named extents throughout the API

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (1)

    1. csrc/ir/internal_base_nodes.cpp, line 1131-1136 (link)

      logic: combine() still requires 1D extents but partition() now supports N-D extents. If you partition with 2D extents, combining will fail.

      Need to update combine() to handle multi-dimensional extents by summing over the last dimension, similar to how partition() extracts the last dimension as component extent.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 14, 2026

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 15, 2026

    !test

    @naoyam naoyam changed the title [WIP] RaggedIterDomain partitioning with multi-dim extents RaggedIterDomain partitioning with multi-dim extents Jan 15, 2026
    Base automatically changed from ragged_combine to main January 16, 2026 17:20
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (1)

    1. csrc/ops/alias.cpp, line 1324-1327 (link)

      logic: duplicate validation check - lines 1319-1322 already perform this exact check

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 16, 2026

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Jan 16, 2026

    !test

    @naoyam naoyam requested a review from wujingyue January 16, 2026 22:08
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants