-
Notifications
You must be signed in to change notification settings - Fork 75
Change from thread idx based predicate to warp idx based #5821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 122ac43 Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Error Handling
|
Test failures
-
(Medium, 5)
NVFuser internal assert: NumComputeWarps not constant in TmaPersistent testsTest Name GB200 Source TmaPersistentTestF.KernelReuse ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_1024 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_2048 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_4096 ❌ Link TmaPersistentTestP.TmaInnerPersistentRmsNorm/__bfloat_2048_8192 ❌ Link -
(Medium, 2)
Thunder nvFuser nanoGPT autograd scalar mismatch in thunder.tests.test_networksTest Name GB200 H100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌ ❌ -
(Medium, 2)
nvFuser warp_specialize codegen assertion failure in tests.python.direct.test_tutorialTest Name GB200 H100 Source tests.python.direct.test_tutorial.test_warp_specialized_circular_buffering_pointwise ❌ ❌
Greptile SummaryThis PR refactors warp-specialized kernel predicate generation to use uniform warp IDs instead of thread indices, providing significant performance improvements through compiler optimization. Key Changes:
Performance Impact:
Architecture: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Compiler as NVIDIA Fuser Compiler
participant Alloc as AllocationInserter
participant PDM as ParallelDimensionMap
participant PC as PredicateCompute
participant CB as CircularBuffer
Compiler->>Alloc: Start allocation pass
activate Alloc
Alloc->>PDM: Check canUseWarpIdBasedPredicate()
activate PDM
PDM-->>Alloc: Returns true if consecutive warp IDs possible
deactivate PDM
alt Warp ID based predicates available
Alloc->>Alloc: Compute uniform warp ID<br/>(tid = tidx + tidy*bdimx + tidz*bdimx*bdimy)<br/>(warp_id = tid / 32)
Alloc->>Compiler: Store uniform_warp_id in GpuLower
Compiler->>PC: Generate predicates
activate PC
PC->>PC: selectFirstWarpElectSyncPredicate()
Note over PC: Use warp_id == 0
PC->>PC: createElectSyncPredicateAsync()
Note over PC: Use warp_id == num_compute_warps
PC->>PC: createMultipleExpressionElectSync()
Note over PC: Route to async/compute warp predicates
deactivate PC
Compiler->>CB: getAsyncWarpPredicate()
activate CB
CB->>PDM: Get getNumComputeWarps()
Note over CB: Use warp_id >= num_compute_warps
CB-->>Compiler: Return warp ID based predicate
deactivate CB
else No consecutive warp IDs (fallback)
Note over Alloc: Skip warp ID computation
Alloc-->>Compiler: uniform_warp_id = nullptr
Compiler->>PC: Generate predicates (fallback)
activate PC
PC->>PC: Use thread index based predicates<br/>(threadIdx.x < 32 for first warp)<br/>(linear_index >= computed_threshold for async)
deactivate PC
Compiler->>CB: getAsyncWarpPredicate (fallback)
activate CB
CB->>PDM: Use getWarpSpecializationPaddedVal()
Note over CB: Use parallel_index >= (raw - padding)
CB-->>Compiler: Return thread index based predicate
deactivate CB
end
Compiler->>Compiler: Generate optimized PTX/SASS<br/>with reduced WARPSYNC/VOTEU.ALL instructions
deactivate Alloc
|
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
15 files reviewed, 1 comment
csrc/predicate_compute.cpp
Outdated
| // Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64); | ||
| // Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64); | ||
|
|
||
| // const ParallelDimensionMap& pdim_map = | ||
| // GpuLower::current()->info().parallelDimensionMap(); | ||
| // Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); | ||
| // Val* warp_id = | ||
| // SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); | ||
| // // TODO Only select first warp now | ||
| // Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); | ||
|
|
||
| // // Use elect-sync if available | ||
| // if (pdim_map.canUseElectSyncInAsyncWarp()) { | ||
| // return SimplifyingIrBuilder::logicalAndExpr( | ||
| // select_warp, createElectSyncExpr()); | ||
| // } | ||
|
|
||
| // // Warp Specialized ParallelType is ThreadIdx.x and it contains less than | ||
| // 32 | ||
| // // threads, so manually select first thread in warp. | ||
| // Val* thread_id = | ||
| // SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); | ||
| // Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); | ||
| // return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Remove commented-out legacy code
0acf1bd to
ded2693
Compare
|
!test |
|
!test |
|
!test |
|
!test |
|
|
||
| // Compute warp_id = tid / 32 | ||
| Val* warp_size = IrBuilder::create<Val>(32L, DataType::Index); | ||
| Val* warp_id = SimplifyingIrBuilder::divExpr(tid, warp_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like there is nothing guaranteeing this will be warp-uniform since the compiler cannot know the block size so unless TIDz and TIDy are both >1 then it won't know that tid is the linear thread ID. So do we need to do a warp broadcast? See #2323.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we need something like
// __shfl_sync helps PTXAS prove that every thread in the warp has the same
// uniform warp id.
__device__ __forceinline__ uint32_t getUniformWarpId() {
const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y;
const unsigned int warp_id = tid / 32;
return __shfl_sync(0xFFFFFFFF, warp_id, 0);
}
This PR is not ready yet.
Test:
NVFUSER_DUMP=ptx,sass,scheduler_params,launch_param,cuda_to_file ./test_nvfuser --gtest_filter=*ThunderRMSNormBwd*bfloat*16384* 2>&1 |tee new.logSASS Code Comparison: old.log vs new3.log
Summary
Use
UniformWarpId()instead ofthreadIdx.x/y/zfor warp specialization predicates.Key Differences
1. Warp ID Calculation
OLD (old.log):
threadIdx.xdirectly (stored in R5)NEW (new3.log):
R0 >> 5(equivalent totid / 32)SHFL.IDXinstruction to ensure all threads in warp have the same value2. Predicate Comparison
OLD (old.log):
threadIdx.x >= 256(0x100)NEW (new3.log):
UniformWarpId >= 8(0x8)3. Code Size
.L_x_42.L_x_444. Register Usage
Both versions use similar register allocation patterns, but:
5. Thread Index Reads
OLD (old.log):
NEW (new3.log):
tid = threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.xBenefits of the New Approach
SHFL.IDXinstruction explicitly proves to PTXAS that all threads in a warp have the same warp ID valuePerformance Impact
🚀 Instruction Count Improvements (MAJOR WINS!)
⚡ WARPSYNC Instruction Reduction
Where WARPSYNC was eliminated:
OLD version had extra WARPSYNC at:
NEW version simplified to:
💡 Why This Happens
The
SHFL.IDXinstruction explicitly proves to PTXAS that all threads in a warp have the same warp ID. This allows the compiler to:Eliminate VOTEU.ALL instructions (-8 instructions, -89%): When PTXAS knows warp ID is uniform, it doesn't need expensive warp-level ballot operations to check uniformity
ISETP→VOTEU.ALL→ Check resultEliminate redundant WARPSYNC (-2 instructions, -40%): When PTXAS knows values are uniform, it doesn't need as many WARPSYNC barriers before ELECT operations
Simplify predicates (-5 ISETP instructions): Fewer comparison operations needed when uniformity is established early
Optimize control flow: Better branch prediction and divergence handling
📊 Overall Performance Impact
Cost vs Benefit Analysis:
Added overhead:
SHF.R.U32.HI+SHFL.IDXRemoved expensive instructions:
Net Performance Gain:
🎯 Summary
This is a clear win! Trading 2-3 cycles for 100-200+ cycles saved, with the biggest gains from: