-
Notifications
You must be signed in to change notification settings - Fork 75
avoid warp diverge in warp specialized kernel #5830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 1c3c8a1 Description
|
| Relevant files | |||
|---|---|---|---|
| Bug fix |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Validation Logic
|
Greptile SummaryThis PR adds validation to prevent warp divergence in warp-specialized kernels when using TIDx parallelization. The issue occurs due to CUDA's thread linearization formula (
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as Test Suite
participant KE as KernelExecutor
participant PDM as ParallelDimensionMap
participant Validation as Warp Specialization Validation
Test->>Test: Define CTA shape (e.g., dim3(32,4,2))
Test->>Test: Set ws_pt = ParallelType::TIDx
Test->>KE: compile(fusion, inputs)
KE->>PDM: adjustMappingsForWarpSpecialization()
PDM->>PDM: Calculate other_active_threads = bdimy * bdimz
PDM->>PDM: Calculate ws_num_threads_pad = 128 / other_active_threads
PDM->>PDM: Calculate after_pad = original_tidx + ws_num_threads_pad
alt ws_pt == TIDx
PDM->>Validation: Check original_tidx % 32 == 0
alt Check fails
Validation-->>KE: Throw error: bdimx must be multiple of 32
KE-->>Test: Exception propagated
end
PDM->>Validation: Check after_pad % 32 == 0
alt Check fails
Validation-->>KE: Throw error: padded bdimx must be multiple of 32
KE-->>Test: Exception propagated
Test->>Test: Verify expected error message
end
end
PDM->>PDM: Apply padding to dimension map
KE-->>Test: Compilation successful
|
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| if (ws_pt == ParallelType::TIDx && | ||
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)
| if (ws_pt == ParallelType::TIDx && | |
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { | |
| if (ws_pt == ParallelType::TIDx && | |
| (bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) { |
is the test suite intended to only cover cases where original bdimx is a multiple of 32?
No description provided.