Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ Arguments:
quant_name Quantization type (optional, default: non-quantized)
Options:
- non-quantized
- quantized-int4-tile-packed
- quantized-int4-weight-only
- quantized-int4-tile-packed (CUDA only)
- quantized-int4-weight-only (CUDA only)
- quantized-int4-metal (Metal only)

output_dir Output directory for artifacts (optional, default: current directory)

Examples:
export_model_artifact.sh metal "openai/whisper-small"
export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal"
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
Expand Down Expand Up @@ -127,21 +129,28 @@ case "$QUANT_NAME" in
;;
quantized-int4-tile-packed)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
;;
quantized-int4-weight-only)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear_encoder 4w"
;;
quantized-int4-metal)
if [ "$DEVICE" != "metal" ]; then
echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend"
exit 1
fi
EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w"
;;
*)
echo "Error: Unsupported quantization '$QUANT_NAME'"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal"
exit 1
;;
esac
Expand Down
12 changes: 12 additions & 0 deletions .github/workflows/metal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down Expand Up @@ -123,6 +129,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down
137 changes: 47 additions & 90 deletions backends/apple/metal/passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import ExportPass


class DecomposeLinearPass(ExportPass):
Expand All @@ -20,92 +20,49 @@ class DecomposeLinearPass(ExportPass):
then squeeze back to 2D.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
graph = graph_module.graph

for node in graph.nodes:
# Check if this is a linear operation
is_linear = False

if node.op == "call_function":
# Match both edge dialect and core aten linear operators
if node.target == exir_ops.edge.aten.linear.default:
is_linear = True
elif node.target == torch.ops.aten.linear.default:
is_linear = True

if is_linear:
# Get input, weight, and bias arguments
input_node = node.args[0]
weight_node = node.args[1]
bias_node = node.args[2] if len(node.args) > 2 else None

with graph.inserting_before(node):
# Determine which ops to use based on the input operator
target_str = str(node.target)

if "executorch_exir_dialects_edge" in target_str:
# Use edge dialect operators
t_op = exir_ops.edge.aten.t.default
matmul_op = exir_ops.edge.aten.matmul.default
add_op = exir_ops.edge.aten.add.Tensor
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
else:
# Use core aten operators
t_op = torch.ops.aten.t.default
matmul_op = torch.ops.aten.matmul.default
add_op = torch.ops.aten.add.Tensor
unsqueeze_op = torch.ops.aten.unsqueeze.default
squeeze_op = torch.ops.aten.squeeze.dims

# Check if input is 2D
needs_unsqueeze = False
if hasattr(input_node, "meta") and "val" in input_node.meta:
if len(input_node.meta["val"].shape) == 2:
needs_unsqueeze = True

# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
current_input = input_node
if needs_unsqueeze:
current_input = graph.call_function(
unsqueeze_op,
args=(input_node, 0),
)

# Decompose linear: matmul(input, weight.T) + bias
weight_t = graph.call_function(
t_op,
args=(weight_node,),
)

matmul_result = graph.call_function(
matmul_op,
args=(current_input, weight_t),
)

if bias_node is not None:
result = graph.call_function(
add_op,
args=(matmul_result, bias_node),
)
else:
result = matmul_result

# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
if needs_unsqueeze:
result = graph.call_function(
squeeze_op,
args=(result, [0]),
)

# Replace all uses of the linear node with the decomposed result
node.replace_all_uses_with(result)
graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()

return PassResult(graph_module, modified)
def call_operator(self, op, args, kwargs, meta):
# Only intercept linear operations
if op not in (exir_ops.edge.aten.linear.default, torch.ops.aten.linear.default):
return super().call_operator(op, args, kwargs, meta)

# Get input, weight, and bias arguments
input_arg = args[0]
weight_arg = args[1]
bias_arg = args[2] if len(args) > 2 else None

# Determine which ops to use based on the input operator
if op == exir_ops.edge.aten.linear.default:
t_op = exir_ops.edge.aten.t.default
matmul_op = exir_ops.edge.aten.matmul.default
add_op = exir_ops.edge.aten.add.Tensor
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
else:
t_op = torch.ops.aten.t.default
matmul_op = torch.ops.aten.matmul.default
add_op = torch.ops.aten.add.Tensor
unsqueeze_op = torch.ops.aten.unsqueeze.default
squeeze_op = torch.ops.aten.squeeze.dims

