Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
39db621
Update
manuelcandales Jan 31, 2026
0ed7c5c
Update
manuelcandales Jan 31, 2026
b4310cc
Update
manuelcandales Jan 31, 2026
94c823c
Update
manuelcandales Jan 31, 2026
31b6f45
Update
manuelcandales Feb 2, 2026
c68cc6b
Update
manuelcandales Feb 2, 2026
bd7192f
Update
manuelcandales Feb 2, 2026
bcc8bda
Update
manuelcandales Feb 2, 2026
f166c50
Update
manuelcandales Feb 2, 2026
0834659
Update
manuelcandales Feb 2, 2026
ed4dcee
Update
manuelcandales Feb 2, 2026
a058197
Update
manuelcandales Feb 2, 2026
7146282
Update
manuelcandales Feb 2, 2026
d3501af
Update
manuelcandales Feb 2, 2026
fe5be37
Update
manuelcandales Feb 2, 2026
a0e3469
Update
manuelcandales Feb 2, 2026
fcfa832
Update
manuelcandales Feb 2, 2026
2e50286
Update
manuelcandales Feb 2, 2026
0145613
Update
manuelcandales Feb 2, 2026
2e3254a
Update
manuelcandales Feb 3, 2026
c5a3c1a
Update
manuelcandales Feb 3, 2026
457428b
Update
manuelcandales Feb 3, 2026
fec15bc
Update
manuelcandales Feb 3, 2026
40ec415
Update
manuelcandales Feb 3, 2026
c16dc59
Update
manuelcandales Feb 4, 2026
8ee7d60
Update
manuelcandales Feb 4, 2026
9966d37
Update
manuelcandales Feb 4, 2026
646b4b3
Update
manuelcandales Feb 5, 2026
3483dbf
Update
manuelcandales Feb 5, 2026
310b1b6
Update
manuelcandales Feb 5, 2026
6ad4556
Update
manuelcandales Feb 5, 2026
7e422e2
Update
manuelcandales Feb 5, 2026
1ae26f5
Update
manuelcandales Feb 5, 2026
086e05c
Update
manuelcandales Feb 5, 2026
9cede1e
Update
manuelcandales Feb 5, 2026
4149007
Update
manuelcandales Feb 5, 2026
ade165f
Update
manuelcandales Feb 5, 2026
11da547
Update
manuelcandales Feb 5, 2026
5ba588f
Update
manuelcandales Feb 5, 2026
0bfe7a5
Update
manuelcandales Feb 5, 2026
099bfd3
Update
manuelcandales Feb 5, 2026
7ee1d30
Update
manuelcandales Feb 5, 2026
3655f63
Update
manuelcandales Feb 5, 2026
a3a8aca
Update
manuelcandales Feb 5, 2026
f4203c8
Update
manuelcandales Feb 5, 2026
c96a67f
Update
manuelcandales Feb 5, 2026
e81b589
Update
manuelcandales Feb 5, 2026
0f2cddd
Update
manuelcandales Feb 5, 2026
8ff273f
Update
manuelcandales Feb 5, 2026
4316164
Update
manuelcandales Feb 5, 2026
401af46
Update
manuelcandales Feb 5, 2026
957ba1f
Update
manuelcandales Feb 5, 2026
87f1529
Update
manuelcandales Feb 5, 2026
9ea88a9
Update
manuelcandales Feb 5, 2026
cf89a2b
Update
manuelcandales Feb 5, 2026
56f91d6
Update
manuelcandales Feb 5, 2026
4962722
Update
manuelcandales Feb 5, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ Arguments:
quant_name Quantization type (optional, default: non-quantized)
Options:
- non-quantized
- quantized-int4-tile-packed
- quantized-int4-weight-only
- quantized-int4-tile-packed (CUDA only)
- quantized-int4-weight-only (CUDA only)
- quantized-int4-metal (Metal only)
- quantized-8da4w (XNNPACK only)

output_dir Output directory for artifacts (optional, default: current directory)

Examples:
export_model_artifact.sh metal "openai/whisper-small"
export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal"
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
Expand Down Expand Up @@ -131,18 +133,25 @@ case "$QUANT_NAME" in
;;
quantized-int4-tile-packed)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
;;
quantized-int4-weight-only)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear_encoder 4w"
;;
quantized-int4-metal)
if [ "$DEVICE" != "metal" ]; then
echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend"
exit 1
fi
EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w"
;;
quantized-8da4w)
if [ "$DEVICE" != "xnnpack" ]; then
echo "Error: quantized-8da4w is only supported with xnnpack device"
Expand All @@ -152,7 +161,7 @@ case "$QUANT_NAME" in
;;
*)
echo "Error: Unsupported quantization '$QUANT_NAME'"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-8da4w"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal, quantized-8da4w"
exit 1
;;
esac
Expand Down
15 changes: 13 additions & 2 deletions .github/workflows/metal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ jobs:
set -eux

