From 80b1a37484c5bad4547c46487d10336416a3a76a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Feb 2026 22:49:12 -0500 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- .ci/scripts/export_model_artifact.sh | 19 +++++++--- .github/workflows/metal.yml | 12 ++++++ examples/models/parakeet/README.md | 33 +++++++++++------ .../models/parakeet/export_parakeet_tdt.py | 10 ++++- examples/models/parakeet/quantize.py | 37 ++++++++++++++++++- third-party/ao | 2 +- 6 files changed, 92 insertions(+), 21 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index d5c1913619d..eb89444a27a 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -26,13 +26,15 @@ Arguments: quant_name Quantization type (optional, default: non-quantized) Options: - non-quantized - - quantized-int4-tile-packed - - quantized-int4-weight-only + - quantized-int4-tile-packed (CUDA only) + - quantized-int4-weight-only (CUDA only) + - quantized-int4-metal (Metal only) output_dir Output directory for artifacts (optional, default: current directory) Examples: export_model_artifact.sh metal "openai/whisper-small" + export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal" export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output" export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output" @@ -127,21 +129,28 @@ case "$QUANT_NAME" in ;; quantized-int4-tile-packed) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" ;; quantized-int4-weight-only) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear_encoder 4w" ;; + quantized-int4-metal) + if [ "$DEVICE" != "metal" ]; then + echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend" + exit 1 + fi + EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w" + ;; *) echo "Error: Unsupported quantization '$QUANT_NAME'" - echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only" + echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal" exit 1 ;; esac diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index bf86b01aff8..1ca0910a67b 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -73,6 +73,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' @@ -123,6 +129,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 23593d324d1..857c61526ef 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -39,10 +39,10 @@ The export script supports quantizing encoder and decoder linear layers using [t | Argument | Description | |----------|-------------| -| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | | `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | -| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | | `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | | `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | @@ -50,14 +50,15 @@ The export script supports quantizing encoder and decoder linear layers using [t #### Quantization Configs -| Config | Description | -|--------|-------------| -| `4w` | 4-bit weight only quantization | -| `8w` | 8-bit weight only quantization | -| `8da4w` | 8-bit dynamic activation, 4-bit weight | -| `8da8w` | 8-bit dynamic activation, 8-bit weight | +| Config | Description | Backends | +|--------|-------------|----------| +| `4w` | 4-bit weight only quantization | CUDA | +| `8w` | 8-bit weight only quantization | CUDA | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | +| `fpa4w` | Floating point activation, 4-bit weight | Metal | -#### Example: 4-bit Weight Quantization with Tile Packing +#### Example: CUDA 4-bit Quantization ```bash python export_parakeet_tdt.py \ @@ -69,10 +70,20 @@ python export_parakeet_tdt.py \ --qlinear_group_size 32 \ --qlinear_packing_format tile_packed_to_4d \ --qembedding 8w \ - --output-dir ./parakeet_quantized + --output-dir ./parakeet_cuda_quantized ``` -**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA. +#### Example: Metal 4-bit Quantization + +```bash +python export_parakeet_tdt.py \ + --backend metal \ + --qlinear_encoder fpa4w \ + --qlinear_encoder_group_size 32 \ + --qlinear fpa4w \ + --qlinear_group_size 32 \ + --output-dir ./parakeet_metal_quantized +``` ### Metal Export (macOS) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 703c25091ec..dd0beca3307 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -583,7 +583,7 @@ def main(): parser.add_argument( "--qlinear", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for decoder linear layers", ) parser.add_argument( @@ -603,7 +603,7 @@ def main(): parser.add_argument( "--qlinear_encoder", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for encoder linear layers", ) parser.add_argument( @@ -639,6 +639,12 @@ def main(): if args.dtype == "fp16": parser.error("fp16 is not yet supported") + # Validate fpa4w quantization requires Metal backend + if args.qlinear == "fpa4w" and args.backend != "metal": + parser.error("--qlinear=fpa4w can only be used with --backend=metal") + if args.qlinear_encoder == "fpa4w" and args.backend != "metal": + parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") diff --git a/examples/models/parakeet/quantize.py b/examples/models/parakeet/quantize.py index 3e540d84834..1d69662efa1 100644 --- a/examples/models/parakeet/quantize.py +++ b/examples/models/parakeet/quantize.py @@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901 Args: module: The PyTorch module to quantize. - qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w"). + qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w"). qlinear_group_size: Group size for linear quantization (default: 32). qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). qembedding_config: Quantization config for embedding layers ("4w", "8w"). @@ -26,12 +26,45 @@ def quantize_model_( # noqa: C901 if not qlinear_config and not qembedding_config: return + from torchao.quantization.quant_api import quantize_ + + # Metal (MPS) quantization uses different API + if qlinear_config == "fpa4w": + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + + # Load MPS ops + import torchao.experimental.ops.mps # noqa: F401 + + config = UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) + + # Filter: only quantize Linear layers with compatible dimensions + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if qlinear_group_size == 0: + return False + if m.weight.shape[1] % 8 != 0: + print( + f" Skipping {fqn}: weight shape {m.weight.shape} not multiple of 8" + ) + return False + return True + return False + + print( + f" Applying {qlinear_config} linear quantization " + f"(group_size={qlinear_group_size})..." + ) + quantize_(module, config, filter_fn=linear_filter) + return + from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - quantize_, ) # Quantize embedding layers first diff --git a/third-party/ao b/third-party/ao index 28306f08500..1b4b6d998bf 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 28306f085003b892bc0a250c209d80f5d4a5147b +Subproject commit 1b4b6d998bf988f059e97a10181cbc4aec269b69 From 3ee1d56512a88eef70a7bea353574675819938f1 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Feb 2026 22:53:25 -0500 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- examples/models/parakeet/README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 857c61526ef..756fce068c5 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -58,7 +58,7 @@ The export script supports quantizing encoder and decoder linear layers using [t | `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | | `fpa4w` | Floating point activation, 4-bit weight | Metal | -#### Example: CUDA 4-bit Quantization +#### Example: 4-bit Weight Quantization with Tile Packing (CUDA) ```bash python export_parakeet_tdt.py \ @@ -70,9 +70,11 @@ python export_parakeet_tdt.py \ --qlinear_group_size 32 \ --qlinear_packing_format tile_packed_to_4d \ --qembedding 8w \ - --output-dir ./parakeet_cuda_quantized + --output-dir ./parakeet_quantized ``` +**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA. + #### Example: Metal 4-bit Quantization ```bash From ce1c30eb3646ca223e000476bed1c0f447f6255f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Feb 2026 23:02:33 -0500 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- .../metal/passes/decompose_linear_pass.py | 137 ++++++++++++------ backends/apple/metal/tests/test_modules.py | 23 ++- 2 files changed, 109 insertions(+), 51 deletions(-) diff --git a/backends/apple/metal/passes/decompose_linear_pass.py b/backends/apple/metal/passes/decompose_linear_pass.py index 9f4358110b7..e6b8578cc9f 100644 --- a/backends/apple/metal/passes/decompose_linear_pass.py +++ b/backends/apple/metal/passes/decompose_linear_pass.py @@ -6,7 +6,7 @@ import torch from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult class DecomposeLinearPass(ExportPass): @@ -20,49 +20,92 @@ class DecomposeLinearPass(ExportPass): then squeeze back to 2D. """ - def call_operator(self, op, args, kwargs, meta): - # Only intercept linear operations - if op not in (exir_ops.edge.aten.linear.default, torch.ops.aten.linear.default): - return super().call_operator(op, args, kwargs, meta) - - # Get input, weight, and bias arguments - input_arg = args[0] - weight_arg = args[1] - bias_arg = args[2] if len(args) > 2 else None - - # Determine which ops to use based on the input operator - if op == exir_ops.edge.aten.linear.default: - t_op = exir_ops.edge.aten.t.default - matmul_op = exir_ops.edge.aten.matmul.default - add_op = exir_ops.edge.aten.add.Tensor - unsqueeze_op = exir_ops.edge.aten.unsqueeze.default - squeeze_op = exir_ops.edge.aten.squeeze.dims - else: - t_op = torch.ops.aten.t.default - matmul_op = torch.ops.aten.matmul.default - add_op = torch.ops.aten.add.Tensor - unsqueeze_op = torch.ops.aten.unsqueeze.default - squeeze_op = torch.ops.aten.squeeze.dims - - # Check if input is 2D from metadata - needs_unsqueeze = len(meta["val"].shape) == 2 - - # Unsqueeze 2D input to 3D: (M, K) -> (1, M, K) - if needs_unsqueeze: - input_arg = super().call_operator(unsqueeze_op, (input_arg, 0), {}, meta) - - # Transpose weight - weight_t = super().call_operator(t_op, (weight_arg,), {}, meta) - - # Matmul - result = super().call_operator(matmul_op, (input_arg, weight_t), {}, meta) - - # Add bias if present - if bias_arg is not None: - result = super().call_operator(add_op, (result, bias_arg), {}, meta) - - # Squeeze 3D output back to 2D: (1, M, N) -> (M, N) - if needs_unsqueeze: - result = super().call_operator(squeeze_op, (result, [0]), {}, meta) - - return result + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + graph = graph_module.graph + + for node in graph.nodes: + # Check if this is a linear operation + is_linear = False + + if node.op == "call_function": + # Match both edge dialect and core aten linear operators + if node.target == exir_ops.edge.aten.linear.default: + is_linear = True + elif node.target == torch.ops.aten.linear.default: + is_linear = True + + if is_linear: + # Get input, weight, and bias arguments + input_node = node.args[0] + weight_node = node.args[1] + bias_node = node.args[2] if len(node.args) > 2 else None + + with graph.inserting_before(node): + # Determine which ops to use based on the input operator + target_str = str(node.target) + + if "executorch_exir_dialects_edge" in target_str: + # Use edge dialect operators + t_op = exir_ops.edge.aten.t.default + matmul_op = exir_ops.edge.aten.matmul.default + add_op = exir_ops.edge.aten.add.Tensor + unsqueeze_op = exir_ops.edge.aten.unsqueeze.default + squeeze_op = exir_ops.edge.aten.squeeze.dims + else: + # Use core aten operators + t_op = torch.ops.aten.t.default + matmul_op = torch.ops.aten.matmul.default + add_op = torch.ops.aten.add.Tensor + unsqueeze_op = torch.ops.aten.unsqueeze.default + squeeze_op = torch.ops.aten.squeeze.dims + + # Check if input is 2D + needs_unsqueeze = False + if hasattr(input_node, "meta") and "val" in input_node.meta: + if len(input_node.meta["val"].shape) == 2: + needs_unsqueeze = True + + # Unsqueeze 2D input to 3D: (M, K) -> (1, M, K) + current_input = input_node + if needs_unsqueeze: + current_input = graph.call_function( + unsqueeze_op, + args=(input_node, 0), + ) + + # Decompose linear: matmul(input, weight.T) + bias + weight_t = graph.call_function( + t_op, + args=(weight_node,), + ) + + matmul_result = graph.call_function( + matmul_op, + args=(current_input, weight_t), + ) + + if bias_node is not None: + result = graph.call_function( + add_op, + args=(matmul_result, bias_node), + ) + else: + result = matmul_result + + # Squeeze 3D output back to 2D: (1, M, N) -> (M, N) + if needs_unsqueeze: + result = graph.call_function( + squeeze_op, + args=(result, [0]), + ) + + # Replace all uses of the linear node with the decomposed result + node.replace_all_uses_with(result) + graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index eefcd15b69f..086e0d9a7fe 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -31,10 +31,15 @@ from torch.export import export from torch.nn.attention import SDPBackend -# Need to import to load the ops -import torchao.experimental.ops.mps # noqa: F401 -from torchao.experimental.quant_api import UIntxWeightOnlyConfig -from torchao.quantization.quant_api import quantize_ +try: + # Need to import to load the ops + import torchao.experimental.ops.mps # noqa: F401 + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + from torchao.quantization.quant_api import quantize_ + + TORCHAO_AVAILABLE = True +except ImportError: + TORCHAO_AVAILABLE = False # Check if MPS is available for export tests @@ -241,6 +246,7 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, } @@ -265,6 +271,7 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, } @@ -289,6 +296,7 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, } @@ -313,6 +321,7 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, } @@ -337,6 +346,7 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, } @@ -688,6 +698,11 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32) - "fpa4w": Floating point activation, 4-bit weight (Metal backend) qlinear_group_size: Group size for quantization (default: 32). """ + if not TORCHAO_AVAILABLE: + raise RuntimeError( + "torchao is not available. Install torchao to use quantization." + ) + if qlinear == "fpa4w": linear_config = UIntxWeightOnlyConfig( group_size=qlinear_group_size,