Skip to content

Conversation

@jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Jan 8, 2026

Context

The series of PRs is trying to enable a single kernel for quantization and layout handling of block scaling factor on grouped tensors.

Existing solution for nvfp4 quantization of activation Tensor for grouped_mm relies on two operation:
i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
ii. block_scaling_factor needs to be processed by PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout required by grouped_mm kernels

The series of PRs tries to merge the two operation into a single one.

Stacked PRs

#5775 GroupedBlockQuantizationOp PR0: Adding runtime function
#5776 GroupedBlockQuantizationOp PR1: Adding codegen support
#5777 GroupedBlockQuantizationOp PR2: Adding python API and updating llama4 benchmark

What's in this PR

  1. Added python API ops.nv_grouped_block_quantize for GroupedBlockQuantizationOp;
  2. Added python translation rule for GroupedBlockQuantizationOp;
  3. Python test for GroupedBlockQuantizationOp;
  4. Switched 2 operation quantization for grouped_mm activation to use GroupedBlockQuanitzationOp instead.

1. refactor existing block_layout op and block_quantization_kernel to re-use existing runtime functions;
2. added runtime function for GroupedBlockQuantizeOp
@github-actions
Copy link

github-actions bot commented Jan 8, 2026

Review updated until commit 8a6d209

Description

  • Added Python API ops.nv_grouped_block_quantize for GroupedBlockQuantizationOp with comprehensive parameters

  • Implemented Python translation rule for GroupedBlockQuantizationOp to enable code generation

  • Updated nvfp4_grouped_mm_translator to use consolidated single operation instead of two separate calls

  • Added comprehensive test test_grouped_block_quantize_op validating quantization against reference implementation

Changes walkthrough

Relevant files
Enhancement
ops.cpp
Added Python API for GroupedBlockQuantizationOp                   

python/python_direct/ops.cpp

  • Added ops.nv_grouped_block_quantize function with parameters for input
    tensor, offsets, global_scale, block_size, and dtype
  • Function returns tuple of quantized tensor and block scaling factors
    in NVFP4 format
  • Includes comprehensive docstring with parameter descriptions and
    return value documentation
  • +49/-0   
    python_translate.cpp
    Added Python translation rule for GroupedBlockQuantizationOp

    python/python_direct/python_translate.cpp

  • Implemented handle method for GroupedBlockQuantizationOp in
    PythonTranslator class
  • Generates Python operation call with default arguments for
    global_scale, block_size, and dtype
  • Handles output tensor registration and visited values tracking
  • +25/-0   
    benchmark_inference.py
    Updated benchmark to use consolidated quantization operation

    benchmarks/python/benchmark_inference.py

  • Replaced two-operation quantization sequence with single
    nv_grouped_block_quantize call
  • Consolidated nv_block_quantize and preprocess_grouped_matmul_input_sf
    operations
  • Simplified code by removing intermediate variable assignments
  • +1/-2     
    Tests
    test_narrow_precision.py
    Added comprehensive test for GroupedBlockQuantizationOp   

    tests/python/direct/test_narrow_precision.py

  • Added comprehensive test test_grouped_block_quantize_op with multiple
    parameter configurations
  • Validates quantization results against reference implementation using
    torch._scaled_mm
  • Includes error checking for max difference thresholds and large
    difference ratios
  • Tests grouped tensor quantization with proper offset handling and
    scaling factor processing
  • +167/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Test Coverage

    The new test test_grouped_block_quantize_op appears comprehensive but only tests a single configuration. Consider if additional parameter combinations (different block sizes, tensor dimensions, or data types) should be tested to ensure robustness of the new grouped block quantization functionality.

    def test_grouped_block_quantize_op(
        nvfuser_direct_test,
        config,
        tokens_per_expert_neg_one,
        out_dtype,
    ):
        BLOCK_SIZE = 16
    
        # k dimension is multiple of 4 * 16 to avoid padding on block scaling factor
        m, n, k = config
        assert k % 64 == 0
        tokens_per_expert = list(tokens_per_expert_neg_one)
        tokens_per_expert.append(m - sum(tokens_per_expert))
        g = len(tokens_per_expert)
    
        mat1 = torch.randn((m, k), dtype=torch.float32, device="cuda:0")
        # format is g, n, k instead of g, k, n
        mat2 = torch.randn((g, n, k), dtype=torch.float32, device="cuda:0")
    
        offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        blockscale_offsets = torch.empty((g,), dtype=torch.int32, device="cuda:0")
        problem_sizes = torch.empty((g, 3), dtype=torch.int32, device="cuda:0")
    
        # prepare quantization for mat2
        mat2_gs = torch.empty((g,), dtype=torch.float32, device="cuda:0")
        scale2 = torch.empty(
            (g, n, k // BLOCK_SIZE), dtype=torch.float8_e4m3fn, device="cuda:0"
        )
    
        acc_tokens = 0
        rounded_acc_tokens = 0
        mat2_scaled = torch.empty(
            (g, n, k // 2), dtype=torch.float4_e2m1fn_x2, device="cuda:0"
        )
    
        for i in range(g):
            global_sf = FLOAT4_E2M1_MAX * FLOAT8_E4M3_MAX / mat2[i].max()
            offsets[i] = acc_tokens
            blockscale_offsets[i] = rounded_acc_tokens
            acc_tokens += tokens_per_expert[i]
            # Note: we technically don't need to round up, since k is perfectly sized.
            rounded_acc_tokens += round_up(tokens_per_expert[i], 128)
    
            problem_sizes[i][0] = tokens_per_expert[i]
            problem_sizes[i][1] = n
            problem_sizes[i][2] = k
    
            scaled_mat2_i, bs_mat2_i = pytorch_nvfp4_quantize(mat2[i], global_sf)
            mat2_gs[i] = 1.0 / global_sf
            mat2_scaled[i] = scaled_mat2_i
            scale2[i] = linear_to_swizzled_128_4(bs_mat2_i)
    
        def nvfuser_fusion_id0(fd: FusionDefinition) -> None:
            mat1 = fd.define_tensor(
                shape=[-1, -1],
                contiguity=True,
                dtype=DataType.Float,
                is_cpu=False,
            )
            mat2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float4_e2m1fn,
                is_cpu=False,
                stride_order=[2, 0, 1],
            )
            scale2 = fd.define_tensor(
                shape=[-1, -1, -1],
                contiguity=True,
                dtype=DataType.Float8_e4m3fn,
                is_cpu=False,
            )
            alpha = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Float, is_cpu=False
            )
            problem_sizes = fd.define_tensor(
                shape=[-1, -1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
            blockscale_offsets = fd.define_tensor(
                shape=[-1], contiguity=True, dtype=DataType.Int32, is_cpu=False
            )
    
            fp4_mat1, fp8_scale1 = fd.ops.nv_grouped_block_quantize(
                mat1, offsets, blockscale_offsets
            )
    
            out = fd.ops.cutlass_nvfp4_grouped_mm(
                fp4_mat1,
                mat2,
                fp8_scale1,
                scale2,
                alpha,
                problem_sizes,
                offsets,
                blockscale_offsets,
                DataType.BFloat16,
            )
            fd.add_output(out)
    
        inputs = [
            mat1,
            mat2_scaled.view(torch.float4_e2m1fn_x2).transpose(-1, -2),
            scale2,
            mat2_gs,
            problem_sizes,
            offsets,
            blockscale_offsets,
        ]
    
        o, _ = nvfuser_direct_test.exec_nvfuser(nvfuser_fusion_id0, inputs)
        # quantization for activation is needed for reference.
        # note: following sglang implementation, not computing global scaling factor for mat1
        #       similarly, we don't need to apply mat1_gs to alpha
        mat1_gs = torch.ones((g,), dtype=torch.float32, device="cuda:0")
        mat1_fp4, scale1 = activation_scale_to_nvfp4(
            mat1, mat1_gs, offsets, blockscale_offsets, BLOCK_SIZE
        )
        o_decomposed_ref = torch.empty(m, n, dtype=torch.bfloat16, device="cuda:0")
        for i in range(g):
            l = offsets[i]
            l_sf = blockscale_offsets[i]
            if i == g - 1:
                r = m
            else:
                r = offsets[i + 1]
            r_sf = round_up(tokens_per_expert[i], 128) + l_sf
            # For some reason I cannot feed mat2_gs[i] as alpha in the torch kernel.
            # This triggers a cublas invalid value error.
            o_decomposed_ref[l:r] = (
                torch._scaled_mm(
                    mat1_fp4[l:r],
                    mat2_scaled[i].transpose(-1, -2),
                    scale1[l_sf:r_sf],
                    scale2[i],
                    None,
                    None,
                    torch.bfloat16,
                )
                * mat2_gs[i]
            )
    
        # Validate: nvfuser quantization should match baseline
        abs_diff = torch.abs(o[0] - o_decomposed_ref)
        max_diff = torch.max(abs_diff)
        assert max_diff <= 10.0, f"Max difference {max_diff:.4f} exceeds threshold of 10.0"
    
        # Check that large differences (> 5.0) are rare (< 10% of elements)
        large_diff_count = torch.count_nonzero(torch.gt(abs_diff, 5.0))
        large_diff_ratio = large_diff_count / abs_diff.numel()
        assert (
            large_diff_ratio < 0.1
        ), f"Large diff ratio {large_diff_ratio:.2%} exceeds 10% threshold"
    Performance Validation

    The benchmark was updated to use the new single-operation approach, but there's no explicit performance comparison or validation that the new nv_grouped_block_quantize provides the expected performance benefits over the previous two-operation approach. Consider adding performance metrics to demonstrate the improvement.

    fp4_mat1, layout_fp8_scale1 = fd.ops.nv_grouped_block_quantize(nv_act, nv_offsets, nv_blocksf_offsets)

    @jjsjann123 jjsjann123 changed the title PR2: adding python API and updating benchmarks GroupedBlockQuantizeOp PR2: Adding python API and updating llama4 benchmark Jan 8, 2026
    jjsjann123 added a commit that referenced this pull request Jan 9, 2026
    ## Context
    
    The series of PRs is trying to enable a single kernel for quantization
    and layout handling of block scaling factor on grouped tensors.
    
    Existing solution for nvfp4 quantization of activation Tensor for
    grouped_mm relies on two operation:
    i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
    ii. block_scaling_factor needs to be processed by
    PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout
    required by grouped_mm kernels
    
    The series of PRs tries to merge the two operation into a single one.
    
    ### Stacked PRs
    
    #5775 GroupedBlockQuantizationOp PR0: Adding runtime function
    #5776 GroupedBlockQuantizationOp PR1: Adding codegen support
    #5777 GroupedBlockQuantizationOp PR2: Adding python API and updating
    llama4 benchmark
    
    ## What's in this PR
    
    1. refactor existing runtime function for re-use by the new op;
    2. added runtime function for GroupedBlockQuantizeOp.
    jjsjann123 added a commit that referenced this pull request Jan 17, 2026
    ## Context
    
    The series of PRs is trying to enable a single kernel for quantization
    and layout handling of block scaling factor on grouped tensors.
    
    Existing solution for nvfp4 quantization of activation Tensor for
    grouped_mm relies on two operation:
    i. BlockQuantizationOp produces scaled_tv and block_scaling_factor.
    ii. block_scaling_factor needs to be processed by
    PreprocessGroupedMatmulInputSf in order to satisfy the swizzle layout
    required by grouped_mm kernels
    
    The series of PRs tries to merge the two operation into a single one.
    
    ### Stacked PRs
    
    #5775 GroupedBlockQuantizationOp PR0: Adding runtime function
    #5776 GroupedBlockQuantizationOp PR1: Adding codegen support
    #5777 GroupedBlockQuantizationOp PR2: Adding python API and updating
    llama4 benchmark
    
    ## What's in this PR
    
    1. Adding Fusion IR node GroupedBlockQuantizationOp. The operation is a
    combination of BlockQuantizationOp and PreprocessGroupedMatmulInputSf,
    where it inherits all the validation / checks from the two operations.
    The operation is similar to BlockQuantizationOp, with the exception
    that:
    i. The block scaling factor output doesn't have the swizzle logic
    represented as allocation domain transformations;
    ii. It takes an additional inputs (input_offsets and output_offsets) to
    facilitate group indexing, similar to PreprocessGroupedMatmulInputSf.
    
    2. Adding cpp test case for GroupedBlockQuantizationOp.
    Base automatically changed from jj/grouped_block_quantize_op_1 to main January 17, 2026 00:51
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 requested a review from protonu January 18, 2026 06:31
    @jjsjann123 jjsjann123 marked this pull request as ready for review January 18, 2026 06:31
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 18, 2026

    Greptile Summary

    This PR adds Python API support for GroupedBlockQuantizationOp, which consolidates two operations (nv_block_quantize + preprocess_grouped_matmul_input_sf) into a single kernel for grouped matmul quantization.

    Key Changes:

    • Added Python binding ops.nv_grouped_block_quantize that wraps the C++ groupedBlockQuantize function with proper parameter mapping
    • Implemented Python translation handler for GroupedBlockQuantizationOp following the same pattern as BlockQuantizationOp
    • Updated llama4 benchmark to use the new unified operation instead of the two-step flow
    • Added comprehensive test coverage that validates correctness against the decomposed reference implementation

    Technical Details:

    • The new operation hardcodes BlockScalingFactorLayout::Block128x4 layout, matching the grouped matmul requirements
    • Block scales output is swizzled in storage (as documented in the docstring)
    • Test validates numerical accuracy with tolerances: max_diff ≤ 10.0 and large diffs (>5.0) occurring in <10% of elements

    The implementation is clean, follows existing patterns in the codebase, and includes proper testing.

    Confidence Score: 5/5

    • This PR is safe to merge with no identified issues
    • The implementation correctly adds Python bindings and test coverage for an existing C++ operation. All parameter mappings are correct, the code follows established patterns in the codebase, and comprehensive testing validates the functionality against a reference implementation
    • No files require special attention

    Important Files Changed

    Filename Overview
    python/python_direct/ops.cpp Adds Python binding for nv_grouped_block_quantize, correctly mapping parameters to C++ function groupedBlockQuantize with proper argument handling and documentation
    python/python_direct/python_translate.cpp Implements Python translation handler for GroupedBlockQuantizationOp, following the same pattern as BlockQuantizationOp with correct default arguments
    benchmarks/python/benchmark_inference.py Replaces two-operation quantization flow with single nv_grouped_block_quantize call, simplifying the activation quantization for grouped matmul
    tests/python/direct/test_narrow_precision.py Adds comprehensive test for nv_grouped_block_quantize operation, validating correctness against decomposed reference implementation with proper tolerance checks

    Sequence Diagram

    sequenceDiagram
        participant User
        participant PythonAPI as Python API (ops.nv_grouped_block_quantize)
        participant CPP as C++ groupedBlockQuantize
        participant Runtime as Runtime Function
        participant Output as Quantized Output
    
        User->>PythonAPI: Call nv_grouped_block_quantize(input, input_offsets, output_offsets, global_scale, block_size, dtype)
        PythonAPI->>CPP: groupedBlockQuantize(input, input_offsets, output_offsets, Block128x4, global_scale, block_size, dtype)
        CPP->>Runtime: Execute GroupedBlockQuantizationOp
        Note over Runtime: Combines quantization + swizzle layout
        Runtime-->>CPP: BlockQuantizationResults{quantized_tensor, block_scales}
        CPP-->>PythonAPI: Return (quantized_tensor, block_scales)
        PythonAPI-->>User: Python tuple (fp4_tensor, swizzled_scales)
        
        Note over User,Output: Old flow (2 operations):
        Note over User: 1. nv_block_quantize(input) → (fp4, scales)
        Note over User: 2. preprocess_grouped_matmul_input_sf(scales, offsets) → swizzled_scales
        Note over User,Output: New flow (1 operation):
        Note over User: nv_grouped_block_quantize(input, offsets, blockscale_offsets) → (fp4, swizzled_scales)
    
    Loading

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants