-
Notifications
You must be signed in to change notification settings - Fork 75
RaggedIterDomain partitioning with multi-dim extents #5823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 1fde60b Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
| ||||
| Documentation |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Incomplete Implementation
|
Greptile SummaryExtends
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant asNested
participant Validation
participant partition
participant RaggedIterDomain
User->>asNested: asNested(data[D,tokens,hidden], extents[D,K], axis=1)
asNested->>Validation: Filter reduction dims from extents
Validation->>Validation: Check extents_ndim > 0
Validation->>Validation: If extents_ndim > 1: validate extents_ndim-1 == ragged_dim
Note over Validation: For 2D extents at axis 1: 2-1=1 == 1 ✓
asNested->>partition: partition(root_domain[axis], extents)
partition->>partition: Filter reduction dims from extents
partition->>partition: Get last dim extent for component_extent
Note over partition: component_extent = extents[..., K]
partition->>RaggedIterDomain: Create component_id with extent K
partition->>RaggedIterDomain: Create ragged_id with N-D extents
partition-->>asNested: Return (component_id, ragged_id)
asNested->>asNested: Build logical_domain replacing axis with [component, ragged]
asNested-->>User: TensorView[D, component=K, ragged, hidden]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
csrc/ir/internal_base_nodes.cpp, line 1126-1131 (link)logic: the
partitionfunction now supports multi-dimensional extents (as documented in lines 1037-1041 and the PR title), butcombinestill restricts extents to 1D only. this will cause runtime errors when trying to combine ragged domains created with 2D or 3D extents. -
csrc/ir/internal_base_nodes.cpp, line 989-992 (link)logic:
cloneWithoutRFactor(true)is called in several places (e.g.,csrc/scheduler/tools/loop_domain_scheduler.cpp:301,tests/cpp/test_id_model.cpp:2546). if any of these paths encounter aRaggedIterDomain, this will throw an exception. consider implementing the mapping logic or adding validation earlier to preventRaggedIterDomainfrom reaching code paths that require this functionality. are ragged domains expected to be used in contexts where cloneWithoutRFactor with mapping is called?
12 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 3 comments
| fusion.zeroVal(), IrBuilder::create<Val>(-1, DataType::Index)) | ||
| .build(); | ||
|
|
||
| // Create 2D offset tensor (e.g., [num_gpus, num_experts + 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: inconsistent terminology: comment says "offset tensor" but parameter is called extents throughout the codebase
| EXPECT_TRUE(ragged_id->isA<RaggedIterDomain>()); | ||
| EXPECT_TRUE(ragged_id->extents() != nullptr); | ||
|
|
||
| // Verify the extents tensor is 2D (one dimension less in last dim than offsets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: comment mentions "offsets" but the parameter is named extents throughout the API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
csrc/ir/internal_base_nodes.cpp, line 1131-1136 (link)logic:
combine()still requires 1D extents butpartition()now supports N-D extents. If you partition with 2D extents, combining will fail.Need to update
combine()to handle multi-dimensional extents by summing over the last dimension, similar to howpartition()extracts the last dimension as component extent.
4 files reviewed, 1 comment
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
csrc/ops/alias.cpp, line 1324-1327 (link)logic: duplicate validation check - lines 1319-1322 already perform this exact check
4 files reviewed, 1 comment
|
!test |
|
!test |
Stacked on #5716.
This PR just allows for using multi-dimensional extent tensors in
asNested.RaggedIterDomain::combineis not yet updated and will be the subject of a next PR.Remaining major TODOs:
combineto support multi-dim partitioningcombine