-
Notifications
You must be signed in to change notification settings - Fork 75
update or validate parallel dimension map with known cta shape #5835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: llu/use_unpadded_cta_shape
Are you sure you want to change the base?
Conversation
Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Missing validation for block dimension types
|
Test failures
-
(Medium, 1)
Thunder–Torch scalar mismatch in nanogpt autograd test (thunder.tests.test_networks)Test Name H100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
Greptile SummaryThis PR separates the concept of "compute threads" from "total CTA threads" to correctly handle warp specialization padding in the parallel dimension map. Key Changes:
Impact: Issue Found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as normalization_inner_tma
participant Params as CompileParams
participant PDM as ParallelDimensionMap
participant GPU as GpuLower
Note over Scheduler: Calculate CTA dimensions
Scheduler->>Scheduler: compute_bdimx = bdimx (before padding)
Scheduler->>Scheduler: compute_bdimy = bdimy
Scheduler->>Scheduler: compute_bdimz = bdimz
alt Warp Specialization on TIDy
Scheduler->>Scheduler: bdimy += 1 (add padding)
else Warp Specialization on TIDx
Scheduler->>Scheduler: bdimx += 128 (add padding)
end
Scheduler->>Params: Set bdimx/y/z (padded values)
Scheduler->>Params: Set compute_bdimx/y/z (unpadded)
Note over PDM: Parallel Dimension Mapping
PDM->>PDM: Build dim_map from fusion
PDM->>PDM: adjustMappingsForWarpSpecialization()
loop For each non-warp-specialized parallel type
PDM->>GPU: getStaticComputeThreadsInDim(pt)
GPU->>Params: Read compute_bdimx/y/z
Params-->>GPU: Return compute thread count
GPU-->>PDM: Return thread count
alt dim_map[pt] is const scalar
PDM->>PDM: Validate matches thread_count
else dim_map[pt] is dynamic
PDM->>PDM: Update dim_map[pt] = thread_count
end
end
PDM->>PDM: Calculate padding and apply to ws_pt
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
| NVF_ERROR( | ||
| dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Missing error message in NVF_ERROR. Should include details about the mismatch.
| NVF_ERROR( | |
| dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt); | |
| NVF_ERROR( | |
| dim_map_.at(pt)->evaluate().as<int64_t>() == thread_count_for_pt, | |
| "Parallel dimension mismatch for ", | |
| pt, | |
| ": expected ", | |
| thread_count_for_pt, | |
| " from compile params, but dim_map has ", | |
| dim_map_.at(pt)->evaluate().as<int64_t>()); |
|
!test |
Before:
After: