Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 8, 2026

Context

The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark

What's in this PR

  1. Adding Fusion IR node GroupedBlockQuantizationOp. The operation is a combination of BlockQuantizationOp and PreprocessGroupedMatmulInputSf, where it inherits all the validation / checks from the two operations.
    The operation is similar to BlockQuantizationOp, with the exception that:
    i. The block scaling factor output doesn't have the swizzle logic represented as allocation domain transformations;
    ii. It takes an additional inputs (input_offsets and output_offsets) to facilitate group indexing, similar to PreprocessGroupedMatmulInputSf.

  2. Adding cpp test case for GroupedBlockQuantizationOp.

1. refactor existing block_layout op and block_quantization_kernel to re-use existing runtime functions;
2. added runtime function for GroupedBlockQuantizeOp
@jjsjann123 jjsjann123 changed the base branch from main to jj/grouped_block_quantize_op_0 January 8, 2026 00:36
@jjsjann123 jjsjann123 changed the title Jj/grouped block quantize op 1 PR1: adding codegen support for GroupedBlockQuantizationOp Jan 8, 2026
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit 3c5c99c

Description

  • Add GroupedBlockQuantizationOp IR node combining BlockQuantizationOp and PreprocessGroupedMatmulInputSf functionality

  • Implement codegen support in CudaKernelGenerator to generate runtime function calls for grouped block quantization

  • Update analysis passes (validation, indexing, sync mapping) to handle new grouped quantization operation

  • Add groupedBlockQuantize public API function and comprehensive test case for validation

Changes walkthrough

Relevant files
Enhancement
25 files
codegen.cpp
Add CudaKernelGenerator handle method for GroupedBlockQuantizationOp
+124/-0 
non_divisible_split.cpp
Update NonDivisiblePredicateInfo to handle GroupedBlockQuantizationOp
+6/-1     
sync_information.cpp
Update SyncMap to handle GroupedBlockQuantizationOp block scales
+10/-5   
trivial_broadcast.cpp
Add ConcretizedBroadcastDomains handler for GroupedBlockQuantizationOp
+11/-0   
index.cpp
Add IndexLowering handler for GroupedBlockQuantizationOp 
+54/-0   
utils.cpp
Update isTvOp to include GroupedBlockQuantizationOp           
+1/-0     
validation.cpp
Add ExprValidator handler and validateQuantizedOutputScheduling for
GroupedBlockQuantizationOp
+219/-169
fusion_segmenter.cpp
Update SegmentedFusion to handle GroupedBlockQuantizationOp inputs
+6/-1     
composite_nodes.cpp
Implement GroupedBlockQuantizationOp class definition and methods
+58/-0   
utils.cpp
Update hasUniformSiblings to exclude GroupedBlockQuantizationOp
+5/-1     
kernel.cpp
Add KernelIrScanner handler for GroupedBlockQuantizationOp
+4/-0     
logical_domain_map.cpp
Update logical domain mapping for GroupedBlockQuantizationOp
+29/-9   
arith.cpp
Implement groupedBlockQuantize public API function             
+141/-0 
pointwise.cpp
Update PointWiseScheduler to handle GroupedBlockQuantizationOp
+23/-1   
pointwise_non_tma.cpp
Update non-TMA pointwise scheduler for GroupedBlockQuantizationOp
+8/-1     
registry_utils.cpp
Update scheduler topology checks for GroupedBlockQuantizationOp
+21/-3   
domain_map.cpp
Update DomainMap validation for GroupedBlockQuantizationOp
+15/-0   
utils.cpp
Update cache utilities to handle GroupedBlockQuantizationOp offsets
+12/-7   
tensor_metadata.cpp
Update tensor metadata inference for GroupedBlockQuantizationOp
+6/-0     
trivial_broadcast.h
Add GroupedBlockQuantizationOp handler declaration             
+2/-0     
index.h
Add GroupedBlockQuantizationOp handler declaration             
+1/-0     
dispatch.h
Add GroupedBlockQuantizationOp to dispatch macros               
+1/-0     
composite_nodes.h
Add GroupedBlockQuantizationOp class declaration and interface
+92/-0   
logical_domain_map.h
Add GroupedBlockQuantizationOp handler declaration             
+4/-0     
arith.h
Add groupedBlockQuantize API function declaration               
+16/-0   
Tests
2 files
test_layout_op.cpp
Add GroupedBlockQuantizeOp test case with validation         
+73/-1   
test_low_precision_recipe.cpp
Update validation test error messages for grouped operations
+4/-6     

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Validation Logic Refactoring

The PR refactors validation logic by extracting common code into validateQuantizedOutputScheduling() function. This is good for code reuse, but need to verify that both BlockQuantizationOp and GroupedBlockQuantizationOp have identical validation requirements and that no functionality was lost in the refactoring. The extracted function should handle both operations correctly.

