Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
exact_types_.erase(ParallelType::TIDx);
}

int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) {
int64_t ParallelDimensionMap::getStaticComputeThreadsInDim(ParallelType pt) {
if (!dim_map_.contains(pt)) {
return 1;
}
Expand All @@ -163,12 +163,12 @@ int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) {
// use the actual compile parameter value
NVF_ERROR(GpuLower::hasCurrent());
const auto& cparams = GpuLower::current()->compileParams();
if (pt == ParallelType::TIDx && cparams.bdimx.has_value()) {
return cparams.bdimx.value();
} else if (pt == ParallelType::TIDy && cparams.bdimy.has_value()) {
return cparams.bdimy.value();
} else if (pt == ParallelType::TIDz && cparams.bdimz.has_value()) {
return cparams.bdimz.value();
if (pt == ParallelType::TIDx && cparams.compute_bdimx.has_value()) {
return cparams.compute_bdimx.value();
} else if (pt == ParallelType::TIDy && cparams.compute_bdimy.has_value()) {
return cparams.compute_bdimy.value();
} else if (pt == ParallelType::TIDz && cparams.compute_bdimz.has_value()) {
return cparams.compute_bdimz.value();
}
// Return -1 for dynamic dimensions when compile-time CTA shape is not known,
// this disables register sharing on dynamic dimensions since we can't
Expand All @@ -192,7 +192,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
if (pt == ws_pt) {
continue;
}
int64_t thread_count_for_pt = getThreadCountInDim(pt);
int64_t thread_count_for_pt = getStaticComputeThreadsInDim(pt);
NVF_ERROR(
thread_count_for_pt != -1,
"Detected dynamic size for parallel type ",
Expand All @@ -208,7 +208,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
"The # active threads in other thread dimensions is not evenly ",
"divisible with 128 threads.");
int64_t ws_num_threads_pad = 128 / other_active_pts_threads;
int64_t after_pad = getThreadCountInDim(ws_pt) + ws_num_threads_pad;
int64_t after_pad = getStaticComputeThreadsInDim(ws_pt) + ws_num_threads_pad;
NVF_ERROR(
(after_pad * other_active_pts_threads) % 128 == 0,
"Illegal register sharing on ",
Expand Down
2 changes: 1 addition & 1 deletion csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ParallelDimensionMap {
private:
//! Get number of threads for ParallelType axis
//! Not used: 1, Const: n, Dynamic: -1
int64_t getThreadCountInDim(ParallelType pt);
int64_t getStaticComputeThreadsInDim(ParallelType pt);

//! TIDx may need to be marked as non-exact as it may be padded to a
//! multiple of the warp size.
Expand Down
10 changes: 9 additions & 1 deletion csrc/runtime/executor_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ struct CompileParams {
std::optional<int64_t> bdimy = std::nullopt;
std::optional<int64_t> bdimz = std::nullopt;

// Threads used for computation, excluding warp specialization padding
std::optional<int64_t> compute_bdimx = std::nullopt;
std::optional<int64_t> compute_bdimy = std::nullopt;
std::optional<int64_t> compute_bdimz = std::nullopt;

bool operator==(const CompileParams& other) const {
// Disallow comparison if the index type is nullopt
NVF_ERROR(
Expand All @@ -46,7 +51,10 @@ struct CompileParams {
maxrregcount == other.maxrregcount &&
enable_magic_zero == other.enable_magic_zero &&
device == other.device && include_paths == other.include_paths &&
bdimx == other.bdimx && bdimy == other.bdimy && bdimz == other.bdimz;
bdimx == other.bdimx && bdimy == other.bdimy && bdimz == other.bdimz &&
compute_bdimx == other.compute_bdimx &&
compute_bdimy == other.compute_bdimy &&
compute_bdimz == other.compute_bdimz;
}

bool operator!=(const CompileParams& other) const {
Expand Down
9 changes: 7 additions & 2 deletions csrc/scheduler/normalization_inner_tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ std::unique_ptr<InnerNormTmaParams> getInnerPersistentHeuristics(
bdimx = ceilDiv(after_vect, params->persistent_batch_size);
bdimx =
bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size;

int64_t compute_bdimx = bdimx, compute_bdimy = 1, compute_bdimz = 1;
// set warp specialized circular buffer options
// don't use warp specialized if the total iteration count is too small
// TODO: heuristic tuning determine when to use warp specialized version
Expand All @@ -105,6 +105,8 @@ std::unique_ptr<InnerNormTmaParams> getInnerPersistentHeuristics(
gdimx = sm_count;
bdimy = n_compute_warp_groups;
bdimz = 1; // warp specialized kernel requires static CTA shape
compute_bdimy = bdimy;
compute_bdimz = bdimz;
params->n_grouped_rows = n_rows_per_compute_warp_group;
ParallelType ws_pt = bdimy > 1 ? ParallelType::TIDy : ParallelType::TIDx;
WarpSpecialized ws(ws_pt);
Expand All @@ -123,7 +125,7 @@ std::unique_ptr<InnerNormTmaParams> getInnerPersistentHeuristics(
if (total_threads > 256) {
int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads);
int64_t computation_threads =
total_threads - kWarpSpecializationPaddedThreads;
compute_bdimx * compute_bdimy * compute_bdimz;
ws.num_registers = scheduler_utils::getRegisterSharing(
reg_per_thread,
computation_threads,
Expand All @@ -146,6 +148,9 @@ std::unique_ptr<InnerNormTmaParams> getInnerPersistentHeuristics(
params->cparams.bdimx = bdimx;
params->cparams.bdimy = bdimy;
params->cparams.bdimz = bdimz;
params->cparams.compute_bdimx = compute_bdimx;
params->cparams.compute_bdimy = compute_bdimy;
params->cparams.compute_bdimz = compute_bdimz;
}

// Set index type
Expand Down