# Check if input is 2D from metadata
needs_unsqueeze = len(meta["val"].shape) == 2

# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
if needs_unsqueeze:
input_arg = super().call_operator(unsqueeze_op, (input_arg, 0), {}, meta)

# Transpose weight
weight_t = super().call_operator(t_op, (weight_arg,), {}, meta)

# Matmul
result = super().call_operator(matmul_op, (input_arg, weight_t), {}, meta)

# Add bias if present
if bias_arg is not None:
result = super().call_operator(add_op, (result, bias_arg), {}, meta)

# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
if needs_unsqueeze:
result = super().call_operator(squeeze_op, (result, [0]), {}, meta)

return result
31 changes: 22 additions & 9 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,26 @@ The export script supports quantizing encoder and decoder linear layers using [t

| Argument | Description |
|----------|-------------|
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) |
| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) |
| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` |
| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` |
| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) |

#### Quantization Configs

| Config | Description |
|--------|-------------|
| `4w` | 4-bit weight only quantization |
| `8w` | 8-bit weight only quantization |
| `8da4w` | 8-bit dynamic activation, 4-bit weight |
| `8da8w` | 8-bit dynamic activation, 8-bit weight |
| Config | Description | Backends |
|--------|-------------|----------|
| `4w` | 4-bit weight only quantization | CUDA |
| `8w` | 8-bit weight only quantization | CUDA |
| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA |
| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA |
| `fpa4w` | Floating point activation, 4-bit weight | Metal |

#### Example: 4-bit Weight Quantization with Tile Packing
#### Example: 4-bit Weight Quantization with Tile Packing (CUDA)

```bash
python export_parakeet_tdt.py \
Expand All @@ -74,6 +75,18 @@ python export_parakeet_tdt.py \

**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA.

#### Example: Metal 4-bit Quantization

```bash
python export_parakeet_tdt.py \
--backend metal \
--qlinear_encoder fpa4w \
--qlinear_encoder_group_size 32 \
--qlinear fpa4w \
--qlinear_group_size 32 \
--output-dir ./parakeet_metal_quantized
```

### Metal Export (macOS)

```bash
Expand Down
10 changes: 8 additions & 2 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def main():
parser.add_argument(
"--qlinear",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for decoder linear layers",
)
parser.add_argument(
Expand All @@ -603,7 +603,7 @@ def main():
parser.add_argument(
"--qlinear_encoder",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for encoder linear layers",
)
parser.add_argument(
Expand Down Expand Up @@ -639,6 +639,12 @@ def main():
if args.dtype == "fp16":
parser.error("fp16 is not yet supported")

# Validate fpa4w quantization requires Metal backend
if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
if args.qlinear_encoder == "fpa4w" and args.backend != "metal":
parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal")

os.makedirs(args.output_dir, exist_ok=True)

print("Extracting tokenizer...")
Expand Down
37 changes: 35 additions & 2 deletions examples/models/parakeet/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901

Args:
module: The PyTorch module to quantize.
qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w").
qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w").
qlinear_group_size: Group size for linear quantization (default: 32).
qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d").
qembedding_config: Quantization config for embedding layers ("4w", "8w").
Expand All @@ -26,12 +26,45 @@ def quantize_model_( # noqa: C901
if not qlinear_config and not qembedding_config:
return

from torchao.quantization.quant_api import quantize_

# Metal (MPS) quantization uses different API
if qlinear_config == "fpa4w":
from torchao.experimental.quant_api import UIntxWeightOnlyConfig

# Load MPS ops
import torchao.experimental.ops.mps # noqa: F401

config = UIntxWeightOnlyConfig(
group_size=qlinear_group_size,
bitwidth=4,
)

# Filter: only quantize Linear layers with compatible dimensions
def linear_filter(m, fqn):
if isinstance(m, torch.nn.Linear):
if qlinear_group_size == 0:
return False
if m.weight.shape[1] % 8 != 0:
print(
f" Skipping {fqn}: weight shape {m.weight.shape} not multiple of 8"
)
return False
return True
return False

print(
f" Applying {qlinear_config} linear quantization "
f"(group_size={qlinear_group_size})..."
)
quantize_(module, config, filter_fn=linear_filter)
return

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Int8DynamicActivationIntxWeightConfig,
IntxWeightOnlyConfig,
quantize_,
)

# Quantize embedding layers first
Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 228 files
Loading