void validateQuantizedOutputScheduling(
    TensorView* quantized_output,
    DataType input_dtype) {
  // Outputs have the same allocation domain
  // as the logical domain - no allocation domain.
  NVF_ERROR(
      !quantized_output->hasAllocation(),
      "Quantized output must not have an allocation domain.");

  IterDomain* grouped_id = nullptr;
  IterDomain* thread_x = nullptr;
  IterDomain* block_x = nullptr;
  IterDomain* thread_z = nullptr;
  IterDomain* block_z = nullptr;

  for (const auto& loop_id : quantized_output->getLoopDomain()) {
    if (loop_id->getParallelType() == ParallelType::Group) {
      grouped_id = loop_id;
    } else if (loop_id->getParallelType() == ParallelType::TIDx) {
      thread_x = loop_id;
    } else if (loop_id->getParallelType() == ParallelType::BIDx) {
      block_x = loop_id;
    } else if (loop_id->getParallelType() == ParallelType::TIDz) {
      thread_z = loop_id;
    } else if (loop_id->getParallelType() == ParallelType::BIDz) {
      block_z = loop_id;
    } else if (
        loop_id->getParallelType() == ParallelType::Serial ||
        loop_id->getParallelType() == ParallelType::Unswitch ||
        loop_id->getParallelType() == ParallelType::Unroll) {
      // Check this is ID has a constant extent and is 1
      NVF_ERROR(
          loop_id->extent()->isConstInt(),
          "Expected constant extent for Serial/Unswitch/Unroll ID in ",
          quantized_output->definition()->toString());
      NVF_ERROR_EQ(
          loop_id->extent()->evaluate().as<int64_t>(),
          1,
          "Expected non-TID/BID/Group ID to have extent of 1 for ",
          quantized_output->definition()->toString());
    }
  }

  NVF_ERROR(
      grouped_id != nullptr,
      "One of the output IDs must be grouped for ",
      quantized_output->definition()->toString());

  NVF_ERROR(
      thread_x != nullptr && block_x != nullptr,
      "Need to have both TIDx and BIDx when using: ",
      quantized_output->definition()->toString());

  NVF_ERROR(
      !thread_z && !block_z,
      "Parallelization along z axis is not supported for ",
      quantized_output->definition()->toString());

  auto inner_extent = grouped_id->extent()->evaluate().as<int64_t>();

  NVF_ERROR(
      ((inner_extent == 4 || inner_extent == 2) &&
       input_dtype == DataType::Float) ||
          ((inner_extent == 8 || inner_extent == 4 || inner_extent == 2) &&
           (input_dtype == DataType::BFloat16 ||
            input_dtype == DataType::Half)),
      "The group dimension must be  2/4 (FP32) or 2/4/8 "
      "(BF16). Found: ",
      inner_extent,
      ". Expr: ",
      quantized_output->definition()->toString());

  // [ NOTE: check scheduling requirements for block quantization ]
  //
  //                   M    K
  //                 │    │
  //                 ▼    ▼
  //              ┌────────────┐
  //              │   merge    │
  //              └─────┬──────┘
  //
  //
  //                   M*K
  //               ┌──────────┐
  //               │  split   ┼──┐
  //               └─┬────────┘  │
  //                 ▼           ▼
  //           (M*K)/4          4(G)
  //           ┌────────┐
  //           │ split  ┼────┐
  //           └─┬──────┘    │
  //             ▼           ▼
  //         (M*K)/4        1(U)
  //     ┌─────────┐
  //     │  split  │
  //   ┌─┼         ┼───┐
  //   │ └─────────┘   │
  //   ▼               ▼
  // (M*K)/4/128      128(Tx)

  // Next we check the following scheduling requirements for
  // BlockQuantizationOp/GroupedBlockQuantizationOp - the above figure is an
  // example of a valid schedule.
  // 1. The Group ID must be derived from the innermost logical IDs
  // 2. TIDx must follow the Group ID in the schedule -- that is when derived
  // from the logical domain, group ID must be inner-most, the next
  // "inner-most" should be TIDx (unless there is an ID with a unit trip
  // count)
  // 3. All merges involved from logical domains to group and thread ID must
  // combine contiguous IDs

  auto transform_exprs = DependencyCheck::getAllExprsBetween(
      {quantized_output->getLogicalDomain().begin(),
       quantized_output->getLogicalDomain().end()},
      {quantized_output->getLoopDomain().begin(),
       quantized_output->getLoopDomain().end()});

  std::vector<IterDomain*> ids_to_transform =
      quantized_output->getLogicalDomain();

  std::deque<IterDomain*> frontier(
      quantized_output->getLogicalDomain().begin(),
      quantized_output->getLogicalDomain().end());

  // This will get the xforms from logical to loop and apply them on the
  // logical domain. We will get a loop domain minus the reordering.
  // This pass also removes all IDs from frontier that were derived using
  // non-contiguous merges.
  scheduler_utils::applyTransforms(
      ids_to_transform, transform_exprs, [&frontier](Expr* expr) {
        traverseFrontierWithContiguityCheck(frontier, expr);
      });

  // The grouped ID must correspond to the innermost loop-like domain
  NVF_ERROR(
      ids_to_transform.back() == grouped_id,
      "The grouped ID must correspond to the innermost of all splits "
      "from logical domains to loop domains for TV: ",
      quantized_output->toString());

  // Iterate from the back to find TIDx, skipping group_id (last element)
  // Ensure all IDs between group_id and TIDx have extent 1
  bool found_tidx = false;
  for (auto it = ids_to_transform.rbegin() + 1; it != ids_to_transform.rend();
       ++it) {
    if (*it == thread_x) {
      found_tidx = true;
      break;
    }
    // All non-TIDx IDs between Group ID and TIDx must have extent of 1
    NVF_ERROR(
        (*it)->extent()->isConstInt() &&
            (*it)->extent()->evaluate().as<int64_t>() == 1,
        "Expected IDs between Group ID and TIDx to have extent of 1 for ",
        quantized_output->toString());
  }

  NVF_ERROR(
      found_tidx,
      "TIDx must follow the Group ID in the schedule for ",
      quantized_output->toString());

  // Check if grouped_id in frontier
  auto grouped_it = std::ranges::find(frontier, grouped_id);
  NVF_ERROR(
      grouped_it != frontier.end(),
      "All merge operations deriving the grouped ID must combine "
      "contiguous IDs from the logical domain for: ",
      quantized_output->toString());
  // Do the same for thread_x
  auto threadx_it =
      std::ranges::find(frontier.begin(), frontier.end(), thread_x);
  NVF_ERROR(
      threadx_it != frontier.end(),
      "All merge operations deriving the TIDx ID must combine "
      "contiguous IDs from the logical domain for: ",
      quantized_output->toString());
}
Runtime Function Integration

The codegen for GroupedBlockQuantizationOp calls bq::grouped_block_quantize_to_nvfp4 runtime function. Need to verify that this runtime function exists and is properly implemented, and that the template/function arguments match the expected signature. Also ensure error handling for unsupported output types is comprehensive.

