From 76592f5ed4e77746c2a8ed9430755ae9faaff20e Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 15 Jan 2026 16:01:52 -0800 Subject: [PATCH] use unpadded compute threads --- csrc/parallel_dimension_map.cpp | 18 +++++++++--------- csrc/parallel_dimension_map.h | 2 +- csrc/runtime/executor_params.h | 10 +++++++++- csrc/scheduler/normalization_inner_tma.cpp | 9 +++++++-- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index b8f0b24f114..abe9f001f2e 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -152,7 +152,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } -int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) { +int64_t ParallelDimensionMap::getStaticComputeThreadsInDim(ParallelType pt) { if (!dim_map_.contains(pt)) { return 1; } @@ -163,12 +163,12 @@ int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) { // use the actual compile parameter value NVF_ERROR(GpuLower::hasCurrent()); const auto& cparams = GpuLower::current()->compileParams(); - if (pt == ParallelType::TIDx && cparams.bdimx.has_value()) { - return cparams.bdimx.value(); - } else if (pt == ParallelType::TIDy && cparams.bdimy.has_value()) { - return cparams.bdimy.value(); - } else if (pt == ParallelType::TIDz && cparams.bdimz.has_value()) { - return cparams.bdimz.value(); + if (pt == ParallelType::TIDx && cparams.compute_bdimx.has_value()) { + return cparams.compute_bdimx.value(); + } else if (pt == ParallelType::TIDy && cparams.compute_bdimy.has_value()) { + return cparams.compute_bdimy.value(); + } else if (pt == ParallelType::TIDz && cparams.compute_bdimz.has_value()) { + return cparams.compute_bdimz.value(); } // Return -1 for dynamic dimensions when compile-time CTA shape is not known, // this disables register sharing on dynamic dimensions since we can't @@ -192,7 +192,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { if (pt == ws_pt) { continue; } - int64_t thread_count_for_pt = getThreadCountInDim(pt); + int64_t thread_count_for_pt = getStaticComputeThreadsInDim(pt); NVF_ERROR( thread_count_for_pt != -1, "Detected dynamic size for parallel type ", @@ -208,7 +208,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { "The # active threads in other thread dimensions is not evenly ", "divisible with 128 threads."); int64_t ws_num_threads_pad = 128 / other_active_pts_threads; - int64_t after_pad = getThreadCountInDim(ws_pt) + ws_num_threads_pad; + int64_t after_pad = getStaticComputeThreadsInDim(ws_pt) + ws_num_threads_pad; NVF_ERROR( (after_pad * other_active_pts_threads) % 128 == 0, "Illegal register sharing on ", diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 00f4165e298..db7a32360df 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -99,7 +99,7 @@ class ParallelDimensionMap { private: //! Get number of threads for ParallelType axis //! Not used: 1, Const: n, Dynamic: -1 - int64_t getThreadCountInDim(ParallelType pt); + int64_t getStaticComputeThreadsInDim(ParallelType pt); //! TIDx may need to be marked as non-exact as it may be padded to a //! multiple of the warp size. diff --git a/csrc/runtime/executor_params.h b/csrc/runtime/executor_params.h index 6c5338c105a..983b042cf41 100644 --- a/csrc/runtime/executor_params.h +++ b/csrc/runtime/executor_params.h @@ -34,6 +34,11 @@ struct CompileParams { std::optional bdimy = std::nullopt; std::optional bdimz = std::nullopt; + // Threads used for computation, excluding warp specialization padding + std::optional compute_bdimx = std::nullopt; + std::optional compute_bdimy = std::nullopt; + std::optional compute_bdimz = std::nullopt; + bool operator==(const CompileParams& other) const { // Disallow comparison if the index type is nullopt NVF_ERROR( @@ -46,7 +51,10 @@ struct CompileParams { maxrregcount == other.maxrregcount && enable_magic_zero == other.enable_magic_zero && device == other.device && include_paths == other.include_paths && - bdimx == other.bdimx && bdimy == other.bdimy && bdimz == other.bdimz; + bdimx == other.bdimx && bdimy == other.bdimy && bdimz == other.bdimz && + compute_bdimx == other.compute_bdimx && + compute_bdimy == other.compute_bdimy && + compute_bdimz == other.compute_bdimz; } bool operator!=(const CompileParams& other) const { diff --git a/csrc/scheduler/normalization_inner_tma.cpp b/csrc/scheduler/normalization_inner_tma.cpp index 6dc8a778038..cde281a1804 100644 --- a/csrc/scheduler/normalization_inner_tma.cpp +++ b/csrc/scheduler/normalization_inner_tma.cpp @@ -85,7 +85,7 @@ std::unique_ptr getInnerPersistentHeuristics( bdimx = ceilDiv(after_vect, params->persistent_batch_size); bdimx = bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size; - + int64_t compute_bdimx = bdimx, compute_bdimy = 1, compute_bdimz = 1; // set warp specialized circular buffer options // don't use warp specialized if the total iteration count is too small // TODO: heuristic tuning determine when to use warp specialized version @@ -105,6 +105,8 @@ std::unique_ptr getInnerPersistentHeuristics( gdimx = sm_count; bdimy = n_compute_warp_groups; bdimz = 1; // warp specialized kernel requires static CTA shape + compute_bdimy = bdimy; + compute_bdimz = bdimz; params->n_grouped_rows = n_rows_per_compute_warp_group; ParallelType ws_pt = bdimy > 1 ? ParallelType::TIDy : ParallelType::TIDx; WarpSpecialized ws(ws_pt); @@ -123,7 +125,7 @@ std::unique_ptr getInnerPersistentHeuristics( if (total_threads > 256) { int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); int64_t computation_threads = - total_threads - kWarpSpecializationPaddedThreads; + compute_bdimx * compute_bdimy * compute_bdimz; ws.num_registers = scheduler_utils::getRegisterSharing( reg_per_thread, computation_threads, @@ -146,6 +148,9 @@ std::unique_ptr getInnerPersistentHeuristics( params->cparams.bdimx = bdimx; params->cparams.bdimy = bdimy; params->cparams.bdimz = bdimz; + params->cparams.compute_bdimx = compute_bdimx; + params->cparams.compute_bdimy = compute_bdimy; + params->cparams.compute_bdimz = compute_bdimz; } // Set index type