echo "::group::Setup ExecuTorch"
PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh
PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh
echo "::endgroup::"

echo "::group::Build Metal Runtime"
${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --update-ao
${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build
echo "::endgroup::"

Expand Down Expand Up @@ -73,6 +72,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down Expand Up @@ -123,6 +128,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down
41 changes: 33 additions & 8 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,24 @@ The export script supports quantizing encoder and decoder linear layers using [t

| Argument | Description |
|----------|-------------|
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) |
| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) |
| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` |
| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` |
| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) |

#### Quantization Configs

| Config | Description |
|--------|-------------|
| `4w` | 4-bit weight only quantization |
| `8w` | 8-bit weight only quantization |
| `8da4w` | 8-bit dynamic activation, 4-bit weight |
| `8da8w` | 8-bit dynamic activation, 8-bit weight |
| Config | Description | Backends |
|--------|-------------|----------|
| `4w` | 4-bit weight only quantization | CUDA |
| `8w` | 8-bit weight only quantization | CUDA |
| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA |
| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA |
| `fpa4w` | Floating point activation, 4-bit weight | Metal |

#### Example: Dynamic Quantization for XNNPACK

Expand Down Expand Up @@ -86,6 +87,30 @@ python export_parakeet_tdt.py \

**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA.

#### Example: Metal 4-bit Quantization

```bash
python export_parakeet_tdt.py \
--backend metal \
--qlinear_encoder fpa4w \
--qlinear_encoder_group_size 32 \
--qlinear fpa4w \
--qlinear_group_size 32 \
--output-dir ./parakeet_metal_quantized
```

**Note:** Metal 4-bit quantization requires torchao built with experimental MPS (Metal) ops.

You can install torchao with Metal support from the `ao` repo:
```bash
USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation
```

Alternatively, you can build torchao with Metal support while installing ExecuTorch:
```bash
EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh
```

### Metal Export (macOS)

```bash
Expand Down
10 changes: 8 additions & 2 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def main():
parser.add_argument(
"--qlinear",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for decoder linear layers",
)
parser.add_argument(
Expand All @@ -642,7 +642,7 @@ def main():
parser.add_argument(
"--qlinear_encoder",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for encoder linear layers",
)
parser.add_argument(
Expand Down Expand Up @@ -678,6 +678,12 @@ def main():
if args.dtype == "fp16":
parser.error("fp16 is not yet supported")

# Validate fpa4w quantization requires Metal backend
if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
if args.qlinear_encoder == "fpa4w" and args.backend != "metal":
parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal")

os.makedirs(args.output_dir, exist_ok=True)

print("Extracting tokenizer...")
Expand Down
33 changes: 31 additions & 2 deletions examples/models/parakeet/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901

Args:
module: The PyTorch module to quantize.
qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w").
qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w").
qlinear_group_size: Group size for linear quantization (default: 32).
qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d").
qembedding_config: Quantization config for embedding layers ("4w", "8w").
Expand All @@ -26,12 +26,41 @@ def quantize_model_( # noqa: C901
if not qlinear_config and not qembedding_config:
return

from torchao.quantization.quant_api import quantize_

# Metal (MPS) quantization uses different API
if qlinear_config == "fpa4w":
# Load MPS ops
import torchao.experimental.ops.mps # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyConfig

config = UIntxWeightOnlyConfig(
group_size=qlinear_group_size,
bitwidth=4,
Copy link
Contributor

@mergennachin mergennachin Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the pin past pytorch/ao#3829, and set

uintx_choose_qparams_algorithm="hqq"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be done in a follow-up PR too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's my plan, to do in follow-up PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here #17258

)

def linear_filter(m, fqn):
if isinstance(m, torch.nn.Linear):
if m.weight.shape[1] % qlinear_group_size != 0:
raise ValueError(
f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. "
f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501
)
return True
return False

print(
f" Applying {qlinear_config} linear quantization "
f"(group_size={qlinear_group_size})..."
)
quantize_(module, config, filter_fn=linear_filter)
return

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)

# Quantize embedding layers first
Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 228 files
Loading