// Special handling of GroupedBlockQuantizationOp to call the runtime
// function.
void handle(const GroupedBlockQuantizationOp* grouped_bqop) final {
  // This operator is plumbed down to a runtime function call.
  // One of the assumptions is that the device runtime expects
  // n consecutive inputs per thread. Where n can be 2 or 4 for Float, and 2,
  // 4, or 8 for Half. We achieve this by having the quantized output tv
  // scheduled to have the inner dimension grouped by 2/4/8.
  auto output =
      grouped_bqop->quantizedOutput()->as<kir::TensorIndex>()->view();
  auto output_dtype = output->getDataType();

  // Extract group size from the loop domain
  int64_t group_size = 1;
  const auto& loop_domain = output->getLoopDomain();
  for (const auto* domain : loop_domain) {
    if (domain->getParallelType() == ParallelType::Group &&
        domain->extent()->isConstInt()) {
      group_size = domain->extent()->evaluate().as<int64_t>();
      break;
    }
  }

  // Validate group size based on input data type
  const auto input_dtype = grouped_bqop->in()
                               ->as<kir::TensorIndex>()
                               ->view()
                               ->getDataType()
                               .value();
  const bool is_half_precision =
      (input_dtype == DataType::BFloat16 || input_dtype == DataType::Half);
  const bool is_valid_group_size = is_half_precision
      ? (group_size == 2 || group_size == 4 || group_size == 8)
      : (group_size == 2 || group_size == 4);

  NVF_ERROR(
      is_valid_group_size,
      "Group size should be ",
      is_half_precision ? "2, 4 or 8" : "2 or 4",
      " for GroupedBlockQuantizationOp with input type ",
      input_dtype,
      ". Found: ",
      group_size,
      ". Expr: ",
      grouped_bqop->toString());

  // Build template arguments
  ArgumentBuilder template_args;
  // No global scale is required when quantizing to mxfp8
  if (output_dtype == DataType::Float4_e2m1fn) {
    template_args.arg(grouped_bqop->hasGlobalScale());
  }
  switch (grouped_bqop->layout()) {
    case BlockScalingFactorLayout::Block128x4:
      template_args.arg(32); // block_row_outer
      template_args.arg(4); // block_row_inner
      template_args.arg(4); // block_col
      break;
    default:
      NVF_THROW("unrecognized layout");
      break;
  }
  template_args.arg(group_size); // ITEMS_PER_THREAD

  // Build function arguments
  ArgumentBuilder func_args;
  func_args.arg(genInline(
      grouped_bqop->input(0)->as<kir::TensorIndex>()->view())); // input data
  func_args.arg(genInline(output)); // quantized output
  func_args.arg(genInline(grouped_bqop->blockScales()
                              ->as<kir::TensorIndex>()
                              ->view())); // block scales

  // generate logical index for runtime function
  func_args.arg(genInline(grouped_bqop->attributeVal(2)));
  func_args.arg(genInline(grouped_bqop->attributeVal(3)));
  func_args.arg("&").append(
      genVariableName(grouped_bqop->inputOffsets()) + "[0]");
  func_args.arg("&").append(
      genVariableName(grouped_bqop->outputOffsets()) + "[0]");
  func_args.arg(genInline(grouped_bqop->k()));
  func_args.arg(genInline(grouped_bqop->g()));

  if (output_dtype == DataType::Float4_e2m1fn) {
    func_args.arg(
        grouped_bqop->hasGlobalScale()
            ? genInline(grouped_bqop->globalScale())
            : "{}");
  }

  // Add swizzled allocation domain parameters if needed
  // This is always skipped when quantizing to mxfp8
  auto block_scales_tv =
      grouped_bqop->blockScales()->as<kir::TensorIndex>()->view();
  if (block_scales_tv->hasAllocation()) {
    auto logical_domain =
        TensorDomain::noReductions(block_scales_tv->getLogicalDomain());
    auto allocation_domain =
        TensorDomain::noReductions(block_scales_tv->getAllocationDomain());

    // Swizzled layout: 2D logical -> 5D allocation
    if (logical_domain.size() == 2 && allocation_domain.size() == 5) {
      // Add logical domain extent of the inner dimension
      func_args.arg(genInline(logical_domain[1]->extent()));

      // Add all allocation domain extents
      for (const auto* alloc_id : allocation_domain) {
        func_args.arg(genInline(alloc_id->extent()));
      }
    }
  }

  NVF_ERROR(
      output_dtype == DataType::Float4_e2m1fn,
      "only nvfp4 output is implemented");

  // Generate the function call
  indent() << genCall(
                  "bq::grouped_block_quantize_to_nvfp4",
                  template_args,
                  func_args)
           << ";\n";
}
API Implementation Completeness

The groupedBlockQuantize function creates the GroupedBlockQuantizationOp but only supports Float4_e2m1fn and Float8_e4m3fn output types, with explicit errors for other types. Need to verify this limitation is intentional and documented, and that the function properly handles all input validation and domain creation as expected.

BlockQuantizationResults groupedBlockQuantize(
    TensorView* input,
    TensorView* input_offsets,
    TensorView* output_offsets,
    BlockScalingFactorLayout layout,
    TensorView* global_scaling_factor,
    int64_t block_size,
    DataType out_dtype) {
  NVF_CHECK(
      out_dtype == DataType::Float4_e2m1fn ||
          out_dtype == DataType::Float8_e4m3fn,
      "Currently only output data type of Float4_e2m1fn or Float8_e4m3fn is "
      "supported");
  if (out_dtype == DataType::Float4_e2m1fn) {
    NVF_ERROR_EQ(
        block_size,
        16,
        "Block size must be 16 for Float4_e2m1fn, got ",
        block_size);
  } else if (out_dtype == DataType::Float8_e4m3fn) {
    NVF_ERROR_EQ(
        block_size,
        32,
        "Block size must be 32 for Float8_e4m3fn, got ",
        block_size);
    NVF_CHECK(
        !global_scaling_factor,
        "global_scaling_factor must be nullptr for Float8_e4m3fn");
  }

  // Validate input data type
  // We'll only support FP32 or BF16/FP16
  NVF_CHECK(
      input->getDataType().value() == DataType::Float ||
          input->getDataType().value() == DataType::BFloat16 ||
          input->getDataType().value() == DataType::Half,
      "Grouped block quantization expects floating point input but got ",
      input->getDataType().value());

  // Check that if global_scaling_factor in non-null
  // then it is a scalar float TensorView
  if (global_scaling_factor != nullptr) {
    NVF_CHECK(
        TensorDomain::noReductions(global_scaling_factor->getLogicalDomain())
            .empty(),
        "Global scaling factor for grouped block quantization must be a scalar "
        "tensor");
    NVF_CHECK(
        global_scaling_factor->getDataType().value() == DataType::Float,
        "Global scaling factor for grouped block quantization must be of float "
        "data "
        "type");
  }

  auto inp_domain = TensorDomain::noReductions(input->getLogicalDomain());

  // Validate input tensor is 2d
  NVF_ERROR_EQ(
      inp_domain.size(),
      2,
      "Grouped block quantization only supports 2-dimensional tensors");

  // Create output domain for quantized tensor (same shape as input)
  std::vector<IterDomain*> quantized_out_domain;
  quantized_out_domain.reserve(inp_domain.size());

  for (auto inp_domain_ptr : inp_domain) {
    quantized_out_domain.push_back(inp_domain_ptr->cloneWithoutRFactor());
  }

  // Create output tensors
  TensorView* quantized_tensor = IrBuilder::create<TensorView>(
      IrBuilder::create<TensorDomain>(
          quantized_out_domain,
          TensorDomain::getContiguityFilledWith(quantized_out_domain, true)),
      out_dtype);

  // Create output blocked scaling factor
  auto block_scales_dtype = (out_dtype == DataType::Float4_e2m1fn)
      ? DataType::Float8_e4m3fn
      : DataType::Float8_e8m0fnu;

  // This is used for both root and loop domain on output
  // maps directly to input's logical domain.
  std::vector<IterDomain*> scales_out_domain;
  scales_out_domain.reserve(inp_domain.size());

  for (auto inp_id : inp_domain) {
    if (inp_id == inp_domain.back()) {
      scales_out_domain.push_back(
          IterDomainBuilder(
              inp_id->start(),
              SimplifyingIrBuilder::divExpr(
                  inp_id->extent(),
                  IrBuilder::create<Val>(block_size, DataType::Index)))
              .build());

    } else {
      scales_out_domain.push_back(inp_id->cloneWithoutRFactor());
    }
  }

  std::vector<IterDomain*> offset_logical_dom =
      TensorDomain::noReductions(input_offsets->getLogicalDomain());
  Val* num_groups = offset_logical_dom[0]->extent();

  // Create the allocation domain of output.
  std::vector<IterDomain*> out_alloc_dom =
      layoutAllocationDomain(scales_out_domain, num_groups, layout);

  // Create block scaling factors
  TensorView* block_scales = IrBuilder::create<TensorView>(
      IrBuilder::create<TensorDomain>(
          /*root_domain=*/std::vector<IterDomain*>(),
          /*logical_domain=*/scales_out_domain,
          /*allocation=*/out_alloc_dom,
          /*loop_domain=*/scales_out_domain,
          /*alternate_loop_domain=*/std::nullopt,
          /*contiguity=*/
          TensorDomain::getContiguityFilledWith(out_alloc_dom, true),
          /*additional_ids=*/std::vector<IterDomain*>(),
          /*skip_checks=*/true),
      block_scales_dtype);

  // Create the grouped block quantization operation
  IrBuilder::create<GroupedBlockQuantizationOp>(
      block_scales,
      quantized_tensor,
      input,
      input_offsets,
      output_offsets,
      layout,
      inp_domain[1]->getMaybeExpandedExtent(),
      num_groups,
      global_scaling_factor,
      block_size);

  return BlockQuantizationResults(quantized_tensor, block_scales);
}

Test failures

  • (Medium, 1) Thunder nvFuser NanoGPT autograd scalar mismatch (CUDA, A100)

    Test Name A100 Source
    thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

@jjsjann123 jjsjann123 changed the title PR1: adding codegen support for GroupedBlockQuantizationOp GroupedBlockQuantizeOp PR1: Adding codegen support Jan 8, 2026
@jjsjann123 jjsjann123 marked this pull request as ready for review January 8, 2026 02:17
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 8, 2026

Greptile Summary

Adds codegen support for GroupedBlockQuantizationOp, merging block quantization and layout preprocessing into a single operation for grouped tensors. The implementation properly extends the existing BlockQuantizationOp pattern with additional group indexing support.

Key Changes:

  • New IR node GroupedBlockQuantizationOp combining functionality from BlockQuantizationOp and PreprocessGroupedMatmulInputSf
  • Complete codegen pipeline integration: validation, index lowering, and kernel code generation
  • Refactored shared validation logic into validateQuantizedOutputScheduling() function
  • Proper handling throughout scheduler, domain mapping, and registry utilities
  • Comprehensive test coverage with reference validation

Implementation Quality:

  • Consistent with existing BlockQuantizationOp patterns across all subsystems
  • Proper validation of memory types, scheduling constraints, and block size divisibility
  • Correct template and function argument generation for runtime calls
  • Well-integrated with existing infrastructure (dispatch, domain maps, scheduler)

Confidence Score: 5/5

  • Safe to merge - well-structured implementation following established patterns
  • The implementation is thorough and follows the existing BlockQuantizationOp patterns consistently. All necessary integration points are covered (IR definition, validation, index lowering, codegen, scheduler, domain mapping). The test provides good coverage. Previous review comments have been addressed.
  • No files require special attention

Important Files Changed

Filename Overview
csrc/ir/composite_nodes.h Adds GroupedBlockQuantizationOp class definition with proper structure mirroring BlockQuantizationOp
csrc/ir/composite_nodes.cpp Implements constructor, toString, and evaluation methods for GroupedBlockQuantizationOp
csrc/ops/arith.cpp Implements groupedBlockQuantize API with proper validation and tensor creation logic
csrc/device_lower/pass/index.cpp Adds index lowering for GroupedBlockQuantizationOp with proper block size usage
csrc/device_lower/validation.cpp Refactored validation logic into shared function and added GroupedBlockQuantizationOp validation
csrc/codegen.cpp Implements code generation for GroupedBlockQuantizationOp runtime function calls
tests/cpp/test_layout_op.cpp Added comprehensive test case for GroupedBlockQuantizationOp with validation against reference implementation
csrc/scheduler/pointwise.cpp Updated scheduler to detect and handle GroupedBlockQuantizationOp with vectorization capping

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as groupedBlockQuantize API
    participant IR as GroupedBlockQuantizationOp IR
    participant Scheduler as Pointwise Scheduler
    participant Validator as Validation Pass
    participant IndexLower as Index Lowering
    participant Codegen as Code Generator
    participant Runtime as Runtime Function

    User->>API: Call groupedBlockQuantize(input, offsets...)
    API->>API: Validate inputs (2D tensor, dtype, block_size)
    API->>IR: Create GroupedBlockQuantizationOp
    IR->>IR: Set inputs: input, input_offsets, output_offsets, k, g, global_scale
    IR->>IR: Set outputs: quantized_tensor, block_scales
    IR->>Scheduler: Register for scheduling
    Scheduler->>Scheduler: Detect GroupedBlockQuantizationOp
    Scheduler->>Scheduler: Cap vectorization factor to 4
    Scheduler->>Validator: Pass to validation
    Validator->>Validator: validateQuantizedOutputScheduling()
    Validator->>Validator: Check memory types (Local input/output, Global scales)
    Validator->>Validator: Verify Group ID is innermost, followed by TIDx
    Validator->>IndexLower: Pass validated IR
    IndexLower->>IndexLower: Compute logical indices [row_idx, col_idx]
    IndexLower->>IndexLower: Validate inner dimension divisible by block_size
    IndexLower->>Codegen: Lower to kernel IR
    Codegen->>Codegen: Extract group_size from ParallelType::Group
    Codegen->>Codegen: Build template args (has_global_scale, layout params, group_size)
    Codegen->>Codegen: Build function args (tensors, indices, offsets, k, g)
    Codegen->>Runtime: Generate call to bq::grouped_block_quantize_to_nvfp4
    Runtime->>Runtime: Perform quantization with swizzled layout
    Runtime-->>User: Return quantized output and block scales
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

26 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

jjsjann123 added a commit that referenced this pull request Jan 9, 2026
## Context

The series of PRs is trying to enable a single kernel for quantization
and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for
grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by
PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout
required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

### Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating
llama4 benchmark

## What's in this PR

1. refactor existing runtime function for re-use by the new op;
2. added runtime function for GroupedBlockQuantizeOp.
Base automatically changed from jj/grouped_block_quantize_op_0 to main January 9, 2026 19:53
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds GroupedBlockQuantizationOp, a new IR node that merges the functionality of BlockQuantizationOp and PreprocessGroupedMatmulInputSf into a single operation. This optimization enables single-kernel quantization and layout handling for grouped matrix multiplication operations.

Key Implementation Details

Core Operation: The new op takes a high-precision input tensor and produces:

  1. A quantized output tensor (same shape as input)
  2. Block scaling factors with swizzled layout directly suitable for grouped_mm

Architecture: The implementation follows the standard pattern for composite operations:

  • IR node definition in composite_nodes.h/cpp
  • User-facing API in ops/arith.cpp
  • Index lowering in device_lower/pass/index.cpp adds logical indices
  • Codegen in codegen.cpp generates runtime function call
  • Comprehensive validation in device_lower/validation.cpp
  • Scheduler integration for pointwise scheduling
  • Test coverage validates correctness against reference implementation

Key Technical Points:

  • Supports Float4_e2m1fn (nvfp4) output with Block128x4 layout
  • Requires specific parallelization: TIDx, BIDx, and Group parallel types
  • Group dimension must be 2/4 for FP32 or 2/4/8 for BF16/FP16 inputs
  • Block scales output has allocation domain with padding for swizzled layout
  • Vectorization capped at 4 when this op is present

The implementation is thorough and integrates well across all compiler passes including dispatch registration, logical domain mapping, broadcast domain analysis, scheduler topology checks, and non-divisible split handling.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations around attribute access patterns
  • The implementation is comprehensive and follows existing patterns well. All necessary integration points are covered including dispatch, validation, index lowering, codegen, scheduler, and domain mapping. The test validates correctness. Score of 4 (not 5) reflects the complexity of the attribute access pattern where row_idx/col_idx are added during index lowering but accessed in codegen - while this works correctly, it creates a subtle dependency that could be error-prone if not well understood
  • Pay close attention to csrc/codegen.cpp and csrc/ops/arith.cpp - ensure the attribute indexing pattern (attributeVal(2) and attributeVal(3)) remains valid if constructor signature changes

Important Files Changed

File Analysis

Filename Score Overview
csrc/ir/composite_nodes.h 5/5 Added GroupedBlockQuantizationOp class definition with proper accessor methods, constructor signature, and evaluate method stub
csrc/ir/composite_nodes.cpp 5/5 Implemented GroupedBlockQuantizationOp constructor, toString, toInlineString, and evaluate methods following existing patterns
csrc/ops/arith.cpp 4/5 Implemented groupedBlockQuantize with proper validation, domain setup, and layout allocation - minor concern about row_idx/col_idx not being passed to initial op construction
csrc/codegen.cpp 4/5 Added codegen handler for GroupedBlockQuantizationOp with template args, validation, and runtime function call - assumes row_idx/col_idx attributes always present at indices 2 and 3
csrc/device_lower/pass/index.cpp 5/5 Added index lowering for GroupedBlockQuantizationOp with proper logical index computation and runtime validation
csrc/device_lower/validation.cpp 5/5 Added comprehensive validation for GroupedBlockQuantizationOp including memory type checks, parallelization requirements, and scheduling constraints
tests/cpp/test_layout_op.cpp 5/5 Added test for GroupedBlockQuantizeOp validating quantized output and block scaling factor layout against reference implementation

Sequence Diagram

sequenceDiagram
    participant User
    participant OpsAPI as ops/arith.cpp
    participant IRNode as GroupedBlockQuantizationOp
    participant IndexLower as device_lower/pass/index
    participant Validation as device_lower/validation
    participant Codegen as codegen.cpp
    participant Runtime as Runtime Function

    User->>OpsAPI: groupedBlockQuantize(input, offsets, layout)
    OpsAPI->>OpsAPI: Validate inputs & data types
    OpsAPI->>OpsAPI: Create logical & allocation domains
    OpsAPI->>IRNode: Create GroupedBlockQuantizationOp<br/>(without row_idx/col_idx)
    IRNode->>OpsAPI: Return quantized_tensor & block_scales
    
    Note over IndexLower: Device Lowering Phase
    IndexLower->>IndexLower: Compute logical indices
    IndexLower->>IndexLower: Validate inner dim divisibility
    IndexLower->>IRNode: Create lowered op<br/>(WITH row_idx/col_idx)
    
    Validation->>IRNode: Validate memory types
    Validation->>IRNode: Check parallelization (TIDx, BIDx, Group)
    Validation->>IRNode: Verify group dimension & contiguity
    
    Note over Codegen: Code Generation Phase
    Codegen->>IRNode: Extract group size from loop domain
    Codegen->>IRNode: Validate group size (2/4 or 2/4/8)
    Codegen->>IRNode: Access row_idx/col_idx via attributeVal(2,3)
    Codegen->>Runtime: Generate call to<br/>bq::grouped_block_quantize_to_nvfp4
    Runtime-->>User: Execute quantization kernel
Loading

