From 13876d0521e20f84637dd34e0cb2a80b934c00c9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 7 Jan 2026 14:19:47 -0800 Subject: [PATCH 01/13] PR0: Adding runtime function for GroupedBlockQuantizeOp 1. refactor existing block_layout op and block_quantization_kernel to re-use existing runtime functions; 2. added runtime function for GroupedBlockQuantizeOp --- CMakeLists.txt | 1 - csrc/codegen.cpp | 2 +- csrc/runtime/compiled_kernel.cpp | 5 +- runtime/block_layout.cu | 102 --------- runtime/block_quantization_kernels.cu | 302 +++++++++++++++++++------- 5 files changed, 231 insertions(+), 181 deletions(-) delete mode 100644 runtime/block_layout.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 61006e83cbf..4c0d36c46ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1436,7 +1436,6 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/block_sync_atomic.cu ${NVFUSER_ROOT}/runtime/block_sync_default.cu ${NVFUSER_ROOT}/runtime/block_welford_outer.cu - ${NVFUSER_ROOT}/runtime/block_layout.cu ${NVFUSER_ROOT}/runtime/block_quantization_kernels.cu ${NVFUSER_ROOT}/runtime/broadcast.cu ${NVFUSER_ROOT}/runtime/casts.cu diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 810837c39e8..e12ed646540 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -4677,7 +4677,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(layout_op->g())); indent() << genCall( - "block_layout::preprocessGroupedMatmulInputSf", + "bq::preprocessGroupedMatmulInputSf", template_args, func_args) << ";\n"; diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index bc9eefebcc1..4ce528a510e 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -1086,11 +1086,8 @@ std::string _getStructuredCode( if (has_topk) { code += nvfuser_resources::topk_cu; } - if (has_block_layout) { - code += nvfuser_resources::block_layout_cu; - } - if (has_block_quantize_op) { + if (has_block_layout || has_block_quantize_op) { code += nvfuser_resources::block_quantization_kernels_cu; } diff --git a/runtime/block_layout.cu b/runtime/block_layout.cu deleted file mode 100644 index 22935dcb57a..00000000000 --- a/runtime/block_layout.cu +++ /dev/null @@ -1,102 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on - -namespace nvf::block_layout { - -namespace { - -// TODO: support vectorized store -template -__device__ nvfuser_index_t outputOffsetAfterSwizzlePadding( - const nvfuser_index_t row_idx, - const nvfuser_index_t col_idx, - const nvfuser_index_t padded_col_size) { - constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER; - - /* logical dimension of matrix [ row_size, col_size] - * - * while logical domain after padding can be viewed as - * [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ] - * where - * row_tile = ceilDiv(row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER) - * col_tile = ceilDiv(col_size / BLOCK_COL) - */ - - // we first convert `row_idx` and `col_idx` to the logical index on the 5d - // tensor. - nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE; - nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE; - nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER; - nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER; - nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL; - nvfuser_index_t col_block_idx = col_idx % BLOCK_COL; - - /* layout for matrix [ row_size, col_size] - * it is viewed - * [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL] - * then transposed with axis (1, 3) - * [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL] - * and then made contiguous - * So we can compute the corresponding stride for each dimension - */ - constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL; - constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE = - BLOCK_ROW_INNER * BLOCK_COL; - constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL; - - return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE + - col_tile_idx * COL_TILE_STRIDE + - row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE + - row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx; -} - -} // namespace - -template < - typename T, - typename Index_T, - int BLOCK_ROW_OUTER, - int BLOCK_ROW_INNER, - int BLOCK_COL, - int UNROLL_FACTOR> -__device__ void preprocessGroupedMatmulInputSf( - T* output, - const T* input, - const nvfuser_index_t row_idx, - const nvfuser_index_t col_idx, - const Index_T* input_offsets, - const Index_T* output_offsets, - const nvfuser_index_t col_size, - const nvfuser_index_t group_size) { - // find corresponding expert_id - int expert_id = group_size - 1; - for (int i = 1; i < group_size; ++i) { - if (row_idx < input_offsets[i]) { - expert_id = i - 1; - break; - } - } - - // row idx for current group - nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; - // compute output group offset for current group - nvfuser_index_t padded_col_size = - (col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; - T* out_group_offset = output + output_offsets[expert_id] * padded_col_size; - - // TODO: vectorized load/store instead of for loop - for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) { - nvfuser_index_t index = outputOffsetAfterSwizzlePadding< - BLOCK_ROW_OUTER, - BLOCK_ROW_INNER, - BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size); - out_group_offset[index] = input[i]; - } -} - -} // namespace nvf::block_layout diff --git a/runtime/block_quantization_kernels.cu b/runtime/block_quantization_kernels.cu index c950eeb31cb..a7db958ba5f 100644 --- a/runtime/block_quantization_kernels.cu +++ b/runtime/block_quantization_kernels.cu @@ -9,6 +9,55 @@ namespace nvf { namespace bq { +namespace { + +// TODO: support vectorized store +template +__device__ nvfuser_index_t outputOffsetAfterSwizzlePadding( + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const nvfuser_index_t padded_col_size) { + constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER; + + /* logical dimension of matrix [ row_size, col_size] + * + * while logical domain after padding can be viewed as + * [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ] + * where + * row_tile = ceilDiv(row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER) + * col_tile = ceilDiv(col_size / BLOCK_COL) + */ + + // we first convert `row_idx` and `col_idx` to the logical index on the 5d + // tensor. + nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE; + nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE; + nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER; + nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER; + nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL; + nvfuser_index_t col_block_idx = col_idx % BLOCK_COL; + + /* layout for matrix [ row_size, col_size] + * it is viewed + * [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL] + * then transposed with axis (1, 3) + * [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL] + * and then made contiguous + * So we can compute the corresponding stride for each dimension + */ + constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL; + constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE = + BLOCK_ROW_INNER * BLOCK_COL; + constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL; + + return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE + + col_tile_idx * COL_TILE_STRIDE + + row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE + + row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx; +} + +} // namespace + // This helper function finds the max of NUM_ELEMENTS (2, 4, or 8) values // using the same number of threads. template @@ -75,6 +124,63 @@ __device__ __inline__ void convertToFloatAndComputeLocalMax( } } +// Fast reciprocal of 2^biased_exp using bit manipulation +// Returns 1.0 for biased_exp==0, otherwise returns 2^(-biased_exp) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +__device__ __forceinline__ float exp2f_rcp(uint8_t biased_exp) { + return (biased_exp == 0) + ? 1 + : __int_as_float( + (254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +template < + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> +__device__ void block_quantize_to_mxfp8( + const Array& input, + Array<__e4m3, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e8m0, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + nvfuser_index_t logical_index) { + // Number of threads involved in computing one block scaling factor + constexpr int THREADS_PER_SCALING_FACTOR = 32 / ITEMS_PER_THREAD; + + Array vec_in; + float local_max; + convertToFloatAndComputeLocalMax( + input, vec_in, local_max); + + // Compute the max accross 32/ITEMS_PER_THREAD threads + // This assumes each thread has already computed is local max of 2, 4 (fp32) + // or 2,4, 8 (bf16/fp16) elements. + reduceAcrossThreads(local_max); + float block_max = local_max; + + static constexpr float max_norm_rcp = 1.0f / 448; + __e8m0 exponent = __float2e8m0(block_max * max_norm_rcp); + + // Write out the block scaling factor to global memory. + // This assumes block_size (32) elements in the input were contiguous. + // Only one block scaling factor is written out per 32(assumed block size) + // elements. + int offset = logical_index / 32; + if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { + block_scales[offset] = exponent; + } + + const float block_scale_inverse = exp2f_rcp(exponent.raw()); + +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + output[i] = __float2e4m3(vec_in[i] * block_scale_inverse); + } +} + // A runtime function to compute quantized nvfp4 output (output) and fp8 block // scaling (block_scales) factors from fp32, fp16, bf16 inputs (input). // The function is templatized over input type T (float, __half, __bfloat). @@ -89,18 +195,12 @@ template < int ALIGNMENT_2, int BLOCK_SCALE_DIM, int BLOCK_SCALE_ALLOC> -__device__ void block_quantize_to_nvfp4( +__device__ void block_quantize_to_nvfp4_util( const Array& input, Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index, - Tensor global_scale, - int64_t fp8_scaling_factors_inner_dim = -1, - int64_t alloc_dim0 = -1, - int64_t alloc_dim1 = -1, - int64_t alloc_dim2 = -1, - int64_t alloc_dim3 = -1, - int64_t alloc_dim4 = -1) { + Tensor& global_scale, + int64_t offset) { // Number of threads involved in computing one block scaling factor constexpr int THREADS_PER_SCALING_FACTOR = 16 / ITEMS_PER_THREAD; @@ -135,6 +235,49 @@ __device__ void block_quantize_to_nvfp4( scaled_max = fminf(1.0f / scaled_max, float_max); } + if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { + block_scales[offset] = clamped_max_fp8; + } + + Array scaled_vals; +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + scaled_vals[i] = vec_in[i] * scaled_max; + } + + Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals; + *reinterpret_cast*>( + &fp4_vals[0]) = + __float2e2m1( + *reinterpret_cast*>( + &scaled_vals[0])); + +#pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + output[i] = fp4_vals[i]; + } +} + +template < + bool USE_GLOBAL_SCALE, + int ITEMS_PER_THREAD, + typename T, + int ALIGNMENT_1, + int ALIGNMENT_2, + int BLOCK_SCALE_DIM, + int BLOCK_SCALE_ALLOC> +__device__ void block_quantize_to_nvfp4( + const Array& input, + Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + nvfuser_index_t logical_index, + Tensor global_scale, + int64_t fp8_scaling_factors_inner_dim = -1, + int64_t alloc_dim0 = -1, + int64_t alloc_dim1 = -1, + int64_t alloc_dim2 = -1, + int64_t alloc_dim3 = -1, + int64_t alloc_dim4 = -1) { // Write out the block scaling factor to global memory. // This assumes 16 elements in the input were contiguous. // Only one block scaling factor is written out per 16(assumed block size) @@ -167,85 +310,98 @@ __device__ void block_quantize_to_nvfp4( offset = pos_4 * stride_4 + pos_3 * stride_3 + pos_2 * stride_2 + pos_1 * stride_1 + pos_0 * stride_0; } + block_quantize_to_nvfp4_util(input, output, block_scales, global_scale, offset); +} - if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { - block_scales[offset] = clamped_max_fp8; - } - - Array scaled_vals; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - scaled_vals[i] = vec_in[i] * scaled_max; +template < + typename T, + typename Index_T, + int BLOCK_ROW_OUTER, + int BLOCK_ROW_INNER, + int BLOCK_COL, + int UNROLL_FACTOR> +__device__ void preprocessGroupedMatmulInputSf( + T* output, + const T* input, + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const Index_T* input_offsets, + const Index_T* output_offsets, + const nvfuser_index_t col_size, + const nvfuser_index_t group_size) { + // find corresponding expert_id + int expert_id = group_size - 1; + for (int i = 1; i < group_size; ++i) { + if (row_idx < input_offsets[i]) { + expert_id = i - 1; + break; + } } - Array<__e2m1, ITEMS_PER_THREAD, 1> fp4_vals; - *reinterpret_cast*>( - &fp4_vals[0]) = - __float2e2m1( - *reinterpret_cast*>( - &scaled_vals[0])); - -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - output[i] = fp4_vals[i]; + // row idx for current group + nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; + // compute output group offset for current group + nvfuser_index_t padded_col_size = + (col_size + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; + T* out_group_offset = output + output_offsets[expert_id] * padded_col_size; + + // TODO: vectorized load/store instead of for loop + for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) { + nvfuser_index_t index = outputOffsetAfterSwizzlePadding< + BLOCK_ROW_OUTER, + BLOCK_ROW_INNER, + BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size); + out_group_offset[index] = input[i]; } } -// Fast reciprocal of 2^biased_exp using bit manipulation -// Returns 1.0 for biased_exp==0, otherwise returns 2^(-biased_exp) -constexpr uint32_t FP32_MANTISSA_BITS = 23; -__device__ __forceinline__ float exp2f_rcp(uint8_t biased_exp) { - return (biased_exp == 0) - ? 1 - : __int_as_float( - (254 - biased_exp) - << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) -} - template < + bool USE_GLOBAL_SCALE, + int BLOCK_ROW_OUTER, + int BLOCK_ROW_INNER, + int BLOCK_COL, int ITEMS_PER_THREAD, typename T, + typename Index_T, int ALIGNMENT_1, int ALIGNMENT_2, int BLOCK_SCALE_DIM, int BLOCK_SCALE_ALLOC> -__device__ void block_quantize_to_mxfp8( +__device__ void grouped_block_quantize_to_nvfp4( const Array& input, - Array<__e4m3, ITEMS_PER_THREAD, ALIGNMENT_2>& output, - Tensor<__e8m0, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, - nvfuser_index_t logical_index) { - // Number of threads involved in computing one block scaling factor - constexpr int THREADS_PER_SCALING_FACTOR = 32 / ITEMS_PER_THREAD; - - Array vec_in; - float local_max; - convertToFloatAndComputeLocalMax( - input, vec_in, local_max); - - // Compute the max accross 32/ITEMS_PER_THREAD threads - // This assumes each thread has already computed is local max of 2, 4 (fp32) - // or 2,4, 8 (bf16/fp16) elements. - reduceAcrossThreads(local_max); - float block_max = local_max; - - static constexpr float max_norm_rcp = 1.0f / 448; - __e8m0 exponent = __float2e8m0(block_max * max_norm_rcp); - - // Write out the block scaling factor to global memory. - // This assumes block_size (32) elements in the input were contiguous. - // Only one block scaling factor is written out per 32(assumed block size) - // elements. - int offset = logical_index / 32; - if (threadIdx.x % THREADS_PER_SCALING_FACTOR == 0) { - block_scales[offset] = exponent; - } - - const float block_scale_inverse = exp2f_rcp(exponent.raw()); - -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; ++i) { - output[i] = __float2e4m3(vec_in[i] * block_scale_inverse); + Array<__e2m1, ITEMS_PER_THREAD, ALIGNMENT_2>& output, + Tensor<__e4m3, BLOCK_SCALE_DIM, BLOCK_SCALE_ALLOC>& block_scales, + const nvfuser_index_t row_idx, + const nvfuser_index_t col_idx, + const Index_T* input_offsets, + const Index_T* output_offsets, + const nvfuser_index_t col_size, + const nvfuser_index_t group_size, + Tensor global_scale) { + // find corresponding expert_id + int expert_id = group_size - 1; + for (int i = 1; i < group_size; ++i) { + if (row_idx < input_offsets[i]) { + expert_id = i - 1; + break; + } } + // NOTE: col_size and col_idx needs to be divided by block size for scaling factor + constexpr nvfuser_index_t BLOCK_SIZE = 16; + // row idx for current group + nvfuser_index_t c_row_idx = row_idx - input_offsets[expert_id]; + // compute output group offset for current group + nvfuser_index_t padded_col_size = + (col_size / BLOCK_SIZE + BLOCK_COL - 1) / BLOCK_COL * BLOCK_COL; + nvfuser_index_t out_group_offset = output_offsets[expert_id] * padded_col_size; + // compute the offset + nvfuser_index_t index = outputOffsetAfterSwizzlePadding< + BLOCK_ROW_OUTER, + BLOCK_ROW_INNER, + BLOCK_COL>(c_row_idx, col_idx / BLOCK_SIZE, padded_col_size); + nvfuser_index_t offset = out_group_offset + index; + + block_quantize_to_nvfp4_util(input, output, block_scales, global_scale, offset); } } // namespace bq From c2cf0236bf03e7c6574ec1d28a4b2caae34290d3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 7 Jan 2026 14:44:50 -0800 Subject: [PATCH 02/13] PR1: adding codegen support for GroupedBlockQuantizationOp --- csrc/codegen.cpp | 110 ++++++++ .../analysis/non_divisible_split.cpp | 7 +- .../analysis/sync_information.cpp | 15 +- .../analysis/trivial_broadcast.cpp | 11 + .../device_lower/analysis/trivial_broadcast.h | 2 + csrc/device_lower/pass/index.cpp | 64 +++++ csrc/device_lower/pass/index.h | 1 + csrc/device_lower/utils.cpp | 1 + csrc/device_lower/validation.cpp | 238 +++++++++++++++++- csrc/dispatch.h | 1 + csrc/fusion_segmenter.cpp | 7 +- csrc/ir/composite_nodes.cpp | 58 +++++ csrc/ir/composite_nodes.h | 92 +++++++ csrc/ir/utils.cpp | 6 +- csrc/kernel.cpp | 4 + csrc/logical_domain_map.cpp | 38 ++- csrc/logical_domain_map.h | 4 + csrc/ops/arith.cpp | 141 +++++++++++ csrc/ops/arith.h | 9 + csrc/scheduler/pointwise.cpp | 24 +- csrc/scheduler/pointwise_non_tma.cpp | 9 +- csrc/scheduler/registry_utils.cpp | 8 + csrc/scheduler/tools/domain_map.cpp | 13 + csrc/scheduler/utils.cpp | 19 +- csrc/tensor_metadata.cpp | 6 + tests/cpp/test_layout_op.cpp | 70 +++++- 26 files changed, 930 insertions(+), 28 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index e12ed646540..fcbd320a7d3 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1898,6 +1898,116 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << genCall(fn_call, template_args, func_args) << ";\n"; } + // Special handling of GroupedBlockQuantizationOp to call the runtime + // function. + void handle(const GroupedBlockQuantizationOp* bqop) final { + // This operator is plumbed down to a runtime function call. + // One of the assumptions is that the device runtime expects + // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, + // 4, or 8 for Half. We achieve this by having the quantized output tv + // scheduled to have the inner dimension grouped by 2/4/8. + auto output = bqop->quantizedOutput()->as()->view(); + auto output_dtype = output->getDataType(); + + // Extract group size from the loop domain + int64_t group_size = 1; + const auto& loop_domain = output->getLoopDomain(); + for (const auto* domain : loop_domain) { + if (domain->getParallelType() == ParallelType::Group && + domain->extent()->isConstInt()) { + group_size = domain->extent()->evaluate().as(); + break; + } + } + + // Validate group size based on input data type + const auto input_dtype = + bqop->in()->as()->view()->getDataType().value(); + const bool is_half_precision = + (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half); + const bool is_valid_group_size = is_half_precision + ? (group_size == 2 || group_size == 4 || group_size == 8) + : (group_size == 2 || group_size == 4); + + NVF_ERROR( + is_valid_group_size, + "Group size should be ", + is_half_precision ? "2, 4 or 8" : "2 or 4", + " for GroupedBlockQuantizationOp with input type ", + input_dtype, + ". Found: ", + group_size, + ". Expr: ", + bqop->toString()); + + // Build template arguments + ArgumentBuilder template_args; + // No global scale is required when quantizing to mxfp8 + if (output_dtype == DataType::Float4_e2m1fn) { + template_args.arg(bqop->hasGlobalScale()); + } + switch (bqop->layout()) { + case BlockScalingFactorLayout::Block128x4: + template_args.arg(32); // block_row_outer + template_args.arg(4); // block_row_inner + template_args.arg(4); // block_col + break; + default: + NVF_THROW("unrecognized layout"); + break; + } + template_args.arg(group_size); // ITEMS_PER_THREAD + + // Build function arguments + ArgumentBuilder func_args; + func_args.arg(genInline( + bqop->input(0)->as()->view())); // input data + func_args.arg(genInline(output)); // quantized output + func_args.arg(genInline( + bqop->blockScales()->as()->view())); // block scales + + // generate logical index for runtime function + func_args.arg(genInline(bqop->attributeVal(2))); + func_args.arg(genInline(bqop->attributeVal(3))); + func_args.arg("&").append(genVariableName(bqop->inputOffsets()) + "[0]"); + func_args.arg("&").append(genVariableName(bqop->outputOffsets()) + "[0]"); + func_args.arg(genInline(bqop->k())); + func_args.arg(genInline(bqop->g())); + + if (output_dtype == DataType::Float4_e2m1fn) { + func_args.arg( + bqop->hasGlobalScale() ? genInline(bqop->globalScale()) : "{}"); + } + + // Add swizzled allocation domain parameters if needed + // This is always skipped when quantizing to mxfp8 + auto block_scales_tv = bqop->blockScales()->as()->view(); + if (block_scales_tv->hasAllocation()) { + auto logical_domain = + TensorDomain::noReductions(block_scales_tv->getLogicalDomain()); + auto allocation_domain = + TensorDomain::noReductions(block_scales_tv->getAllocationDomain()); + + // Swizzled layout: 2D logical -> 5D allocation + if (logical_domain.size() == 2 && allocation_domain.size() == 5) { + // Add logical domain extent of the inner dimension + func_args.arg(genInline(logical_domain[1]->extent())); + + // Add all allocation domain extents + for (const auto* alloc_id : allocation_domain) { + func_args.arg(genInline(alloc_id->extent())); + } + } + } + + auto fn_call = output_dtype == DataType::Float4_e2m1fn + ? "bq::grouped_block_quantize_to_nvfp4" + : "bq::grouped_block_quantize_to_mxfp8"; + + // Generate the function call + indent() << genCall(fn_call, template_args, func_args) << ";\n"; + } + std::string genReductionOp(BinaryOpType op_type, DataType data_type) { std::stringstream lambda; lambda << "[](" << data_type << " &a, " << data_type << " b) " diff --git a/csrc/device_lower/analysis/non_divisible_split.cpp b/csrc/device_lower/analysis/non_divisible_split.cpp index 0effe1618e1..8a46bfa352e 100644 --- a/csrc/device_lower/analysis/non_divisible_split.cpp +++ b/csrc/device_lower/analysis/non_divisible_split.cpp @@ -218,7 +218,12 @@ NonDivisiblePredicateInfo::NonDivisiblePredicateInfo(Fusion* fusion) { // mapped to any ID of the input or sibling output. if (def == nullptr || (tv->definition()->isA() && - tv == tv->definition()->as()->blockScales())) { + tv == tv->definition()->as()->blockScales()) || + (tv->definition()->isA() && + tv == + tv->definition() + ->as() + ->blockScales())) { continue; } diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 178408860e2..bba2ff3e2a7 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -299,11 +299,16 @@ SyncMap::SyncMap(Fusion* fusion, bool error_on_failure) { // sync/predication is handled there. if ((parallel_type == ParallelType::BIDx || parallel_type == ParallelType::TIDx) && - (consumer->definition()->isA() && - consumer == - consumer->definition() - ->as() - ->blockScales())) { + ((consumer->definition()->isA() && + consumer == + consumer->definition() + ->as() + ->blockScales()) || + (consumer->definition()->isA() && + consumer == + consumer->definition() + ->as() + ->blockScales()))) { continue; } diff --git a/csrc/device_lower/analysis/trivial_broadcast.cpp b/csrc/device_lower/analysis/trivial_broadcast.cpp index 34655598e25..2c0938f79ec 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.cpp +++ b/csrc/device_lower/analysis/trivial_broadcast.cpp @@ -125,6 +125,17 @@ void ConcretizedBroadcastDomains::handle(BlockQuantizationOp* bq) { } } +// GroupedBlockQuantizationOp introduces broadcast domains in the block scales +// output +void ConcretizedBroadcastDomains::handle(GroupedBlockQuantizationOp* bq) { + auto out = bq->blockScales()->as(); + auto bcast_id = out->getLogicalDomain().back(); + if (bcast_id->isBroadcast()) { + broadcast_origin_map_.emplace( + bcast_id, std::unordered_set({bcast_id})); + } +} + void ConcretizedBroadcastDomains::dispatch(Expr* expr) { IterVisitor::dispatch(expr); diff --git a/csrc/device_lower/analysis/trivial_broadcast.h b/csrc/device_lower/analysis/trivial_broadcast.h index b002d73e976..d17a0d0bd94 100644 --- a/csrc/device_lower/analysis/trivial_broadcast.h +++ b/csrc/device_lower/analysis/trivial_broadcast.h @@ -53,6 +53,8 @@ class NVF_API ConcretizedBroadcastDomains : private IterVisitor { void handle(BlockQuantizationOp* bq) final; + void handle(GroupedBlockQuantizationOp* bq) final; + void dispatch(Expr* expr) final; void markAsConcretized( diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index a2d2877402e..a6555c402f5 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -437,6 +437,70 @@ void IndexLowering::handle(const BlockQuantizationOp* bqop) { GpuLower::current()->propagateExprInfo(bqop, back()); } +void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) { + const auto in = IrBuilder::create( + grouped_bqop->in()->as(), grouped_bqop->fusion()->zeroVal()); + + const auto out_scales = IrBuilder::create( + grouped_bqop->blockScales()->as(), + grouped_bqop->fusion()->zeroVal()); + const auto out_quantized = IrBuilder::create( + grouped_bqop->quantizedOutput()->as(), + grouped_bqop->fusion()->zeroVal()); + + // The GroupedBlockQuantizationOp funnels down to a runtime function. + // We pass the index for the block scaling factors output. We compute + // the index bases on the logical indices of the quantized output tensor. + // Then inside the runtime function, we divide this linearized index by 16 + // (the block size) to get the index for the scaling factors. + // We get the linearized index as follows: + // We get the logical indices for the quantized output. + // We then multiply and accumulate them using the logical extents of the + // quantized output tensor to get the linearized index. + std::vector logical_index = Index::getConsumerPerDimLogicalIndex( + grouped_bqop->quantizedOutput()->as(), for_loops_); + + NVF_ERROR( + logical_index.size() == 2, + "only matrices are supported in GroupedBlockQuantizationOp"); + + // As part of runtime validation + // make sure that the inner dimension of the input is divisible by block size. + auto* inner_id = + grouped_bqop->in()->as()->getLogicalDomain().back(); + Val* is_divisible = SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr( + inner_id->extent(), + IrBuilder::create(grouped_bqop->blockSize(), DataType::Index)), + grouped_bqop->fusion()->zeroVal()); + + NVFUSER_LOWER_VALIDATE( + is_divisible, + "Inner dimension of GroupedBlockQuantizationOp input must be divisible " + "by block " + "size (", + grouped_bqop->blockSize(), + "), but got extent ", + inner_id->extent()->toInlineString(), + " in ", + grouped_bqop->toString()); + + pushBack(IrBuilder::create( + out_scales, + out_quantized, + in, + grouped_bqop->inputOffsets(), + grouped_bqop->outputOffsets(), + grouped_bqop->layout(), + grouped_bqop->k(), + grouped_bqop->g(), + grouped_bqop->globalScale(), + 16, + logical_index[0], + logical_index[1])); + GpuLower::current()->propagateExprInfo(grouped_bqop, back()); +} + void IndexLowering::handle(const SelectOp* sop) { auto lowered_index = lowerSrcIndex(sop->input(1), sop->output(0)); auto lowered_index_cast = lowered_index; diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index b47a9f9b36a..be4a33d25aa 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -58,6 +58,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const ArgsortOp*) final; void handle(const TopKOp*) final; void handle(const BlockQuantizationOp*) final; + void handle(const GroupedBlockQuantizationOp*) final; void handle(const RNGOp*) final; void handle(const ReductionOp*) final; void handle(const GroupedReductionOp*) final; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index efb5933aee7..d305348384a 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -152,6 +152,7 @@ bool isTvOp(const Expr* expr) { ScanOp, PreprocessGroupedMatmulInputSf, BlockQuantizationOp, + GroupedBlockQuantizationOp, LaunchDependentGridOp, WaitForPriorGridOp, kir::AllocTMem, diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index f786f6847ac..3b7952ea266 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -862,6 +862,241 @@ class ExprValidator : public OptOutDispatch { "contiguous IDs from the logical domain for BlockQuantizationOp: ", quantized_output->toString()); } + + void handle(GroupedBlockQuantizationOp* bqop) final { + auto inp_tv = bqop->input(0)->as(); + auto quantized_output = bqop->quantizedOutput()->as(); + auto block_scaling_factor = bqop->blockScales()->as(); + auto output_dtype = quantized_output->dtype(); + + NVF_ERROR_EQ( + inp_tv->getMemoryType(), + MemoryType::Local, + "Input must be a local memory tensor. Found: ", + inp_tv->getMemoryType()); + + NVF_ERROR_EQ( + quantized_output->getMemoryType(), + MemoryType::Local, + "Quantized output must be a local memory tensor. Found: ", + quantized_output->getMemoryType()); + + NVF_ERROR_EQ( + block_scaling_factor->getMemoryType(), + MemoryType::Global, + "Block scaling factor must be a global memory tensor. Found: ", + block_scaling_factor->getMemoryType()); + + if (output_dtype == DataType::Float8_e4m3fn) { + NVF_ERROR( + !bqop->hasGlobalScale(), + "Global scale is not supported when quantizing to Float8_e4m3fn."); + + NVF_ERROR( + !block_scaling_factor->hasAllocation(), + "Block scaling factor must not have an allocation domain when " + "quantizing to Float8_e4m3fn."); + } + + if (bqop->hasGlobalScale()) { + auto global_scale = bqop->globalScale()->as(); + + NVF_ERROR_EQ( + global_scale->getMemoryType(), + MemoryType::Global, + "Global scaling factor must be a global memory tensor. Found: ", + global_scale->getMemoryType()); + + NVF_ERROR_EQ( + global_scale->dtype(), + DataType::Float, + "Global scaling factor must be of type float. Found: ", + global_scale->dtype()); + } + + // Outputs have the same allocation domain + // as the logical domain - no allocation domain. + NVF_ERROR( + !quantized_output->hasAllocation(), + "Quantized output must not have an allocation domain."); + + IterDomain* grouped_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + grouped_id = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDx) { + thread_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDx) { + block_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDz) { + thread_z = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDz) { + block_z = loop_id; + } else if ( + loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial/Unswitch/Unroll ID in " + "GroupedBlockQuantizationOp"); + NVF_ERROR_EQ( + loop_id->extent()->evaluate().as(), + 1, + "Expected non-TID/BID/Group ID to have extent of 1 for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + } + } + + NVF_ERROR( + grouped_id != nullptr, + "One of the output IDs must be grouped for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for " + "GroupedBlockQuantizationOp: ", + bqop->toString()); + + auto inner_extent = grouped_id->extent()->evaluate().as(); + auto input_dtype = inp_tv->dtype(); + + NVF_ERROR( + ((inner_extent == 4 || inner_extent == 2) && + input_dtype == DataType::Float) || + ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The group dimension must be 2/4 (FP32) or 2/4/8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + bqop->toString()); + + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + + // Next we check the following scheduling requirements for + // GroupedBlockQuantizationOp - the above figure is an example of a valid + // schedule. + // 1. The Group ID must be derived from the innermost logical IDs + // 2. TIDx must follow the Group ID in the schedule -- that is when derived + // from the logical domain, group ID must be inner-most, the next + // "inner-most" should be TIDx (unless there is an ID with a unit trip + // count) + // 3. All merges involved from logical domains to group and thread ID must + // combine contiguous IDs + + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()}, + {quantized_output->getLoopDomain().begin(), + quantized_output->getLoopDomain().end()}); + + std::vector ids_to_transform = + quantized_output->getLogicalDomain(); + + std::deque frontier( + quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()); + + // This will get the xforms from logical to loop and apply them on the + // logical domain. We will get a loop domain minus the reordering. + // This pass also removes all IDs from frontier that were derived using + // non-contiguous merges. + scheduler_utils::applyTransforms( + ids_to_transform, transform_exprs, [&frontier](Expr* expr) { + traverseFrontierWithContiguityCheck(frontier, expr); + }); + + // The grouped ID must correspond to the innermost loop-like domain + NVF_ERROR( + ids_to_transform.back() == grouped_id, + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for GroupedBlockQuantizationOp. " + "TV: ", + quantized_output->toString()); + + // Iterate from the back to find TIDx, skipping group_id (last element) + // Ensure all IDs between group_id and TIDx have extent 1 + bool found_tidx = false; + for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); + ++it) { + if (*it == thread_x) { + found_tidx = true; + break; + } + // All non-TIDx IDs between Group ID and TIDx must have extent of 1 + NVF_ERROR( + (*it)->extent()->isConstInt() && + (*it)->extent()->evaluate().as() == 1, + "Expected IDs between Group ID and TIDx to have extent of 1 for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + } + + NVF_ERROR( + found_tidx, + "TIDx must follow the Group ID in the schedule for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + + // Check if grouped_id in frontier + auto grouped_it = std::ranges::find(frontier, grouped_id); + NVF_ERROR( + grouped_it != frontier.end(), + "All merge operations deriving the grouped ID must combine " + "contiguous IDs from the logical domain for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + // Do the same for thread_x + auto threadx_it = + std::ranges::find(frontier.begin(), frontier.end(), thread_x); + NVF_ERROR( + threadx_it != frontier.end(), + "All merge operations deriving the TIDx ID must combine " + "contiguous IDs from the logical domain for " + "GroupedBlockQuantizationOp: ", + quantized_output->toString()); + } }; } // namespace @@ -1869,7 +2104,8 @@ void validateAndConvertIterDomainGrouping(Fusion* fusion) { def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || def->isA() || - def->isA(), + def->isA() || + def->isA(), "Invalid use of ParallelType::Group: ", def->toString()); diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 822ababb149..f46423b498a 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -112,6 +112,7 @@ class Val; f(CutlassNvfp4GroupedMmaOp); \ f(PreprocessGroupedMatmulInputSf); \ f(BlockQuantizationOp); \ + f(GroupedBlockQuantizationOp); \ f(TopKOp); \ f(ScanOp); \ f(Merge); \ diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index d0dc64cbe2a..f472d41f061 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1878,7 +1878,12 @@ std::pair> SegmentedFusion::makeFusion( for (auto inp : getAllInputs(sg)) { auto clone_tv = complete_to_segment_map.clone(inp); fusion_segment->addInput(clone_tv); - if (inp->isDefinitionType()) { + if (inp->isDefinitionType() || + (inp->isDefinitionType() && + (inp == + inp->definition() + ->as() + ->blockScales()))) { // NOTE: inp is an input to fusion segment. // // There's no point of replaying allocation domain if we cannot index into diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index b3ec5e39c3a..175247d7d0e 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -1890,4 +1890,62 @@ std::vector BlockQuantizationOp::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(BlockQuantizationOp) +GroupedBlockQuantizationOp::GroupedBlockQuantizationOp( + IrBuilderPasskey passkey, + Val* output_scales, + Val* output, + Val* input, + Val* input_offsets, + Val* output_offsets, + BlockScalingFactorLayout layout, + Val* k, + Val* g, + Val* global_scale, + int64_t block_size, + Val* row_idx, + Val* col_idx) + : Expr(passkey) { + addOutput(output); + addOutput(output_scales); + addInput(input); + addInput(input_offsets); + addInput(output_offsets); + addInput(k); + addInput(g); + if (global_scale) { + addInput(global_scale); + } + addDataAttribute(block_size); + addDataAttribute(layout); + if (row_idx != nullptr) { + addAttribute(row_idx); + } + if (col_idx != nullptr) { + addAttribute(col_idx); + } +} + +std::string GroupedBlockQuantizationOp::toString(int indent_size) const { + // FIXME(jiej): update this to print out additional stuff. + std::stringstream ss; + indent(ss, indent_size) << "(" << blockScales()->toString() << ",\n " + << quantizedOutput()->toString() << ")\n" + << " = grouped_block_quantize(" << in()->toString() + << ")\n"; + return ss.str(); +} + +std::string GroupedBlockQuantizationOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "GroupedBlockQuantizationOp can not be printed inline"); +} + +std::vector GroupedBlockQuantizationOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + // This is a placeholder, currently we don't have a fallback kernel available + NVF_THROW("GroupedBlockQuantizationOp evaluation not yet implemented"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedBlockQuantizationOp) + } // namespace nvfuser diff --git a/csrc/ir/composite_nodes.h b/csrc/ir/composite_nodes.h index 13bf54de515..1019526add8 100644 --- a/csrc/ir/composite_nodes.h +++ b/csrc/ir/composite_nodes.h @@ -1107,4 +1107,96 @@ class BlockQuantizationOp : public Expr { const std::vector& inputs) const override; }; +class GroupedBlockQuantizationOp : public Expr { + public: + using Expr::Expr; + + // This op takes in a high precision input(input) + // and returns the quantized output(output) along with the block scaling + // factors (output_scales). It can also take as an optional input the global + // scaling factor and block size (though we currently only support 16). + // logical_index is used for internal implemtation. This op is currently + // implemented via a runtime function. During index computation, we compute + // the index of the output_scales and pass it to the runtime function. + GroupedBlockQuantizationOp( + IrBuilderPasskey, + Val* output_scales, + Val* output, + Val* input, + Val* input_offsets, + Val* output_offsets, + BlockScalingFactorLayout layout, + Val* k, + Val* g, + Val* global_scale = nullptr, + int64_t block_size = 16, + Val* row_idx = nullptr, + Val* col_idx = nullptr); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + Val* blockScales() const { + return output(1); + } + + Val* quantizedOutput() const { + return output(0); + } + + Val* in() const { + return input(0); + } + + int64_t blockSize() const { + return attribute(0); + } + + bool hasGlobalScale() const { + if (inputs().size() > 5) { + return true; + } + return false; + } + + Val* globalScale() const { + if (hasGlobalScale()) { + return input(5); + } + return nullptr; + } + + const char* getOpString() const override { + return "GroupedBlockQuantizationOp"; + } + + TensorView* inputOffsets() const { + return input(1)->as(); + } + + TensorView* outputOffsets() const { + return input(2)->as(); + } + + // get scalar - column size + Val* k() const { + return input(3); + } + + // get scalar - number of groups + Val* g() const { + return input(4); + } + + // get enum - block scaling factor layout + BlockScalingFactorLayout layout() const { + return attribute(1); + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; +}; + } // namespace nvfuser diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9418b137688..d24fdd1c879 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1337,7 +1337,11 @@ bool hasTrivialAllocationDomain(const TensorView* tv) { alloc | TensorDomain::kNoReductions | TensorDomain::kNoBroadcasts); } bool hasUniformSiblings(Expr* expr) { - return !expr->isOneOf(); + return !expr->isOneOf< + SdpaFwdOp, + SdpaBwdOp, + BlockQuantizationOp, + GroupedBlockQuantizationOp>(); } bool mayRequireAllocation(const TensorView* tv, IterDomain* id) { diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 55552d2638d..76cea7675bd 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -324,6 +324,10 @@ class KernelIrScanner : private IrVisitor { summary_.has_block_quantize_op = true; } + void handle(GroupedBlockQuantizationOp* bqop) final { + summary_.has_block_quantize_op = true; + } + void handle(ScanOp* scan) final { summary_.has_scan = true; } diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 26f6200a757..56ea5b324ef 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -141,14 +141,33 @@ std::pair, bool> getNonMappingDomainInfo( // as it's extent is reduced by a factor of the block size // for example [i0, i1] => [i0, i1/16] where 16 is the block size. // Make sure the producer isn't the global scale. - if (consumer_tv == - consumer_tv->definition() - ->as() - ->blockScales() && - producer_tv != - consumer_tv->definition() - ->as() - ->globalScale()) { + Val* block_scales = + consumer_tv->definition()->as()->blockScales(); + Val* global_scale = + consumer_tv->definition()->as()->globalScale(); + + if (consumer_tv == block_scales && producer_tv != global_scale) { + auto producer_logical = + TensorDomain::noReductions(producer_tv->getLogicalDomain()); + auto last_logical_dim = producer_logical.size() - 1; + non_mapping_ids.insert(producer_logical.at(last_logical_dim)); + // We are mapping everything but the last ID. + has_consumer_id = true; + } + } else if ( + auto grouped_bqop = dynamic_cast( + consumer_tv->definition())) { + if (producer_tv != grouped_bqop->in()) { + auto producer_logical = + TensorDomain::noReductions(producer_tv->getLogicalDomain()); + non_mapping_ids.insert(producer_logical.begin(), producer_logical.end()); + // we are not mapping anything, `has_consumer_id` doesn't matter. + has_consumer_id = false; + } else if (consumer_tv == grouped_bqop->blockScales()) { + // We don't map the inner-most dimension of the block scaling factors + // as it's extent is reduced by a factor of the block size + // for example [i0, i1] => [i0, i1/16] where 16 is the block size. + // Make sure the producer isn't the global scale. auto producer_logical = TensorDomain::noReductions(producer_tv->getLogicalDomain()); auto last_logical_dim = producer_logical.size() - 1; @@ -1387,7 +1406,8 @@ void ComputeAtLogicalDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { NVF_ERROR( expr->isA() || expr->isA() || expr->isA() || expr->isA() || - expr->isA(), + expr->isA() || + expr->isA(), "Unknown multi-output Expr type ", expr->getOpString(), " is found"); diff --git a/csrc/logical_domain_map.h b/csrc/logical_domain_map.h index 3d76758ddb6..133d526fbf6 100644 --- a/csrc/logical_domain_map.h +++ b/csrc/logical_domain_map.h @@ -550,6 +550,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(op); } + void handle(GroupedBlockQuantizationOp* op) override { + mapPointwiseLikeOp(op); + } + void handle(TensorView* tv) override; //! Maps all pending mappings. diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 64f68cbf1a4..266bb30e616 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -2753,4 +2754,144 @@ BlockQuantizationResults blockQuantize( return BlockQuantizationResults(quantized_tensor, block_scales); } +BlockQuantizationResults groupedBlockQuantize( + TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets, + BlockScalingFactorLayout layout, + TensorView* global_scaling_factor, + int64_t block_size, + DataType out_dtype) { + NVF_CHECK( + out_dtype == DataType::Float4_e2m1fn || + out_dtype == DataType::Float8_e4m3fn, + "Currently only output data type of Float4_e2m1fn or Float8_e4m3fn is " + "supported"); + if (out_dtype == DataType::Float4_e2m1fn) { + NVF_ERROR_EQ( + block_size, + 16, + "Block size must be 16 for Float4_e2m1fn, got ", + block_size); + } else if (out_dtype == DataType::Float8_e4m3fn) { + NVF_ERROR_EQ( + block_size, + 32, + "Block size must be 32 for Float8_e4m3fn, got ", + block_size); + NVF_CHECK( + !global_scaling_factor, + "global_scaling_factor must be nullptr for Float8_e4m3fn"); + } + + // Validate input data type + // We'll only support FP32 or BF16/FP16 + NVF_CHECK( + input->getDataType().value() == DataType::Float || + input->getDataType().value() == DataType::BFloat16 || + input->getDataType().value() == DataType::Half, + "Grouped block quantization expects floating point input but got ", + input->getDataType().value()); + + // Check that if global_scaling_factor in non-null + // then it is a scalar float TensorView + if (global_scaling_factor != nullptr) { + NVF_CHECK( + TensorDomain::noReductions(global_scaling_factor->getLogicalDomain()) + .empty(), + "Global scaling factor for grouped block quantization must be a scalar " + "tensor"); + NVF_CHECK( + global_scaling_factor->getDataType().value() == DataType::Float, + "Global scaling factor for grouped block quantization must be of float " + "data " + "type"); + } + + auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain()); + + // Validate input tensor is not zero-dimensional + NVF_CHECK( + !inp_domain.empty(), + "Grouped block quantization does not support zero-dimensional tensors"); + + // Create output domain for quantized tensor (same shape as input) + std::vector quantized_out_domain; + quantized_out_domain.reserve(inp_domain.size()); + + for (auto inp_domain_ptr : inp_domain) { + quantized_out_domain.push_back(inp_domain_ptr->cloneWithoutRFactor()); + } + + // Create output tensors + TensorView* quantized_tensor = IrBuilder::create( + IrBuilder::create( + quantized_out_domain, + TensorDomain::getContiguityFilledWith(quantized_out_domain, true)), + out_dtype); + + // Create output blocked scaling factor + auto block_scales_dtype = (out_dtype == DataType::Float4_e2m1fn) + ? DataType::Float8_e4m3fn + : DataType::Float8_e8m0fnu; + NVF_ERROR_EQ(inp_domain.size(), 2); + + // This is used for both root and loop domain on output + // maps directly to input's logical domain. + std::vector scales_out_domain; + scales_out_domain.reserve(inp_domain.size()); + + for (auto inp_id : inp_domain) { + if (inp_id == inp_domain.back()) { + scales_out_domain.push_back( + IterDomainBuilder( + inp_id->start(), + SimplifyingIrBuilder::divExpr( + inp_id->extent(), + IrBuilder::create(block_size, DataType::Index))) + .build()); + + } else { + scales_out_domain.push_back(inp_id->cloneWithoutRFactor()); + } + } + + std::vector offset_logical_dom = + TensorDomain::noReductions(input_offsets->getLogicalDomain()); + Val* num_groups = offset_logical_dom[0]->extent(); + + // Create the allocation domain of output. + std::vector out_alloc_dom = + layoutAllocationDomain(scales_out_domain, num_groups, layout); + + // Create block scaling factors + TensorView* block_scales = IrBuilder::create( + IrBuilder::create( + /*root_domain=*/std::vector(), + /*logical_domain=*/scales_out_domain, + /*allocation=*/out_alloc_dom, + /*loop_domain=*/scales_out_domain, + /*alternate_loop_domain=*/std::nullopt, + /*contiguity=*/ + TensorDomain::getContiguityFilledWith(out_alloc_dom, true), + /*additional_ids=*/std::vector(), + /*skip_checks=*/true), + block_scales_dtype); + + // Create the grouped block quantization operation + IrBuilder::create( + block_scales, + quantized_tensor, + input, + input_offsets, + output_offsets, + layout, + inp_domain[1]->getMaybeExpandedExtent(), + num_groups, + global_scaling_factor, + block_size); + + return BlockQuantizationResults(quantized_tensor, block_scales); +} + } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index e49a1a416f1..f4595d36561 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -855,4 +855,13 @@ NVF_API BlockQuantizationResults blockQuantize( bool swizzle_scales = false, DataType out_dtype = DataType::Float4_e2m1fn); +NVF_API BlockQuantizationResults groupedBlockQuantize( + TensorView* input, + TensorView* input_offsets, + TensorView* output_offsets, + BlockScalingFactorLayout layout, + TensorView* global_scaling_factor = nullptr, + int64_t block_size = 16, + DataType out_dtype = DataType::Float4_e2m1fn); + } // namespace nvfuser diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index a86e2f08495..11500f55381 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -300,7 +300,9 @@ bool PointWiseScheduler::canScheduleRunTime( data_cache, [fusion]() { return std::make_unique( - !ir_utils::getOpsOfType(fusion).empty()); + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); }) .get(); @@ -417,6 +419,26 @@ std::unique_ptr PointWiseScheduler::computeHeuristics( fusion, runtime_info, data_cache, prop); } NVF_ERROR(pparams != nullptr); + + // cap vectorization when block quantization op is encountered, since there's + // a validation during device_lower + auto has_block_quantization_ops = + HeuristicDataCacheEntry( + data_cache, + [fusion]() { + return std::make_unique( + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); + }) + .get(); + if (has_block_quantization_ops) { + // FIXME: this needs to be done per input dtype. I'm capping it as 4 for + // simplicity for now. + pparams->as()->vectorization_factor = std::min( + 4, pparams->as()->vectorization_factor); + } + return pparams; } diff --git a/csrc/scheduler/pointwise_non_tma.cpp b/csrc/scheduler/pointwise_non_tma.cpp index 37d30b8049c..bdf66b8b14c 100644 --- a/csrc/scheduler/pointwise_non_tma.cpp +++ b/csrc/scheduler/pointwise_non_tma.cpp @@ -148,7 +148,9 @@ int64_t getUnrollFactor( data_cache, [fusion]() { return std::make_unique( - !ir_utils::getOpsOfType(fusion).empty()); + !ir_utils::getOpsOfType(fusion).empty() || + !ir_utils::getOpsOfType(fusion) + .empty()); }) .get(); @@ -603,11 +605,16 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // We do so as the runtime function for block quantization expects 2/4/8 // elements per thread. auto bq_ops = ir_utils::getOpsOfType(fusion); + auto gbq_ops = ir_utils::getOpsOfType(fusion); std::vector nvfp4_quantized_outputs = {}; for (auto bq_op : bq_ops) { nvfp4_quantized_outputs.push_back( bq_op->quantizedOutput()->as()); } + for (auto gbq_op : gbq_ops) { + nvfp4_quantized_outputs.push_back( + gbq_op->quantizedOutput()->as()); + } if (pparams->vectorization_factor > 1) { // Grab all tensor views that should be vectorized diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 490ffd6d036..b2d14ffcffa 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -851,6 +851,13 @@ bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { if (!block_scales->isFusionOutput()) { return true; } + } else if (expr->isA()) { + auto block_scales = expr->as() + ->blockScales() + ->as(); + if (!block_scales->isFusionOutput()) { + return true; + } } } return false; @@ -1084,6 +1091,7 @@ bool SchedulerTopologyChecker::rejectScheduleFusionGlobalBufferRequirement( return true; } } + // FIXME: I think I needed to do the same for GroupedBlockQuantizationOp } return false; } diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 55abe2d0914..8d9cc77c57e 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -66,6 +66,13 @@ bool canIgnoreIndexedInputDomainID( input_tv == layout->outputOffsets()) { continue; } + } else if (auto layout = dynamic_cast(use)) { + // since we don't index into offsets, scheduler doesn't need to cover + // offset TVs ID. + if (input_tv == layout->inputOffsets() || + input_tv == layout->outputOffsets()) { + continue; + } } else { // If the input TV is used by any other ops return false; @@ -420,6 +427,12 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { output_tv == output_tv->definition() ->as() + ->blockScales()) || + (output_tv->definition() && + output_tv->definition()->isA() && + output_tv == + output_tv->definition() + ->as() ->blockScales())) { continue; } diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 17ac3c44328..f32a0f8eb93 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1354,17 +1354,18 @@ std::vector> cacheInputs( // TODO: we might need to explicitly promote offsets to global memory // We expect offsets to remain in global memory, so we do not add it to // cache - auto isPreprocessGroupedMatmulInputSfOffsets = [tv](Expr* use) { - if (!use->isA()) { - return false; + auto isGroupOffsets = [tv](Expr* use) { + if (auto op = dynamic_cast(use)) { + return tv == op->inputOffsets() || tv == op->outputOffsets(); + } else if (auto op = dynamic_cast(use)) { + return tv == op->inputOffsets() || tv == op->outputOffsets(); } - auto layout = use->as(); - return tv == layout->inputOffsets() || tv == layout->outputOffsets(); + return false; }; std::vector cached_uses; for (auto use : tv->uses()) { if (!use->isOneOf() && !isGatherLookUpTvInUse(use) && - !isPreprocessGroupedMatmulInputSfOffsets(use)) { + !isGroupOffsets(use)) { cached_uses.push_back(use); } } @@ -1408,7 +1409,11 @@ std::vector> cacheAndForkOutputs( ->isOneOf() || (output->definition()->isA() && output->definition()->as()->blockScales() == - output)) { + output) || + (output->definition()->isA() && + output->definition() + ->as() + ->blockScales() == output)) { continue; } if (!output->uses().empty()) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 89d83a97eea..7ab282644c9 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -356,6 +356,12 @@ inferAndValidateAllocationSizesAndStrides( if (bqop->isSwizzledScales() && tv == bqop->blockScales()) { skip_validation = true; } + } else if ( + tv->definition() && tv->definition()->isA()) { + auto bqop = tv->definition()->as(); + if (tv == bqop->blockScales()) { + skip_validation = true; + } } // Skip validation for scale input to ScaledMmaOp as it will be swizzled. diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index 1a2fa739f27..1535f45e428 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -62,7 +62,8 @@ bool validateGroupedLayout( .transpose(1, 3) .reshape({mn_tile * 4 * 32, k_tile * 4}) .slice(0, 0, m_g) - .slice(1, 0, k); + .slice(1, 0, k) + .to(ref.dtype()); auto ref_g = ref.slice( 0, expert_offsets[i].item().to(), @@ -376,4 +377,71 @@ TEST_F(LayoutOpTest, Inlining) { EXPECT_EQ(inp_cache->getComputeAtPosition(), 2); } +TEST_F(LayoutOpTest, GroupedBlockQuantizeOp) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto inp = makeSymbolicTensor(2); + auto offsets = makeSymbolicTensor(1, DataType::Int32); + auto rounded_offsets = makeSymbolicTensor(1, DataType::Int32); + fusion.addInput(inp); + fusion.addInput(offsets); + fusion.addInput(rounded_offsets); + + auto outs = groupedBlockQuantize( + inp, offsets, rounded_offsets, BlockScalingFactorLayout::Block128x4); + fusion.addOutput(castOp(DataType::Float, outs.quantized_tensor)); + fusion.addOutput(outs.block_scales); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + int m = 512; + int k = 9 * 16; // note: padded column size needs to be a multiple of 16 + auto t0 = at::randn({m, k}, options); + + // tokens per group are [100, 150, 262] respectively, so each group would be + // padded to multiple of 128. Hence the total output row span would cover a + // length of 128 + 256 + 384 = 768. + auto t1 = at::tensor({0, 100, 250}, options.dtype(at::kInt)); + auto t2 = at::tensor({0, 128, 384}, options.dtype(at::kInt)); + + // automatic scheduling. + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0, t1, t2}); + + at::Tensor ref_block_sf; + at::Tensor ref_scaled_out; + // producing reference + { + std::unique_ptr fusion_new_op = std::make_unique(); + FusionGuard fg2(fusion_new_op.get()); + auto tv_in = makeContigTensor(2); + fusion_new_op->addInput(tv_in); + auto quantization_results = + blockQuantize(tv_in, nullptr, /*block_size=*/16, false); + + fusion_new_op->addOutput(quantization_results.block_scales); + fusion_new_op->addOutput( + castOp(DataType::Float, quantization_results.quantized_tensor)); + FusionExecutorCache executor_cache(std::move(fusion_new_op)); + auto outputs_new_op = executor_cache.runFusionWithInputs({t0}); + ref_block_sf = outputs_new_op[0].as().to(at::kFloat); + ref_scaled_out = outputs_new_op[1].as(); + } + + // check scaled output + EXPECT_TRUE(at::allclose(ref_scaled_out, outputs[0].as())); + // check block scaling factor + ASSERT_TRUE(validateGroupedLayout( + BlockScalingFactorLayout::Block128x4, + outputs[1].as(), + ref_block_sf, + t1, + t2)); + + EXPECT_THAT( + executor_cache.getMostRecentKernelRuntime()->fusionSegments()->groups(), + UnorderedElementsAre(HeuristicIs(SchedulerType::PointWise))); +} + } // namespace nvfuser From 7e26890786d5fed469acba67103b35f85d85194f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 7 Jan 2026 17:01:04 -0800 Subject: [PATCH 03/13] delete include of removed header --- csrc/runtime/compiled_kernel.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index 4ce528a510e..2532d249267 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -58,7 +58,6 @@ #include #include #include -#include #include #include #include From fc79a9c9ee0d63aa37a3befee0044bfa572ac0dc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 7 Jan 2026 18:16:59 -0800 Subject: [PATCH 04/13] code clean up --- csrc/codegen.cpp | 12 ++++--- csrc/device_lower/pass/index.cpp | 10 ------ csrc/device_lower/validation.cpp | 53 ++++--------------------------- csrc/ir/composite_nodes.cpp | 4 +-- csrc/scheduler/registry_utils.cpp | 14 +++++++- 5 files changed, 29 insertions(+), 64 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index fcbd320a7d3..9b89216aae8 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -2000,12 +2000,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } } - auto fn_call = output_dtype == DataType::Float4_e2m1fn - ? "bq::grouped_block_quantize_to_nvfp4" - : "bq::grouped_block_quantize_to_mxfp8"; + NVF_ERROR( + output_dtype == DataType::Float4_e2m1fn, + "only nvfp4 output is implemented"); // Generate the function call - indent() << genCall(fn_call, template_args, func_args) << ";\n"; + indent() << genCall( + "bq::grouped_block_quantize_to_nvfp4", + template_args, + func_args) + << ";\n"; } std::string genReductionOp(BinaryOpType op_type, DataType data_type) { diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index a6555c402f5..9f0e042721d 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -448,18 +448,8 @@ void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) { grouped_bqop->quantizedOutput()->as(), grouped_bqop->fusion()->zeroVal()); - // The GroupedBlockQuantizationOp funnels down to a runtime function. - // We pass the index for the block scaling factors output. We compute - // the index bases on the logical indices of the quantized output tensor. - // Then inside the runtime function, we divide this linearized index by 16 - // (the block size) to get the index for the scaling factors. - // We get the linearized index as follows: - // We get the logical indices for the quantized output. - // We then multiply and accumulate them using the logical extents of the - // quantized output tensor to get the linearized index. std::vector logical_index = Index::getConsumerPerDimLogicalIndex( grouped_bqop->quantizedOutput()->as(), for_loops_); - NVF_ERROR( logical_index.size() == 2, "only matrices are supported in GroupedBlockQuantizationOp"); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 3b7952ea266..8be4cdd0b76 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -756,6 +756,8 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); + // [ NOTE: check scheduling requirements for block quantization ] + // // M K // │ │ // ▼ ▼ @@ -887,16 +889,9 @@ class ExprValidator : public OptOutDispatch { "Block scaling factor must be a global memory tensor. Found: ", block_scaling_factor->getMemoryType()); - if (output_dtype == DataType::Float8_e4m3fn) { - NVF_ERROR( - !bqop->hasGlobalScale(), - "Global scale is not supported when quantizing to Float8_e4m3fn."); - - NVF_ERROR( - !block_scaling_factor->hasAllocation(), - "Block scaling factor must not have an allocation domain when " - "quantizing to Float8_e4m3fn."); - } + NVF_ERROR( + output_dtype != DataType::Float8_e4m3fn, + "output of Float8_e4m3fn is not yet implemented"); if (bqop->hasGlobalScale()) { auto global_scale = bqop->globalScale()->as(); @@ -988,43 +983,7 @@ class ExprValidator : public OptOutDispatch { ". Expr: ", bqop->toString()); - // M K - // │ │ - // ▼ ▼ - // ┌────────────┐ - // │ merge │ - // └─────┬──────┘ - // │ - // ▼ - // M*K - // ┌──────────┐ - // │ split ┼──┐ - // └─┬────────┘ │ - // ▼ ▼ - // (M*K)/4 4(G) - // ┌────────┐ - // │ split ┼────┐ - // └─┬──────┘ │ - // ▼ ▼ - // (M*K)/4 1(U) - // ┌─────────┐ - // │ split │ - // ┌─┼ ┼───┐ - // │ └─────────┘ │ - // ▼ ▼ - // (M*K)/4/128 128(Tx) - - // Next we check the following scheduling requirements for - // GroupedBlockQuantizationOp - the above figure is an example of a valid - // schedule. - // 1. The Group ID must be derived from the innermost logical IDs - // 2. TIDx must follow the Group ID in the schedule -- that is when derived - // from the logical domain, group ID must be inner-most, the next - // "inner-most" should be TIDx (unless there is an ID with a unit trip - // count) - // 3. All merges involved from logical domains to group and thread ID must - // combine contiguous IDs - + // see [ NOTE: check scheduling requirements for block quantization ] auto transform_exprs = DependencyCheck::getAllExprsBetween( {quantized_output->getLogicalDomain().begin(), quantized_output->getLogicalDomain().end()}, diff --git a/csrc/ir/composite_nodes.cpp b/csrc/ir/composite_nodes.cpp index 175247d7d0e..6e4167f5ded 100644 --- a/csrc/ir/composite_nodes.cpp +++ b/csrc/ir/composite_nodes.cpp @@ -1926,12 +1926,12 @@ GroupedBlockQuantizationOp::GroupedBlockQuantizationOp( } std::string GroupedBlockQuantizationOp::toString(int indent_size) const { - // FIXME(jiej): update this to print out additional stuff. std::stringstream ss; indent(ss, indent_size) << "(" << blockScales()->toString() << ",\n " << quantizedOutput()->toString() << ")\n" << " = grouped_block_quantize(" << in()->toString() - << ")\n"; + << ",\n " << inputOffsets()->toString() << ",\n " + << outputOffsets()->toString() << ")\n"; return ss.str(); } diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index b2d14ffcffa..c2cf2a0cd07 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -1090,8 +1090,20 @@ bool SchedulerTopologyChecker::rejectScheduleFusionGlobalBufferRequirement( layout_op, layout_op->outputOffsets(), scheduler_type)) { return true; } + } else if (expr->isA()) { + // The runtime function of GroupedBlockQuantizationOp needs: + // 1. Write scale output directly to global memory + // 2. Read two offset inputs directly from global memory + auto grouped_bop = expr->as(); + if (rejectScheduleFusionOutputRequirement( + grouped_bop, grouped_bop->out(), scheduler_type) || + rejectScheduleFusionInputRequirement( + grouped_bop, grouped_bop->inputOffsets(), scheduler_type) || + rejectScheduleFusionInputRequirement( + grouped_bop, grouped_bop->outputOffsets(), scheduler_type)) { + return true; + } } - // FIXME: I think I needed to do the same for GroupedBlockQuantizationOp } return false; } From e0c2bc0b59d8801521663395cf92102e18705954 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 9 Jan 2026 14:59:52 -0800 Subject: [PATCH 05/13] addressing review comments and clangtidy --- csrc/ops/arith.cpp | 8 +++----- csrc/scheduler/registry_utils.cpp | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 266bb30e616..087d1a046cc 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2810,10 +2810,9 @@ BlockQuantizationResults groupedBlockQuantize( auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain()); - // Validate input tensor is not zero-dimensional - NVF_CHECK( - !inp_domain.empty(), - "Grouped block quantization does not support zero-dimensional tensors"); + // Validate input tensor is 2d + NVF_ERROR_EQ(inp_domain.size(), 2, + "Grouped block quantization only supports 2-dimensional tensors"); // Create output domain for quantized tensor (same shape as input) std::vector quantized_out_domain; @@ -2834,7 +2833,6 @@ BlockQuantizationResults groupedBlockQuantize( auto block_scales_dtype = (out_dtype == DataType::Float4_e2m1fn) ? DataType::Float8_e4m3fn : DataType::Float8_e8m0fnu; - NVF_ERROR_EQ(inp_domain.size(), 2); // This is used for both root and loop domain on output // maps directly to input's logical domain. diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index c2cf2a0cd07..8b1c545f2f6 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -1096,7 +1096,7 @@ bool SchedulerTopologyChecker::rejectScheduleFusionGlobalBufferRequirement( // 2. Read two offset inputs directly from global memory auto grouped_bop = expr->as(); if (rejectScheduleFusionOutputRequirement( - grouped_bop, grouped_bop->out(), scheduler_type) || + grouped_bop, grouped_bop->blockScales(), scheduler_type) || rejectScheduleFusionInputRequirement( grouped_bop, grouped_bop->inputOffsets(), scheduler_type) || rejectScheduleFusionInputRequirement( From fdac20849fa808887ee77e82724f6dfc037feb5b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 9 Jan 2026 15:02:25 -0800 Subject: [PATCH 06/13] clangformat --- csrc/ops/arith.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 087d1a046cc..9baf7536e82 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2811,7 +2811,9 @@ BlockQuantizationResults groupedBlockQuantize( auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain()); // Validate input tensor is 2d - NVF_ERROR_EQ(inp_domain.size(), 2, + NVF_ERROR_EQ( + inp_domain.size(), + 2, "Grouped block quantization only supports 2-dimensional tensors"); // Create output domain for quantized tensor (same shape as input) From 3051db0f0283668240187e08de593de579345668 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 9 Jan 2026 15:24:35 -0800 Subject: [PATCH 07/13] fixing block size --- csrc/device_lower/pass/index.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 9f0e042721d..e70f092fae8 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -485,7 +485,7 @@ void IndexLowering::handle(const GroupedBlockQuantizationOp* grouped_bqop) { grouped_bqop->k(), grouped_bqop->g(), grouped_bqop->globalScale(), - 16, + grouped_bqop->blockSize(), logical_index[0], logical_index[1])); GpuLower::current()->propagateExprInfo(grouped_bqop, back()); From de33d7b92033749d791d0e53f664267ea1cb27fc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 12:02:11 -0800 Subject: [PATCH 08/13] skip test for < sm10.0 hardware --- tests/cpp/test_layout_op.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/cpp/test_layout_op.cpp b/tests/cpp/test_layout_op.cpp index a0e51a30262..8ac5e9df436 100644 --- a/tests/cpp/test_layout_op.cpp +++ b/tests/cpp/test_layout_op.cpp @@ -378,6 +378,10 @@ TEST_F(LayoutOpTest, Inlining) { } TEST_F(LayoutOpTest, GroupedBlockQuantizeOp) { + if (cudaArchGuardShouldSkip(10, 0)) { + GTEST_SKIP() << "skipping test because fp8 requires compute capability " + "10.0 or above"; + } auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); From 5feef0a738c05730d85af91d72c25bc019d0304b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 14:53:41 -0800 Subject: [PATCH 09/13] addressing review comments --- csrc/codegen.cpp | 32 +- csrc/device_lower/validation.cpp | 505 ++++++++++------------------ csrc/ops/arith.h | 2 + csrc/scheduler/registry_utils.cpp | 11 +- csrc/scheduler/tools/domain_map.cpp | 1 + 5 files changed, 203 insertions(+), 348 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 9b89216aae8..9c5a4132e72 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1900,13 +1900,13 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Special handling of GroupedBlockQuantizationOp to call the runtime // function. - void handle(const GroupedBlockQuantizationOp* bqop) final { + void handle(const GroupedBlockQuantizationOp* grouped_bqop) final { // This operator is plumbed down to a runtime function call. // One of the assumptions is that the device runtime expects // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, // 4, or 8 for Half. We achieve this by having the quantized output tv // scheduled to have the inner dimension grouped by 2/4/8. - auto output = bqop->quantizedOutput()->as()->view(); + auto output = grouped_bqop->quantizedOutput()->as()->view(); auto output_dtype = output->getDataType(); // Extract group size from the loop domain @@ -1922,7 +1922,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Validate group size based on input data type const auto input_dtype = - bqop->in()->as()->view()->getDataType().value(); + grouped_bqop->in()->as()->view()->getDataType().value(); const bool is_half_precision = (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half); const bool is_valid_group_size = is_half_precision @@ -1938,15 +1938,15 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { ". Found: ", group_size, ". Expr: ", - bqop->toString()); + grouped_bqop->toString()); // Build template arguments ArgumentBuilder template_args; // No global scale is required when quantizing to mxfp8 if (output_dtype == DataType::Float4_e2m1fn) { - template_args.arg(bqop->hasGlobalScale()); + template_args.arg(grouped_bqop->hasGlobalScale()); } - switch (bqop->layout()) { + switch (grouped_bqop->layout()) { case BlockScalingFactorLayout::Block128x4: template_args.arg(32); // block_row_outer template_args.arg(4); // block_row_inner @@ -1961,27 +1961,27 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // Build function arguments ArgumentBuilder func_args; func_args.arg(genInline( - bqop->input(0)->as()->view())); // input data + grouped_bqop->input(0)->as()->view())); // input data func_args.arg(genInline(output)); // quantized output func_args.arg(genInline( - bqop->blockScales()->as()->view())); // block scales + grouped_bqop->blockScales()->as()->view())); // block scales // generate logical index for runtime function - func_args.arg(genInline(bqop->attributeVal(2))); - func_args.arg(genInline(bqop->attributeVal(3))); - func_args.arg("&").append(genVariableName(bqop->inputOffsets()) + "[0]"); - func_args.arg("&").append(genVariableName(bqop->outputOffsets()) + "[0]"); - func_args.arg(genInline(bqop->k())); - func_args.arg(genInline(bqop->g())); + func_args.arg(genInline(grouped_bqop->attributeVal(2))); + func_args.arg(genInline(grouped_bqop->attributeVal(3))); + func_args.arg("&").append(genVariableName(grouped_bqop->inputOffsets()) + "[0]"); + func_args.arg("&").append(genVariableName(grouped_bqop->outputOffsets()) + "[0]"); + func_args.arg(genInline(grouped_bqop->k())); + func_args.arg(genInline(grouped_bqop->g())); if (output_dtype == DataType::Float4_e2m1fn) { func_args.arg( - bqop->hasGlobalScale() ? genInline(bqop->globalScale()) : "{}"); + grouped_bqop->hasGlobalScale() ? genInline(grouped_bqop->globalScale()) : "{}"); } // Add swizzled allocation domain parameters if needed // This is always skipped when quantizing to mxfp8 - auto block_scales_tv = bqop->blockScales()->as()->view(); + auto block_scales_tv = grouped_bqop->blockScales()->as()->view(); if (block_scales_tv->hasAllocation()) { auto logical_domain = TensorDomain::noReductions(block_scales_tv->getLogicalDomain()); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 8be4cdd0b76..8411ed04844 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -31,6 +31,184 @@ namespace nvfuser { namespace { +void validateQuantizedOutputScheduling(TensorView* quantized_output) { + // Outputs have the same allocation domain + // as the logical domain - no allocation domain. + NVF_ERROR( + !quantized_output->hasAllocation(), + "Quantized output must not have an allocation domain."); + + IterDomain* grouped_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + grouped_id = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDx) { + thread_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDx) { + block_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDz) { + thread_z = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDz) { + block_z = loop_id; + } else if ( + loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial/Unswitch/Unroll ID in ", + quantized_output->definition()->toString()); + NVF_ERROR_EQ( + loop_id->extent()->evaluate().as(), + 1, + "Expected non-TID/BID/Group ID to have extent of 1 for ", + bqop->toString()); + } + } + + NVF_ERROR( + grouped_id != nullptr, + "One of the output IDs must be grouped for ", + bqop->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using: ", + bqop->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for ", + bqop->toString()); + + auto inner_extent = grouped_id->extent()->evaluate().as(); + auto input_dtype = inp_tv->dtype(); + + NVF_ERROR( + ((inner_extent == 4 || inner_extent == 2) && + input_dtype == DataType::Float) || + ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The group dimension must be 2/4 (FP32) or 2/4/8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + bqop->toString()); + + // [ NOTE: check scheduling requirements for block quantization ] + // + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + + // Next we check the following scheduling requirements for + // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an example of a valid schedule. + // 1. The Group ID must be derived from the innermost logical IDs + // 2. TIDx must follow the Group ID in the schedule -- that is when derived + // from the logical domain, group ID must be inner-most, the next + // "inner-most" should be TIDx (unless there is an ID with a unit trip + // count) + // 3. All merges involved from logical domains to group and thread ID must + // combine contiguous IDs + + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()}, + {quantized_output->getLoopDomain().begin(), + quantized_output->getLoopDomain().end()}); + + std::vector ids_to_transform = + quantized_output->getLogicalDomain(); + + std::deque frontier( + quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()); + + // This will get the xforms from logical to loop and apply them on the + // logical domain. We will get a loop domain minus the reordering. + // This pass also removes all IDs from frontier that were derived using + // non-contiguous merges. + scheduler_utils::applyTransforms( + ids_to_transform, transform_exprs, [&frontier](Expr* expr) { + traverseFrontierWithContiguityCheck(frontier, expr); + }); + + // The grouped ID must correspond to the innermost loop-like domain + NVF_ERROR( + ids_to_transform.back() == grouped_id, + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for TV: " + quantized_output->toString()); + + // Iterate from the back to find TIDx, skipping group_id (last element) + // Ensure all IDs between group_id and TIDx have extent 1 + bool found_tidx = false; + for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); + ++it) { + if (*it == thread_x) { + found_tidx = true; + break; + } + // All non-TIDx IDs between Group ID and TIDx must have extent of 1 + NVF_ERROR( + (*it)->extent()->isConstInt() && + (*it)->extent()->evaluate().as() == 1, + "Expected IDs between Group ID and TIDx to have extent of 1 for ", + quantized_output->toString()); + } + + NVF_ERROR( + found_tidx, + "TIDx must follow the Group ID in the schedule for ", + quantized_output->toString()); + + // Check if grouped_id in frontier + auto grouped_it = std::ranges::find(frontier, grouped_id); + NVF_ERROR( + grouped_it != frontier.end(), + "All merge operations deriving the grouped ID must combine " + "contiguous IDs from the logical domain for: ", + quantized_output->toString()); + // Do the same for thread_x + auto threadx_it = + std::ranges::find(frontier.begin(), frontier.end(), thread_x); + NVF_ERROR( + threadx_it != frontier.end(), + "All merge operations deriving the TIDx ID must combine " + "contiguous IDs from the logical domain for: ", + quantized_output->toString()); +} + + //! Validate multiple output tensors of the same expression, i.e., //! siblings, have valid domains and parallel types. Since siblings //! are placed in the same loop nest, they must be parallelized the @@ -662,12 +840,6 @@ class ExprValidator : public OptOutDispatch { global_scale->dtype()); } - // Outputs have the same allocation domain - // as the logical domain - no allocation domain. - NVF_ERROR( - !quantized_output->hasAllocation(), - "Quantized output must not have an allocation domain."); - // When output scales is swizzled we will need to allow these checks // to be relaxed. We will need to ensure that the swizzling // allocation allowed is a fixed pattern: @@ -689,180 +861,8 @@ class ExprValidator : public OptOutDispatch { [](std::optional c) { return c.value_or(true); }), "Block scaling factor not contiguous"); - IterDomain* grouped_id = nullptr; - IterDomain* thread_x = nullptr; - IterDomain* block_x = nullptr; - IterDomain* thread_z = nullptr; - IterDomain* block_z = nullptr; - - for (const auto& loop_id : quantized_output->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group) { - grouped_id = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDx) { - thread_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDx) { - block_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDz) { - thread_z = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDz) { - block_z = loop_id; - } else if ( - loop_id->getParallelType() == ParallelType::Serial || - loop_id->getParallelType() == ParallelType::Unswitch || - loop_id->getParallelType() == ParallelType::Unroll) { - // Check this is ID has a constant extent and is 1 - NVF_ERROR( - loop_id->extent()->isConstInt(), - "Expected constant extent for Serial/Unswitch/Unroll ID in " - "BlockQuantizationOp"); - NVF_ERROR_EQ( - loop_id->extent()->evaluate().as(), - 1, - "Expected non-TID/BID/Group ID to have extent of 1 for " - "BlockQuantizationOp: ", - bqop->toString()); - } - } - - NVF_ERROR( - grouped_id != nullptr, - "One of the output IDs must be grouped for " - "BlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - thread_x != nullptr && block_x != nullptr, - "Need to have both TIDx and BIDx when using BlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - !thread_z && !block_z, - "Parallelization along z axis is not supported for " - "BlockQuantizationOp: ", - bqop->toString()); - - auto inner_extent = grouped_id->extent()->evaluate().as(); - auto input_dtype = inp_tv->dtype(); - - NVF_ERROR( - ((inner_extent == 4 || inner_extent == 2) && - input_dtype == DataType::Float) || - ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && - (input_dtype == DataType::BFloat16 || - input_dtype == DataType::Half)), - "The group dimension must be 2/4 (FP32) or 2/4/8 " - "(BF16). Found: ", - inner_extent, - ". Expr: ", - bqop->toString()); - - // [ NOTE: check scheduling requirements for block quantization ] - // - // M K - // │ │ - // ▼ ▼ - // ┌────────────┐ - // │ merge │ - // └─────┬──────┘ - // │ - // ▼ - // M*K - // ┌──────────┐ - // │ split ┼──┐ - // └─┬────────┘ │ - // ▼ ▼ - // (M*K)/4 4(G) - // ┌────────┐ - // │ split ┼────┐ - // └─┬──────┘ │ - // ▼ ▼ - // (M*K)/4 1(U) - // ┌─────────┐ - // │ split │ - // ┌─┼ ┼───┐ - // │ └─────────┘ │ - // ▼ ▼ - // (M*K)/4/128 128(Tx) - - // Next we check the following scheduling requirements for - // BlockQuantizationOp - the above figure is an example of a valid schedule. - // 1. The Group ID must be derived from the innermost logical IDs - // 2. TIDx must follow the Group ID in the schedule -- that is when derived - // from the logical domain, group ID must be inner-most, the next - // "inner-most" should be TIDx (unless there is an ID with a unit trip - // count) - // 3. All merges involved from logical domains to group and thread ID must - // combine contiguous IDs - - auto transform_exprs = DependencyCheck::getAllExprsBetween( - {quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()}, - {quantized_output->getLoopDomain().begin(), - quantized_output->getLoopDomain().end()}); - - std::vector ids_to_transform = - quantized_output->getLogicalDomain(); - - std::deque frontier( - quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()); - - // This will get the xforms from logical to loop and apply them on the - // logical domain. We will get a loop domain minus the reordering. - // This pass also removes all IDs from frontier that were derived using - // non-contiguous merges. - scheduler_utils::applyTransforms( - ids_to_transform, transform_exprs, [&frontier](Expr* expr) { - traverseFrontierWithContiguityCheck(frontier, expr); - }); - - // The grouped ID must correspond to the innermost loop-like domain - NVF_ERROR( - ids_to_transform.back() == grouped_id, - "The grouped ID must correspond to the innermost of all splits " - "from logical domains to loop domains for BlockQuantizationOp. " - "TV: ", - quantized_output->toString()); - - // Iterate from the back to find TIDx, skipping group_id (last element) - // Ensure all IDs between group_id and TIDx have extent 1 - bool found_tidx = false; - for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); - ++it) { - if (*it == thread_x) { - found_tidx = true; - break; - } - // All non-TIDx IDs between Group ID and TIDx must have extent of 1 - NVF_ERROR( - (*it)->extent()->isConstInt() && - (*it)->extent()->evaluate().as() == 1, - "Expected IDs between Group ID and TIDx to have extent of 1 for " - "BlockQuantizationOp: ", - quantized_output->toString()); - } - - NVF_ERROR( - found_tidx, - "TIDx must follow the Group ID in the schedule for " - "BlockQuantizationOp: ", - quantized_output->toString()); + validateQuantizedOutputScheduling(quantized_output); - // Check if grouped_id in frontier - auto grouped_it = std::ranges::find(frontier, grouped_id); - NVF_ERROR( - grouped_it != frontier.end(), - "All merge operations deriving the grouped ID must combine " - "contiguous IDs from the logical domain for BlockQuantizationOp: ", - quantized_output->toString()); - // Do the same for thread_x - auto threadx_it = - std::ranges::find(frontier.begin(), frontier.end(), thread_x); - NVF_ERROR( - threadx_it != frontier.end(), - "All merge operations deriving the TIDx ID must combine " - "contiguous IDs from the logical domain for BlockQuantizationOp: ", - quantized_output->toString()); } void handle(GroupedBlockQuantizationOp* bqop) final { @@ -909,152 +909,7 @@ class ExprValidator : public OptOutDispatch { global_scale->dtype()); } - // Outputs have the same allocation domain - // as the logical domain - no allocation domain. - NVF_ERROR( - !quantized_output->hasAllocation(), - "Quantized output must not have an allocation domain."); - - IterDomain* grouped_id = nullptr; - IterDomain* thread_x = nullptr; - IterDomain* block_x = nullptr; - IterDomain* thread_z = nullptr; - IterDomain* block_z = nullptr; - - for (const auto& loop_id : quantized_output->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group) { - grouped_id = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDx) { - thread_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDx) { - block_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDz) { - thread_z = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDz) { - block_z = loop_id; - } else if ( - loop_id->getParallelType() == ParallelType::Serial || - loop_id->getParallelType() == ParallelType::Unswitch || - loop_id->getParallelType() == ParallelType::Unroll) { - // Check this is ID has a constant extent and is 1 - NVF_ERROR( - loop_id->extent()->isConstInt(), - "Expected constant extent for Serial/Unswitch/Unroll ID in " - "GroupedBlockQuantizationOp"); - NVF_ERROR_EQ( - loop_id->extent()->evaluate().as(), - 1, - "Expected non-TID/BID/Group ID to have extent of 1 for " - "GroupedBlockQuantizationOp: ", - bqop->toString()); - } - } - - NVF_ERROR( - grouped_id != nullptr, - "One of the output IDs must be grouped for " - "GroupedBlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - thread_x != nullptr && block_x != nullptr, - "Need to have both TIDx and BIDx when using " - "GroupedBlockQuantizationOp: ", - bqop->toString()); - - NVF_ERROR( - !thread_z && !block_z, - "Parallelization along z axis is not supported for " - "GroupedBlockQuantizationOp: ", - bqop->toString()); - - auto inner_extent = grouped_id->extent()->evaluate().as(); - auto input_dtype = inp_tv->dtype(); - - NVF_ERROR( - ((inner_extent == 4 || inner_extent == 2) && - input_dtype == DataType::Float) || - ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && - (input_dtype == DataType::BFloat16 || - input_dtype == DataType::Half)), - "The group dimension must be 2/4 (FP32) or 2/4/8 " - "(BF16). Found: ", - inner_extent, - ". Expr: ", - bqop->toString()); - - // see [ NOTE: check scheduling requirements for block quantization ] - auto transform_exprs = DependencyCheck::getAllExprsBetween( - {quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()}, - {quantized_output->getLoopDomain().begin(), - quantized_output->getLoopDomain().end()}); - - std::vector ids_to_transform = - quantized_output->getLogicalDomain(); - - std::deque frontier( - quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()); - - // This will get the xforms from logical to loop and apply them on the - // logical domain. We will get a loop domain minus the reordering. - // This pass also removes all IDs from frontier that were derived using - // non-contiguous merges. - scheduler_utils::applyTransforms( - ids_to_transform, transform_exprs, [&frontier](Expr* expr) { - traverseFrontierWithContiguityCheck(frontier, expr); - }); - - // The grouped ID must correspond to the innermost loop-like domain - NVF_ERROR( - ids_to_transform.back() == grouped_id, - "The grouped ID must correspond to the innermost of all splits " - "from logical domains to loop domains for GroupedBlockQuantizationOp. " - "TV: ", - quantized_output->toString()); - - // Iterate from the back to find TIDx, skipping group_id (last element) - // Ensure all IDs between group_id and TIDx have extent 1 - bool found_tidx = false; - for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); - ++it) { - if (*it == thread_x) { - found_tidx = true; - break; - } - // All non-TIDx IDs between Group ID and TIDx must have extent of 1 - NVF_ERROR( - (*it)->extent()->isConstInt() && - (*it)->extent()->evaluate().as() == 1, - "Expected IDs between Group ID and TIDx to have extent of 1 for " - "GroupedBlockQuantizationOp: ", - quantized_output->toString()); - } - - NVF_ERROR( - found_tidx, - "TIDx must follow the Group ID in the schedule for " - "GroupedBlockQuantizationOp: ", - quantized_output->toString()); - - // Check if grouped_id in frontier - auto grouped_it = std::ranges::find(frontier, grouped_id); - NVF_ERROR( - grouped_it != frontier.end(), - "All merge operations deriving the grouped ID must combine " - "contiguous IDs from the logical domain for " - "GroupedBlockQuantizationOp: ", - quantized_output->toString()); - // Do the same for thread_x - auto threadx_it = - std::ranges::find(frontier.begin(), frontier.end(), thread_x); - NVF_ERROR( - threadx_it != frontier.end(), - "All merge operations deriving the TIDx ID must combine " - "contiguous IDs from the logical domain for " - "GroupedBlockQuantizationOp: ", - quantized_output->toString()); + validateQuantizedOutputScheduling(quantized_output); } }; diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index f4595d36561..d818cd32776 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -855,6 +855,8 @@ NVF_API BlockQuantizationResults blockQuantize( bool swizzle_scales = false, DataType out_dtype = DataType::Float4_e2m1fn); +// API for grouped block quantization. +// This operation combines blockQuantizationOp and PreprocessGroupedMatmulInputSf together, where it computes the quantized output and block scaling factor, as well as handling the swizzle layout required by block scaling factor. Refer to blockQuantize and preprocessGroupedMatmulInputSf for implementation details regarding these two operations. NVF_API BlockQuantizationResults groupedBlockQuantize( TensorView* input, TensorView* input_offsets, diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 8b1c545f2f6..ad188ccfe94 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -845,16 +845,13 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( // is not the fusion/segment output. bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { for (auto expr : fusion->exprs()) { - if (expr->isA()) { - auto block_scales = - expr->as()->blockScales()->as(); + if (auto bqop = dynamic_cast(expr)) { + auto block_scales = bqop->blockScales()->as(); if (!block_scales->isFusionOutput()) { return true; } - } else if (expr->isA()) { - auto block_scales = expr->as() - ->blockScales() - ->as(); + } else if (auto grouped_bqop = dynamic_cast(expr)) { + auto block_scales = grouped_bqop->blockScales()->as(); if (!block_scales->isFusionOutput()) { return true; } diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 8d9cc77c57e..7c4f606113e 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -421,6 +421,7 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { // BlockQuantizationOp when looking for a reference tensor. This is because // the two outputs of block quantization op are not symmetrical and the // logical domains of the scaling factor is not completely mapped. + // The same thing applies to GroupedBlockQuantizationOp's block scale output. if (output_tv == tv || (output_tv->definition() && output_tv->definition()->isA() && From 13910fcba8bb51851e9f3e49fc0930a433886bbe Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 15:01:07 -0800 Subject: [PATCH 10/13] wip --- csrc/device_lower/validation.cpp | 359 +++++++++++++++--------------- csrc/scheduler/registry_utils.cpp | 4 +- 2 files changed, 181 insertions(+), 182 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 8411ed04844..03dd03eda5e 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -31,184 +31,6 @@ namespace nvfuser { namespace { -void validateQuantizedOutputScheduling(TensorView* quantized_output) { - // Outputs have the same allocation domain - // as the logical domain - no allocation domain. - NVF_ERROR( - !quantized_output->hasAllocation(), - "Quantized output must not have an allocation domain."); - - IterDomain* grouped_id = nullptr; - IterDomain* thread_x = nullptr; - IterDomain* block_x = nullptr; - IterDomain* thread_z = nullptr; - IterDomain* block_z = nullptr; - - for (const auto& loop_id : quantized_output->getLoopDomain()) { - if (loop_id->getParallelType() == ParallelType::Group) { - grouped_id = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDx) { - thread_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDx) { - block_x = loop_id; - } else if (loop_id->getParallelType() == ParallelType::TIDz) { - thread_z = loop_id; - } else if (loop_id->getParallelType() == ParallelType::BIDz) { - block_z = loop_id; - } else if ( - loop_id->getParallelType() == ParallelType::Serial || - loop_id->getParallelType() == ParallelType::Unswitch || - loop_id->getParallelType() == ParallelType::Unroll) { - // Check this is ID has a constant extent and is 1 - NVF_ERROR( - loop_id->extent()->isConstInt(), - "Expected constant extent for Serial/Unswitch/Unroll ID in ", - quantized_output->definition()->toString()); - NVF_ERROR_EQ( - loop_id->extent()->evaluate().as(), - 1, - "Expected non-TID/BID/Group ID to have extent of 1 for ", - bqop->toString()); - } - } - - NVF_ERROR( - grouped_id != nullptr, - "One of the output IDs must be grouped for ", - bqop->toString()); - - NVF_ERROR( - thread_x != nullptr && block_x != nullptr, - "Need to have both TIDx and BIDx when using: ", - bqop->toString()); - - NVF_ERROR( - !thread_z && !block_z, - "Parallelization along z axis is not supported for ", - bqop->toString()); - - auto inner_extent = grouped_id->extent()->evaluate().as(); - auto input_dtype = inp_tv->dtype(); - - NVF_ERROR( - ((inner_extent == 4 || inner_extent == 2) && - input_dtype == DataType::Float) || - ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && - (input_dtype == DataType::BFloat16 || - input_dtype == DataType::Half)), - "The group dimension must be 2/4 (FP32) or 2/4/8 " - "(BF16). Found: ", - inner_extent, - ". Expr: ", - bqop->toString()); - - // [ NOTE: check scheduling requirements for block quantization ] - // - // M K - // │ │ - // ▼ ▼ - // ┌────────────┐ - // │ merge │ - // └─────┬──────┘ - // │ - // ▼ - // M*K - // ┌──────────┐ - // │ split ┼──┐ - // └─┬────────┘ │ - // ▼ ▼ - // (M*K)/4 4(G) - // ┌────────┐ - // │ split ┼────┐ - // └─┬──────┘ │ - // ▼ ▼ - // (M*K)/4 1(U) - // ┌─────────┐ - // │ split │ - // ┌─┼ ┼───┐ - // │ └─────────┘ │ - // ▼ ▼ - // (M*K)/4/128 128(Tx) - - // Next we check the following scheduling requirements for - // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an example of a valid schedule. - // 1. The Group ID must be derived from the innermost logical IDs - // 2. TIDx must follow the Group ID in the schedule -- that is when derived - // from the logical domain, group ID must be inner-most, the next - // "inner-most" should be TIDx (unless there is an ID with a unit trip - // count) - // 3. All merges involved from logical domains to group and thread ID must - // combine contiguous IDs - - auto transform_exprs = DependencyCheck::getAllExprsBetween( - {quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()}, - {quantized_output->getLoopDomain().begin(), - quantized_output->getLoopDomain().end()}); - - std::vector ids_to_transform = - quantized_output->getLogicalDomain(); - - std::deque frontier( - quantized_output->getLogicalDomain().begin(), - quantized_output->getLogicalDomain().end()); - - // This will get the xforms from logical to loop and apply them on the - // logical domain. We will get a loop domain minus the reordering. - // This pass also removes all IDs from frontier that were derived using - // non-contiguous merges. - scheduler_utils::applyTransforms( - ids_to_transform, transform_exprs, [&frontier](Expr* expr) { - traverseFrontierWithContiguityCheck(frontier, expr); - }); - - // The grouped ID must correspond to the innermost loop-like domain - NVF_ERROR( - ids_to_transform.back() == grouped_id, - "The grouped ID must correspond to the innermost of all splits " - "from logical domains to loop domains for TV: " - quantized_output->toString()); - - // Iterate from the back to find TIDx, skipping group_id (last element) - // Ensure all IDs between group_id and TIDx have extent 1 - bool found_tidx = false; - for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); - ++it) { - if (*it == thread_x) { - found_tidx = true; - break; - } - // All non-TIDx IDs between Group ID and TIDx must have extent of 1 - NVF_ERROR( - (*it)->extent()->isConstInt() && - (*it)->extent()->evaluate().as() == 1, - "Expected IDs between Group ID and TIDx to have extent of 1 for ", - quantized_output->toString()); - } - - NVF_ERROR( - found_tidx, - "TIDx must follow the Group ID in the schedule for ", - quantized_output->toString()); - - // Check if grouped_id in frontier - auto grouped_it = std::ranges::find(frontier, grouped_id); - NVF_ERROR( - grouped_it != frontier.end(), - "All merge operations deriving the grouped ID must combine " - "contiguous IDs from the logical domain for: ", - quantized_output->toString()); - // Do the same for thread_x - auto threadx_it = - std::ranges::find(frontier.begin(), frontier.end(), thread_x); - NVF_ERROR( - threadx_it != frontier.end(), - "All merge operations deriving the TIDx ID must combine " - "contiguous IDs from the logical domain for: ", - quantized_output->toString()); -} - - //! Validate multiple output tensors of the same expression, i.e., //! siblings, have valid domains and parallel types. Since siblings //! are placed in the same loop nest, they must be parallelized the @@ -431,6 +253,183 @@ void traverseFrontierWithContiguityCheck( } } +void validateQuantizedOutputScheduling(TensorView* quantized_output, DataType input_dtype) { + // Outputs have the same allocation domain + // as the logical domain - no allocation domain. + NVF_ERROR( + !quantized_output->hasAllocation(), + "Quantized output must not have an allocation domain."); + + IterDomain* grouped_id = nullptr; + IterDomain* thread_x = nullptr; + IterDomain* block_x = nullptr; + IterDomain* thread_z = nullptr; + IterDomain* block_z = nullptr; + + for (const auto& loop_id : quantized_output->getLoopDomain()) { + if (loop_id->getParallelType() == ParallelType::Group) { + grouped_id = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDx) { + thread_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDx) { + block_x = loop_id; + } else if (loop_id->getParallelType() == ParallelType::TIDz) { + thread_z = loop_id; + } else if (loop_id->getParallelType() == ParallelType::BIDz) { + block_z = loop_id; + } else if ( + loop_id->getParallelType() == ParallelType::Serial || + loop_id->getParallelType() == ParallelType::Unswitch || + loop_id->getParallelType() == ParallelType::Unroll) { + // Check this is ID has a constant extent and is 1 + NVF_ERROR( + loop_id->extent()->isConstInt(), + "Expected constant extent for Serial/Unswitch/Unroll ID in ", + quantized_output->definition()->toString()); + NVF_ERROR_EQ( + loop_id->extent()->evaluate().as(), + 1, + "Expected non-TID/BID/Group ID to have extent of 1 for ", + quantized_output->definition()->toString()); + } + } + + NVF_ERROR( + grouped_id != nullptr, + "One of the output IDs must be grouped for ", + quantized_output->definition()->toString()); + + NVF_ERROR( + thread_x != nullptr && block_x != nullptr, + "Need to have both TIDx and BIDx when using: ", + quantized_output->definition()->toString()); + + NVF_ERROR( + !thread_z && !block_z, + "Parallelization along z axis is not supported for ", + quantized_output->definition()->toString()); + + auto inner_extent = grouped_id->extent()->evaluate().as(); + + NVF_ERROR( + ((inner_extent == 4 || inner_extent == 2) && + input_dtype == DataType::Float) || + ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) && + (input_dtype == DataType::BFloat16 || + input_dtype == DataType::Half)), + "The group dimension must be 2/4 (FP32) or 2/4/8 " + "(BF16). Found: ", + inner_extent, + ". Expr: ", + quantized_output->definition()->toString()); + + // [ NOTE: check scheduling requirements for block quantization ] + // + // M K + // │ │ + // ▼ ▼ + // ┌────────────┐ + // │ merge │ + // └─────┬──────┘ + // │ + // ▼ + // M*K + // ┌──────────┐ + // │ split ┼──┐ + // └─┬────────┘ │ + // ▼ ▼ + // (M*K)/4 4(G) + // ┌────────┐ + // │ split ┼────┐ + // └─┬──────┘ │ + // ▼ ▼ + // (M*K)/4 1(U) + // ┌─────────┐ + // │ split │ + // ┌─┼ ┼───┐ + // │ └─────────┘ │ + // ▼ ▼ + // (M*K)/4/128 128(Tx) + + // Next we check the following scheduling requirements for + // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an example of a valid schedule. + // 1. The Group ID must be derived from the innermost logical IDs + // 2. TIDx must follow the Group ID in the schedule -- that is when derived + // from the logical domain, group ID must be inner-most, the next + // "inner-most" should be TIDx (unless there is an ID with a unit trip + // count) + // 3. All merges involved from logical domains to group and thread ID must + // combine contiguous IDs + + auto transform_exprs = DependencyCheck::getAllExprsBetween( + {quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()}, + {quantized_output->getLoopDomain().begin(), + quantized_output->getLoopDomain().end()}); + + std::vector ids_to_transform = + quantized_output->getLogicalDomain(); + + std::deque frontier( + quantized_output->getLogicalDomain().begin(), + quantized_output->getLogicalDomain().end()); + + // This will get the xforms from logical to loop and apply them on the + // logical domain. We will get a loop domain minus the reordering. + // This pass also removes all IDs from frontier that were derived using + // non-contiguous merges. + scheduler_utils::applyTransforms( + ids_to_transform, transform_exprs, [&frontier](Expr* expr) { + traverseFrontierWithContiguityCheck(frontier, expr); + }); + + // The grouped ID must correspond to the innermost loop-like domain + NVF_ERROR( + ids_to_transform.back() == grouped_id, + "The grouped ID must correspond to the innermost of all splits " + "from logical domains to loop domains for TV: " + quantized_output->toString()); + + // Iterate from the back to find TIDx, skipping group_id (last element) + // Ensure all IDs between group_id and TIDx have extent 1 + bool found_tidx = false; + for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend(); + ++it) { + if (*it == thread_x) { + found_tidx = true; + break; + } + // All non-TIDx IDs between Group ID and TIDx must have extent of 1 + NVF_ERROR( + (*it)->extent()->isConstInt() && + (*it)->extent()->evaluate().as() == 1, + "Expected IDs between Group ID and TIDx to have extent of 1 for ", + quantized_output->toString()); + } + + NVF_ERROR( + found_tidx, + "TIDx must follow the Group ID in the schedule for ", + quantized_output->toString()); + + // Check if grouped_id in frontier + auto grouped_it = std::ranges::find(frontier, grouped_id); + NVF_ERROR( + grouped_it != frontier.end(), + "All merge operations deriving the grouped ID must combine " + "contiguous IDs from the logical domain for: ", + quantized_output->toString()); + // Do the same for thread_x + auto threadx_it = + std::ranges::find(frontier.begin(), frontier.end(), thread_x); + NVF_ERROR( + threadx_it != frontier.end(), + "All merge operations deriving the TIDx ID must combine " + "contiguous IDs from the logical domain for: ", + quantized_output->toString()); +} + + // Check if maybe_innermost_id is derived from base_id and corresponds to the // innermost subregion of base_id. The split/merge exprs between // base_id and id must not include any ID that is not produced from @@ -861,7 +860,7 @@ class ExprValidator : public OptOutDispatch { [](std::optional c) { return c.value_or(true); }), "Block scaling factor not contiguous"); - validateQuantizedOutputScheduling(quantized_output); + validateQuantizedOutputScheduling(quantized_output, input_tv->dtype()); } @@ -909,7 +908,7 @@ class ExprValidator : public OptOutDispatch { global_scale->dtype()); } - validateQuantizedOutputScheduling(quantized_output); + validateQuantizedOutputScheduling(quantized_output, input_tv->dtype()); } }; diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index ad188ccfe94..0d40a8f6390 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -845,12 +845,12 @@ bool SchedulerTopologyChecker::hasNonNormalizePostReductionBCast( // is not the fusion/segment output. bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { for (auto expr : fusion->exprs()) { - if (auto bqop = dynamic_cast(expr)) { + if (auto bqop = dynamic_cast(expr)) { auto block_scales = bqop->blockScales()->as(); if (!block_scales->isFusionOutput()) { return true; } - } else if (auto grouped_bqop = dynamic_cast(expr)) { + } else if (auto grouped_bqop = dynamic_cast(expr)) { auto block_scales = grouped_bqop->blockScales()->as(); if (!block_scales->isFusionOutput()) { return true; From a9cc3b7e222eae31cf3832bb166f3f60758ffbfe Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 15:04:23 -0800 Subject: [PATCH 11/13] wip --- csrc/device_lower/validation.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 03dd03eda5e..f0ffef1eda4 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -387,7 +387,7 @@ void validateQuantizedOutputScheduling(TensorView* quantized_output, DataType in NVF_ERROR( ids_to_transform.back() == grouped_id, "The grouped ID must correspond to the innermost of all splits " - "from logical domains to loop domains for TV: " + "from logical domains to loop domains for TV: ", quantized_output->toString()); // Iterate from the back to find TIDx, skipping group_id (last element) @@ -860,7 +860,7 @@ class ExprValidator : public OptOutDispatch { [](std::optional c) { return c.value_or(true); }), "Block scaling factor not contiguous"); - validateQuantizedOutputScheduling(quantized_output, input_tv->dtype()); + validateQuantizedOutputScheduling(quantized_output, inp_tv->dtype()); } @@ -908,7 +908,7 @@ class ExprValidator : public OptOutDispatch { global_scale->dtype()); } - validateQuantizedOutputScheduling(quantized_output, input_tv->dtype()); + validateQuantizedOutputScheduling(quantized_output, inp_tv->dtype()); } }; From fd57cb0e344665facba5cb922dcd922cd4a68993 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 15:07:00 -0800 Subject: [PATCH 12/13] clangformat --- csrc/codegen.cpp | 28 +++++++++++++++++++--------- csrc/device_lower/validation.cpp | 9 +++++---- csrc/ops/arith.h | 7 ++++++- csrc/scheduler/registry_utils.cpp | 3 ++- csrc/scheduler/tools/domain_map.cpp | 3 ++- 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 9c5a4132e72..eaeb22b96d6 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1906,7 +1906,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2, // 4, or 8 for Half. We achieve this by having the quantized output tv // scheduled to have the inner dimension grouped by 2/4/8. - auto output = grouped_bqop->quantizedOutput()->as()->view(); + auto output = + grouped_bqop->quantizedOutput()->as()->view(); auto output_dtype = output->getDataType(); // Extract group size from the loop domain @@ -1921,8 +1922,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } // Validate group size based on input data type - const auto input_dtype = - grouped_bqop->in()->as()->view()->getDataType().value(); + const auto input_dtype = grouped_bqop->in() + ->as() + ->view() + ->getDataType() + .value(); const bool is_half_precision = (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half); const bool is_valid_group_size = is_half_precision @@ -1963,25 +1967,31 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline( grouped_bqop->input(0)->as()->view())); // input data func_args.arg(genInline(output)); // quantized output - func_args.arg(genInline( - grouped_bqop->blockScales()->as()->view())); // block scales + func_args.arg(genInline(grouped_bqop->blockScales() + ->as() + ->view())); // block scales // generate logical index for runtime function func_args.arg(genInline(grouped_bqop->attributeVal(2))); func_args.arg(genInline(grouped_bqop->attributeVal(3))); - func_args.arg("&").append(genVariableName(grouped_bqop->inputOffsets()) + "[0]"); - func_args.arg("&").append(genVariableName(grouped_bqop->outputOffsets()) + "[0]"); + func_args.arg("&").append( + genVariableName(grouped_bqop->inputOffsets()) + "[0]"); + func_args.arg("&").append( + genVariableName(grouped_bqop->outputOffsets()) + "[0]"); func_args.arg(genInline(grouped_bqop->k())); func_args.arg(genInline(grouped_bqop->g())); if (output_dtype == DataType::Float4_e2m1fn) { func_args.arg( - grouped_bqop->hasGlobalScale() ? genInline(grouped_bqop->globalScale()) : "{}"); + grouped_bqop->hasGlobalScale() + ? genInline(grouped_bqop->globalScale()) + : "{}"); } // Add swizzled allocation domain parameters if needed // This is always skipped when quantizing to mxfp8 - auto block_scales_tv = grouped_bqop->blockScales()->as()->view(); + auto block_scales_tv = + grouped_bqop->blockScales()->as()->view(); if (block_scales_tv->hasAllocation()) { auto logical_domain = TensorDomain::noReductions(block_scales_tv->getLogicalDomain()); diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index f0ffef1eda4..c7f2fe83769 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -253,7 +253,9 @@ void traverseFrontierWithContiguityCheck( } } -void validateQuantizedOutputScheduling(TensorView* quantized_output, DataType input_dtype) { +void validateQuantizedOutputScheduling( + TensorView* quantized_output, + DataType input_dtype) { // Outputs have the same allocation domain // as the logical domain - no allocation domain. NVF_ERROR( @@ -352,7 +354,8 @@ void validateQuantizedOutputScheduling(TensorView* quantized_output, DataType in // (M*K)/4/128 128(Tx) // Next we check the following scheduling requirements for - // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an example of a valid schedule. + // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an + // example of a valid schedule. // 1. The Group ID must be derived from the innermost logical IDs // 2. TIDx must follow the Group ID in the schedule -- that is when derived // from the logical domain, group ID must be inner-most, the next @@ -429,7 +432,6 @@ void validateQuantizedOutputScheduling(TensorView* quantized_output, DataType in quantized_output->toString()); } - // Check if maybe_innermost_id is derived from base_id and corresponds to the // innermost subregion of base_id. The split/merge exprs between // base_id and id must not include any ID that is not produced from @@ -861,7 +863,6 @@ class ExprValidator : public OptOutDispatch { "Block scaling factor not contiguous"); validateQuantizedOutputScheduling(quantized_output, inp_tv->dtype()); - } void handle(GroupedBlockQuantizationOp* bqop) final { diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index d818cd32776..4ac75f9771f 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -856,7 +856,12 @@ NVF_API BlockQuantizationResults blockQuantize( DataType out_dtype = DataType::Float4_e2m1fn); // API for grouped block quantization. -// This operation combines blockQuantizationOp and PreprocessGroupedMatmulInputSf together, where it computes the quantized output and block scaling factor, as well as handling the swizzle layout required by block scaling factor. Refer to blockQuantize and preprocessGroupedMatmulInputSf for implementation details regarding these two operations. +// This operation combines blockQuantizationOp and +// PreprocessGroupedMatmulInputSf together, where it computes the quantized +// output and block scaling factor, as well as handling the swizzle layout +// required by block scaling factor. Refer to blockQuantize and +// preprocessGroupedMatmulInputSf for implementation details regarding these two +// operations. NVF_API BlockQuantizationResults groupedBlockQuantize( TensorView* input, TensorView* input_offsets, diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 0d40a8f6390..3e0ccccdff2 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -850,7 +850,8 @@ bool hasNonTerminalBlockQuantizeOp(Fusion* fusion) { if (!block_scales->isFusionOutput()) { return true; } - } else if (auto grouped_bqop = dynamic_cast(expr)) { + } else if ( + auto grouped_bqop = dynamic_cast(expr)) { auto block_scales = grouped_bqop->blockScales()->as(); if (!block_scales->isFusionOutput()) { return true; diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 7c4f606113e..b25ed288fde 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -421,7 +421,8 @@ bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { // BlockQuantizationOp when looking for a reference tensor. This is because // the two outputs of block quantization op are not symmetrical and the // logical domains of the scaling factor is not completely mapped. - // The same thing applies to GroupedBlockQuantizationOp's block scale output. + // The same thing applies to GroupedBlockQuantizationOp's block scale + // output. if (output_tv == tv || (output_tv->definition() && output_tv->definition()->isA() && From da98b52252d0a0de648979b0676c23cf2d054d8d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 14 Jan 2026 17:45:28 -0800 Subject: [PATCH 13/13] modifying test string check --- tests/cpp/test_low_precision_recipe.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_low_precision_recipe.cpp b/tests/cpp/test_low_precision_recipe.cpp index 1261a70a6fa..68a2f1d0f9b 100644 --- a/tests/cpp/test_low_precision_recipe.cpp +++ b/tests/cpp/test_low_precision_recipe.cpp @@ -845,7 +845,7 @@ TEST_F(BlockQuantizationValidationTest, GroupIDMustBeInnermost) { [&]() { GpuLower(setup.fusion.get()).run(); }, testing::ThrowsMessage(testing::HasSubstr( "The grouped ID must correspond to the innermost of all splits from " - "logical domains to loop domains for BlockQuantizationOp"))); + "logical domains to loop domains for"))); } // We do not allow IDs of types serial, unroll, unswitch to have extent > 1 @@ -880,8 +880,7 @@ TEST_F(BlockQuantizationValidationTest, NonParallelizedIDsMustHaveExtentOfOne) { EXPECT_THAT( [&]() { GpuLower(setup.fusion.get()).run(); }, testing::ThrowsMessage(testing::HasSubstr( - "Expected non-TID/BID/Group ID to have extent of 1 for " - "BlockQuantizationOp"))); + "Expected non-TID/BID/Group ID to have extent of 1 for "))); } // The runtime kernel for block quantization expects TIDx to access contiguous @@ -919,8 +918,7 @@ TEST_F(BlockQuantizationValidationTest, TIDxMustBeSecondInnermostAfterGroupID) { EXPECT_THAT( [&]() { GpuLower(setup.fusion.get()).run(); }, testing::ThrowsMessage(testing::HasSubstr( - "Expected IDs between Group ID and TIDx to have extent of 1 for " - "BlockQuantizationOp:"))); + "Expected IDs between Group ID and TIDx to have extent of 1 for "))); } // When running validation checks we traverse from loop to logical domain @@ -968,7 +966,7 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) { testing::ThrowsMessage(testing::HasSubstr( "All merge operations deriving the grouped ID must combine " "contiguous " - "IDs from the logical domain for BlockQuantizationOp"))); + "IDs from the logical domain for"))); } class BlockQuantizationSchedulingTest