Comment on lines 64 to +65
.slice(0, 0, m_g)
.slice(1, 0, k);
.slice(1, 0, k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition of .to(ref.dtype()) to ensure dtype matching in the validation. This handles the case where the reference and output might have different dtypes due to the layout transformation.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive codegen support for GroupedBlockQuantizationOp, a new IR node that merges BlockQuantizationOp and PreprocessGroupedMatmulInputSf into a single operation for improved performance in grouped matrix multiplication quantization scenarios.

What Changed

Core IR Implementation:

  • Added GroupedBlockQuantizationOp class in composite_nodes.h/cpp with full constructor, accessors, and evaluation methods
  • The operation takes input tensor, input/output offsets, layout specification (Block128x4), k/g dimensions, optional global scale, and block_size parameter
  • Produces quantized output and block scaling factors with swizzled layout

Codegen & Lowering:

  • Implemented code generation handler that validates group sizes (2/4/8 for half-precision, 2/4 for float), builds template arguments for layout parameters (32, 4, 4), and generates calls to bq::grouped_block_quantize_to_nvfp4 runtime function
  • Added index lowering that creates TensorIndex nodes, validates input divisibility by block size, and computes logical indices
  • Comprehensive validation checks memory types, parallelization requirements (TIDx, BIDx, Group ID), and schedule ordering

Compiler Integration:

  • Registered in dispatch system for proper IR traversal
  • Updated all device lowering passes (sync analysis, trivial broadcast, non-divisible split)
  • Integrated with scheduler (pointwise, pointwise_non_tma) for special handling of quantization ops
  • Modified fusion segmenter to erase allocation domain (Transform Replay cannot handle padding transformations)
  • Updated logical domain mapping, tensor metadata, and kernel handling

Testing:

  • Added GroupedBlockQuantizeOp test case that validates against BlockQuantizationOp reference and verifies grouped layout with proper padding/swizzling

Issue Found

Critical Bug in Index Lowering (csrc/device_lower/pass/index.cpp:488):
The block_size parameter is hardcoded to 16 when creating the lowered GroupedBlockQuantizationOp, but it should use grouped_bqop->blockSize() to respect the original operation's block_size parameter. This means any non-default block size will be ignored during compilation.

Confidence Score: 3/5

  • This PR has one critical logic bug that needs to be fixed before merging
  • The implementation is comprehensive and well-structured with proper validation, dispatch registration, and test coverage. However, there is a confirmed logic bug in the index lowering pass where block_size is hardcoded to 16 instead of using the operation's actual block_size parameter. This will cause incorrect behavior for any non-default block sizes. The rest of the implementation appears solid with thorough integration across all compiler passes.
  • Pay close attention to csrc/device_lower/pass/index.cpp - the hardcoded block_size needs to be fixed to use grouped_bqop->blockSize()

Important Files Changed

File Analysis

Filename Score Overview
csrc/ir/composite_nodes.h 5/5 Adds GroupedBlockQuantizationOp class declaration with proper constructor, accessors (blockScales, quantizedOutput, in, blockSize, hasGlobalScale, globalScale, inputOffsets, outputOffsets, k, g, layout), and evaluation methods
csrc/ir/composite_nodes.cpp 5/5 Implements GroupedBlockQuantizationOp constructor, toString/toInlineString methods, and evaluate placeholder (throws as fallback kernel not yet implemented)
csrc/codegen.cpp 4/5 Adds codegen handler for GroupedBlockQuantizationOp that validates group size, builds template/function arguments including layout parameters (32, 4, 4 for Block128x4), and generates call to bq::grouped_block_quantize_to_nvfp4 runtime function
csrc/device_lower/pass/index.cpp 3/5 Implements index lowering for GroupedBlockQuantizationOp, creates TensorIndex nodes, validates input divisibility by block size, computes logical indices - contains bug where block_size is hardcoded to 16 instead of using grouped_bqop->blockSize()
csrc/device_lower/validation.cpp 5/5 Adds comprehensive validation for GroupedBlockQuantizationOp including memory type checks, parallelization requirements (TIDx, BIDx, Group ID), extent validation, and schedule ordering constraints
csrc/dispatch.h 5/5 Registers GroupedBlockQuantizationOp in DISPATCH_FOR_ALL_EXPRS macro for IR dispatch system integration
csrc/ops/arith.cpp 5/5 Implements groupedBlockQuantize API function that creates quantized tensor and block scales outputs with proper allocation domains, then instantiates GroupedBlockQuantizationOp
csrc/ops/arith.h 5/5 Declares groupedBlockQuantize API function with parameters for input, offsets, layout, global scaling factor, block size, and output dtype
tests/cpp/test_layout_op.cpp 5/5 Adds GroupedBlockQuantizeOp test case that validates quantized output and block scales against BlockQuantizationOp reference, verifies grouped layout transformation with proper padding/swizzling
csrc/fusion_segmenter.cpp 5/5 Updates fusion segmentation to erase allocation domain for GroupedBlockQuantizationOp's blockScales output (similar to PreprocessGroupedMatmulInputSf) since Transform Replay cannot handle allocation domain transformations with padding
csrc/scheduler/pointwise.cpp 5/5 Updates scheduler to detect GroupedBlockQuantizationOp (alongside BlockQuantizationOp) for special handling of block quantization operations in pointwise scheduling

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as groupedBlockQuantize API
    participant IR as GroupedBlockQuantizationOp
    participant Scheduler as Scheduler/Validation
    participant IndexLower as Index Lowering
    participant Codegen as Code Generator
    participant Runtime as Runtime Function

    User->>API: Call groupedBlockQuantize(input, offsets, layout, ...)
    API->>API: Create quantized_tensor output
    API->>API: Create block_scales output with allocation domain
    API->>IR: Create GroupedBlockQuantizationOp
    IR->>IR: Store inputs, offsets, layout, k, g, block_size
    
    Note over Scheduler: Compilation Phase
    Scheduler->>IR: Validate operation
    Scheduler->>Scheduler: Check memory types (Local)
    Scheduler->>Scheduler: Validate parallelization (TIDx, BIDx, Group)
    Scheduler->>Scheduler: Check group size (2/4/8 for half, 2/4 for float)
    
    IndexLower->>IR: Lower indices
    IndexLower->>IndexLower: Compute logical indices for row/col
    IndexLower->>IndexLower: Validate input divisible by block_size
    IndexLower->>IR: Create lowered GroupedBlockQuantizationOp
    Note over IndexLower: BUG: Hardcodes block_size=16
    
    Codegen->>IR: Generate code
    Codegen->>Codegen: Extract group_size from loop domain
    Codegen->>Codegen: Build template args (has_global_scale, 32, 4, 4, group_size)
    Codegen->>Codegen: Build function args (input, output, scales, indices, offsets, k, g)
    Codegen->>Runtime: Call bq::grouped_block_quantize_to_nvfp4<...>(...)
    
    Runtime->>Runtime: Perform quantization with layout transformation
    Runtime-->>User: Return quantized_tensor and block_scales
Loading

grouped_bqop->k(),
grouped_bqop->g(),
grouped_bqop->globalScale(),
16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the block_size parameter is hardcoded to 16, but it should use grouped_bqop->blockSize() to respect the original operation's block_size parameter

Suggested change
16,
grouped_bqop->blockSize(),

The GroupedBlockQuantizationOp constructor accepts a block_size parameter (line 1063 in composite_nodes.h), and the operation stores this value as an attribute accessible via blockSize() method (line 1081-1083). However, during index lowering, this value is being replaced with a hardcoded 16, which means any non-default block size specified by the user will be ignored.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds comprehensive codegen support for GroupedBlockQuantizationOp, a new IR node that merges BlockQuantizationOp and PreprocessGroupedMatmulInputSf into a single operation for improved performance in grouped matrix multiplication scenarios.

Key Changes:

  1. IR Node Implementation (csrc/ir/composite_nodes.h/cpp): New GroupedBlockQuantizationOp class with inputs (input tensor, input/output offsets, k, g, optional global_scale) and outputs (quantized tensor, block scales). The operation stores block size and layout as attributes, with row/col indices added during lowering.

  2. Code Generation (csrc/codegen.cpp): Handler generates runtime function call to bq::grouped_block_quantize_to_nvfp4 with template parameters for layout configuration (block_row_outer=32, block_row_inner=4, block_col=4 for Block128x4) and proper function arguments including offset tensors and dimension scalars.

  3. Index Lowering (csrc/device_lower/pass/index.cpp): Computes logical indices for the 2D matrix and validates that the inner dimension is divisible by the block size before creating the lowered operation.

  4. API Function (csrc/ops/arith.cpp/.h): groupedBlockQuantize() function with validation for supported output types (Float4_e2m1fn with block_size=16, Float8_e4m3fn with block_size=32) and proper tensor domain construction with layout-specific allocation domains.

  5. Compiler Integration: Updates to dispatch macros, logical domain mapping, broadcast domain tracking, and scheduler to properly handle the new operation throughout the compilation pipeline.

  6. Testing (tests/cpp/test_layout_op.cpp): Comprehensive test validating both quantized output correctness and proper block scaling factor layout with grouped operations.

Critical Issue Found: Index lowering hardcodes block_size to 16 (line 488) instead of using grouped_bqop->blockSize(), which will break Float8_e4m3fn quantization that requires block_size=32.

Confidence Score: 4/5

  • This PR is generally safe to merge after fixing the hardcoded block_size bug, as it follows established patterns and has good test coverage
  • Score reflects one critical logic bug (hardcoded block_size=16 in index lowering) that breaks Float8_e4m3fn support. Otherwise, the implementation is thorough and well-integrated across the codebase with consistent patterns matching PreprocessGroupedMatmulInputSf and BlockQuantizationOp. The changes are localized and the test validates the main use case.
  • Pay close attention to csrc/device_lower/pass/index.cpp line 488 - the hardcoded block_size must be changed to grouped_bqop->blockSize() before merging

Important Files Changed

File Analysis

Filename Score Overview
csrc/ir/composite_nodes.h 5/5 Adds GroupedBlockQuantizationOp class declaration with proper accessors for inputs, outputs, and attributes including layout, block size, and offset tensors
csrc/ir/composite_nodes.cpp 5/5 Implements GroupedBlockQuantizationOp constructor, toString methods, and evaluation placeholder - correctly manages inputs/outputs and attributes
csrc/codegen.cpp 5/5 Adds codegen handler for GroupedBlockQuantizationOp that generates runtime function call with template parameters for layout and correct function arguments
csrc/device_lower/pass/index.cpp 3/5 Implements index lowering for GroupedBlockQuantizationOp with logical index computation - contains hardcoded block_size bug at line 488
csrc/ops/arith.cpp 5/5 Implements groupedBlockQuantize API function with proper validation for block sizes (16 for nvfp4, 32 for mxfp8) and tensor domain construction with layout allocation
tests/cpp/test_layout_op.cpp 5/5 Adds comprehensive test for GroupedBlockQuantizationOp validating quantized output and block scaling factor layout correctness with grouped operations

Sequence Diagram

sequenceDiagram
    participant User as User/Python API
    participant API as ops/arith.cpp
    participant IR as IR Builder
    participant Fusion as Fusion Graph
    participant Scheduler as Scheduler
    participant IndexLower as Index Lowering
    participant Codegen as Code Generator
    participant Runtime as Runtime Kernel
    
    User->>API: groupedBlockQuantize(input, offsets, layout)
    API->>API: Validate inputs & block size
    API->>IR: Create tensor domains for output & scales
    IR->>IR: Build allocation domain with layout
    API->>Fusion: IrBuilder::create<GroupedBlockQuantizationOp>
    Fusion->>Fusion: Add GroupedBlockQuantizationOp to graph
    
    Note over Scheduler: Compilation Phase
    Scheduler->>Scheduler: Detect GroupedBlockQuantizationOp
    Scheduler->>Scheduler: Cap vectorization factor (≤4)
    Scheduler->>Scheduler: Apply pointwise scheduling
    
    Note over IndexLower: Lowering Phase
    IndexLower->>IndexLower: handle(GroupedBlockQuantizationOp)
    IndexLower->>IndexLower: Compute logical indices [row, col]
    IndexLower->>IndexLower: Validate block size divisibility
    IndexLower->>Fusion: Create lowered op with indices
    
    Note over Codegen: Code Generation
    Codegen->>Codegen: handle(GroupedBlockQuantizationOp)
    Codegen->>Codegen: Extract group_size from loop domain
    Codegen->>Codegen: Build template args (layout params)
    Codegen->>Codegen: Build function args (tensors, offsets, k, g)
    Codegen->>Runtime: Generate call to bq::grouped_block_quantize_to_nvfp4
    
    Note over Runtime: Execution
    Runtime->>Runtime: Perform block quantization per group
    Runtime->>Runtime: Apply layout transformation (Block128x4)
    Runtime->>User: Return quantized tensor & block scales
Loading

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 requested a review from protonu January 9, 2026 23:26
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Overview

This PR adds codegen support for GroupedBlockQuantizationOp, a new IR node that merges BlockQuantizationOp and PreprocessGroupedMatmulInputSf into a single operation. This optimization enables single-kernel quantization and layout handling for grouped tensors in nvfp4 quantization workflows.

Key Changes

Core IR Implementation

  • New IR Node: GroupedBlockQuantizationOp class added to composite_nodes.{h,cpp} with proper constructor, accessors, and attribute management
  • Input/Output Structure: Takes input tensor, input_offsets, output_offsets, k, g scalars, optional global_scale, plus row_idx/col_idx for indexing
  • Attributes: Stores block_size (index 0) and layout (index 1) as data attributes, plus optional row_idx/col_idx as regular attributes

Codegen and Lowering

  • Code Generation: Comprehensive handler in codegen.cpp that generates calls to bq::grouped_block_quantize_to_nvfp4 with proper template arguments and validation
  • Index Lowering: device_lower/pass/index.cpp implements index lowering with runtime validation that inner dimension is divisible by block_size
  • Validation: Extensive validation in device_lower/validation.cpp checks memory types, parallelization requirements (TIDx, BIDx, Group), and prevents z-axis parallelization

Integration Points

  • Dispatch: Registered in DISPATCH_FOR_ALL_EXPRS macro in dispatch.h
  • API: groupedBlockQuantize() function in ops/arith.{h,cpp} with input validation for data types, tensor dimensions, and block size requirements
  • Scheduler: Updates to scheduler/utils.cpp to handle offset tensors and block scales in caching logic
  • Other Files: Propagated through logical_domain_map, tensor_metadata, fusion_segmenter, and various analysis passes

Testing

  • Test Coverage: Single test case GroupedBlockQuantizeOp in test_layout_op.cpp validates correctness against reference implementation using BlockQuantizationOp
  • Test Scenario: 3 groups with [100, 150, 262] tokens, verifies both quantized output and block scaling factor layout

Architecture Consistency

The implementation follows established patterns from BlockQuantizationOp and PreprocessGroupedMatmulInputSf:

  • Similar attribute indexing patterns (adjusted for different parameter counts)
  • Consistent validation approach across lowering passes
  • Proper integration with scheduler and domain mapping infrastructure

Code Quality

Strengths:

  • Comprehensive validation at multiple stages (API, index lowering, device validation)
  • Proper error messages with context
  • Consistent with existing codebase patterns
  • Well-structured separation of concerns

⚠️ Minor Observations:

  • Currently only supports Float4_e2m1fn output (enforced in codegen.cpp:2003-2005)
  • Evaluation method not implemented (placeholder throws, which is acceptable for ops with runtime kernel implementations)
  • Single test case covers basic functionality but could benefit from additional edge case testing

Recommendation

The implementation is solid and ready to merge. The code is well-structured, properly validated, and consistently integrated across the codebase. The limitation to nvfp4 output is documented and intentional for this PR.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - comprehensive validation, consistent patterns, and proper testing
  • Score of 5 reflects thorough implementation with validation at every stage, consistent integration patterns matching existing ops, comprehensive error handling, and working test coverage. No critical issues found.
  • No files require special attention - implementation is consistent and well-validated across all changed files

Important Files Changed

File Analysis

Filename Score Overview
csrc/ir/composite_nodes.h 5/5 Adds GroupedBlockQuantizationOp class definition with proper accessor methods and attributes for merging quantization and layout operations
csrc/ir/composite_nodes.cpp 5/5 Implements GroupedBlockQuantizationOp constructor, toString methods, and placeholder evaluate; follows existing patterns from BlockQuantizationOp
csrc/codegen.cpp 5/5 Adds codegen handler for GroupedBlockQuantizationOp that generates runtime function calls with proper template args and validation
csrc/device_lower/pass/index.cpp 5/5 Implements index lowering for GroupedBlockQuantizationOp with validation that inner dimension is divisible by block size
csrc/device_lower/validation.cpp 5/5 Adds comprehensive validation for GroupedBlockQuantizationOp including memory type checks, parallelization requirements, and group dimension verification
csrc/ops/arith.cpp 5/5 Implements groupedBlockQuantize with comprehensive input validation, output tensor creation, and proper attribute setup
csrc/dispatch.h 5/5 Registers GroupedBlockQuantizationOp in dispatcher macro for proper IR node handling across the codebase
csrc/scheduler/utils.cpp 5/5 Updates scheduler utilities to handle GroupedBlockQuantizationOp offset tensors and block scales correctly in caching logic
tests/cpp/test_layout_op.cpp 5/5 Adds test case for GroupedBlockQuantizeOp verifying correct quantization and grouped layout transformation against reference

Sequence Diagram

sequenceDiagram
    participant User as Python/C++ API
    participant API as groupedBlockQuantize()
    participant IR as GroupedBlockQuantizationOp
    participant Scheduler as Scheduler
    participant IndexLower as Index Lowering
    participant Validation as Device Validation
    participant Codegen as Code Generator
    participant Runtime as CUDA Runtime

    User->>API: Call groupedBlockQuantize(input, offsets, layout)
    API->>API: Validate input dtype (Float/BF16/Half)
    API->>API: Validate 2D tensor
    API->>API: Check block_size (16 for nvfp4, 32 for mxfp8)
    API->>IR: Create GroupedBlockQuantizationOp node
    IR->>IR: Store inputs: input, offsets, k, g
    IR->>IR: Store attributes: block_size, layout
    IR->>Scheduler: Schedule fusion
    Scheduler->>Scheduler: Apply pointwise scheduler
    Scheduler->>Scheduler: Handle offset tensors in caching
    Scheduler->>IndexLower: Lower to kernel IR
    IndexLower->>IndexLower: Compute logical indices
    IndexLower->>IndexLower: Validate inner_dim % block_size == 0
    IndexLower->>IR: Create lowered GroupedBlockQuantizationOp
    IR->>Validation: Validate lowered op
    Validation->>Validation: Check MemoryType::Local
    Validation->>Validation: Verify TIDx and BIDx present
    Validation->>Validation: Check Group dimension exists
    Validation->>Validation: Ensure no z-axis parallelization
    Validation->>Codegen: Pass validated op
    Codegen->>Codegen: Extract group_size from loop domain
    Codegen->>Codegen: Validate group_size (2/4/8 for half, 2/4 for float)
    Codegen->>Codegen: Build template args (layout params, group_size)
    Codegen->>Codegen: Build function args (tensors, indices, offsets)
    Codegen->>Runtime: Generate call to bq::grouped_block_quantize_to_nvfp4
    Runtime->>Runtime: Execute quantization kernel
    Runtime-->>User: Return quantized_tensor + block_scales
Loading

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

26 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123 jjsjann123 requested a review from protonu January 14, 2026 23:08
@jjsjann123
Copy link
Collaborator Author

err. looks like I messed up something. I'll double check.

@jjsjann123
Copy link
Collaborator Author

!test

@jjsjann123
Copy link
Collaborator Author

!test

Copy link
Collaborator

@protonu protonu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - LGTM now!

@jjsjann123 jjsjann123 merged commit 311c26d into main Jan 17, 2026
65 of 66 checks passed
@jjsjann123 jjsjann123 deleted the jj/grouped_block_quantize_op_1 branch January 17, 2026 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants