From 39db621505b31fe217eed504cf4419912d5ae2cc Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 30 Jan 2026 19:25:33 -0500 Subject: [PATCH 01/26] Update [ghstack-poisoned] --- .github/workflows/metal.yml | 24 + backends/apple/metal/tests/run_metal_test.sh | 126 +++ backends/apple/metal/tests/test_modules.py | 817 +++++++++++++++++++ 3 files changed, 967 insertions(+) create mode 100755 backends/apple/metal/tests/run_metal_test.sh create mode 100644 backends/apple/metal/tests/test_modules.py diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 1e0ad2f9587..63466f36abb 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -28,6 +28,30 @@ jobs: PYTHON_EXECUTABLE=python CMAKE_ARGS="-DEXECUTORCH_BUILD_METAL=ON" ${CONDA_RUN} --no-capture-output ./install_executorch.sh echo "::endgroup::" + test-metal-modules: + name: test-metal-backend-modules + uses: pytorch/test-infra/.github/workflows/macos_job.yml@main + with: + runner: macos-m2-stable + python-version: '3.11' + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 120 + script: | + set -eux + + echo "::group::Setup ExecuTorch" + PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + echo "::endgroup::" + + echo "::group::Build Metal Runtime" + ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build + echo "::endgroup::" + + echo "::group::Run Metal Backend Module Tests" + ${CONDA_RUN} python -m unittest backends.apple.metal.tests.test_modules.TestMetalBackendModules + echo "::endgroup::" + export-model-metal-artifact: name: export-model-metal-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/backends/apple/metal/tests/run_metal_test.sh b/backends/apple/metal/tests/run_metal_test.sh new file mode 100755 index 00000000000..95c0cb1c6a7 --- /dev/null +++ b/backends/apple/metal/tests/run_metal_test.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Script to build and run Metal backend tests +# Usage: +# ./run_metal_test.sh --build # Build the Metal runtime +# ./run_metal_test.sh --run # Run inference with given model files +# ./run_metal_test.sh --check-build # Check if runtime is already built + +set -e # Exit on any error + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXECUTORCH_ROOT="$(cd "$SCRIPT_DIR/../../../.." && pwd)" +BUILD_DIR="$EXECUTORCH_ROOT/cmake-out" +EXECUTOR_RUNNER="$BUILD_DIR/executor_runner" + +# Function to check if Metal runtime is built +check_build() { + if [[ -f "$EXECUTOR_RUNNER" ]]; then + echo "true" + return 0 + else + echo "false" + return 1 + fi +} + +# Function to build the Metal runtime +build_runtime() { + echo "Building Metal runtime..." + + # Check if we're on macOS + if [[ "$(uname)" != "Darwin" ]]; then + echo "Error: Metal backend is only supported on macOS" + exit 1 + fi + + # Create build directory + mkdir -p "$BUILD_DIR" + cd "$BUILD_DIR" + + # CMake configuration for Metal backend + CMAKE_ARGS="-DEXECUTORCH_BUILD_METAL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ + -DAOTI_METAL=ON \ + -DEXECUTORCH_LOG_LEVEL=Info \ + -DCMAKE_BUILD_TYPE=Release" + + echo "Running cmake..." + eval cmake $CMAKE_ARGS "$EXECUTORCH_ROOT" + + echo "Building..." + cmake --build . -j$(sysctl -n hw.ncpu) + + cd "$EXECUTORCH_ROOT" + + if [[ -f "$EXECUTOR_RUNNER" ]]; then + echo "Build successful: $EXECUTOR_RUNNER" + else + echo "Error: Build failed - executor_runner not found" + exit 1 + fi +} + +# Function to run inference +run_inference() { + local pte_path="$1" + local ptd_path="$2" + + if [[ ! -f "$EXECUTOR_RUNNER" ]]; then + echo "Error: executor_runner not found at $EXECUTOR_RUNNER" + echo "Run '$0 --build' first to build the Metal runtime" + exit 1 + fi + + if [[ ! -f "$pte_path" ]]; then + echo "Error: PTE file not found: $pte_path" + exit 1 + fi + + if [[ ! -f "$ptd_path" ]]; then + echo "Error: PTD file not found: $ptd_path" + exit 1 + fi + + echo "Running inference..." + echo " PTE: $pte_path" + echo " PTD: $ptd_path" + + "$EXECUTOR_RUNNER" --model_path "$pte_path" --data_path "$ptd_path" +} + +# Parse command line arguments +case "$1" in + --build) + build_runtime + ;; + --run) + if [[ -z "$2" ]] || [[ -z "$3" ]]; then + echo "Usage: $0 --run " + exit 1 + fi + run_inference "$2" "$3" + ;; + --check-build) + check_build + ;; + *) + echo "Metal Backend Test Runner" + echo "" + echo "Usage:" + echo " $0 --build Build the Metal runtime" + echo " $0 --run Run inference with given model files" + echo " $0 --check-build Check if runtime is already built" + exit 1 + ;; +esac diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py new file mode 100644 index 00000000000..424d736b3b7 --- /dev/null +++ b/backends/apple/metal/tests/test_modules.py @@ -0,0 +1,817 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unit tests for Metal backend modules. + +These tests export and run various model modules through the Metal backend +to verify that the export and execution pipeline works correctly. + +These tests require MPS to be available. On systems without MPS support, +the export tests will be skipped. +""" + +import os +import platform +import subprocess +import tempfile +import unittest +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +from executorch.backends.apple.metal.metal_backend import MetalBackend +from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner +from executorch.exir import to_edge_transform_and_lower +from torch import nn +from torch.export import export +from torch.nn.attention import SDPBackend + + +# Check if MPS is available for export tests +MPS_AVAILABLE = torch.backends.mps.is_available() +IS_MACOS = platform.system() == "Darwin" +SKIP_EXPORT_TESTS = not MPS_AVAILABLE +SKIP_REASON = "MPS not available - Metal export tests require MPS support" + +# Paths +TESTS_DIR = Path(__file__).parent +EXECUTORCH_ROOT = TESTS_DIR.parent.parent.parent.parent +BUILD_DIR = EXECUTORCH_ROOT / "cmake-out" +EXECUTOR_RUNNER = BUILD_DIR / "executor_runner" +RUN_METAL_TEST_SCRIPT = TESTS_DIR / "run_metal_test.sh" + +# Check if executor_runner is built +EXECUTOR_RUNNER_AVAILABLE = EXECUTOR_RUNNER.exists() +SKIP_RUNTIME_TESTS = not EXECUTOR_RUNNER_AVAILABLE or SKIP_EXPORT_TESTS +SKIP_RUNTIME_REASON = ( + "executor_runner not built - run 'backends/apple/metal/tests/run_metal_test.sh --build'" + if not EXECUTOR_RUNNER_AVAILABLE + else SKIP_REASON +) + +# Data types to test +DTYPES = [torch.float32, torch.bfloat16] + +# Map dtype to short name for test method naming +DTYPE_NAMES = { + torch.float32: "float32", + torch.bfloat16: "bfloat16", +} + +# Registry mapping model names to their configurations +MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {} + + +# ============================================================================= +# Model Definitions +# ============================================================================= + + +class Add(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + +MODULE_REGISTRY["add"] = { + "model_class": Add, + "input_shapes": [(10,), (10,)], + "description": "Simple tensor addition model", +} + + +# ------------------------------------------------------------------------- +# Matrix Multiplication Modules +# ------------------------------------------------------------------------- + +class Mm(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x.mm(y) + +MODULE_REGISTRY["mm"] = { + "model_class": Mm, + "input_shapes": [(3, 4), (4, 5)], + "description": "Simple mm layer model", +} + +# ------------------------------------------------------------------------- +class MmWeights(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.arange(20, dtype=torch.float).reshape(4, 5)) + + def forward(self, x: torch.Tensor): + return x.mm(self.weight) + +MODULE_REGISTRY["mm_weights"] = { + "model_class": MmWeights, + "input_shapes": [(3, 4)], + "description": "Matrix multiplication with weight parameter", +} + +# ------------------------------------------------------------------------- +class TwoMm(nn.Module): + def __init__(self): + super().__init__() + self.left_weight = nn.Parameter( + torch.arange(20, dtype=torch.float).reshape(4, 5) + ) + self.right_weight = nn.Parameter( + torch.arange(42, dtype=torch.float).reshape(6, 7) + ) + + def forward(self, x: torch.Tensor): + return self.left_weight.mm(x).mm(self.right_weight) + +MODULE_REGISTRY["two_mm"] = { + "model_class": TwoMm, + "input_shapes": [(5, 6)], + "description": "Two consecutive matrix multiplications", +} + +# ------------------------------------------------------------------------- +class ElementwiseMmReduction(nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + x1 = x.sin() + x + y2 = y.cos() + 3 + z = x1.mm(y2) + return z + z.sum() + +MODULE_REGISTRY["elementwise_mm_reduction"] = { + "model_class": ElementwiseMmReduction, + "input_shapes": [(11, 45), (45, 8)], + "description": "Combining mm with elementwise and reduction ops", +} + + +# ------------------------------------------------------------------------- +# Linear Modules +# ------------------------------------------------------------------------- + +class LinearNoBias(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(7, 101, bias=False) + + def forward(self, x: torch.Tensor): + return self.linear(x) + +MODULE_REGISTRY["linear_nobias"] = { + "model_class": LinearNoBias, + "input_shapes": [(127, 7)], + "description": "Simple linear layer model with no bias", +} + + +# ------------------------------------------------------------------------- +# Convolution Modules +# ------------------------------------------------------------------------- + +class SingleConv2d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d( + in_channels=3, out_channels=5, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor): + return self.conv(x) + +MODULE_REGISTRY["conv2d"] = { + "model_class": SingleConv2d, + "input_shapes": [(4, 3, 8, 8)], + "description": "Single Conv2d layer model", +} + +# ------------------------------------------------------------------------- +class DepthwiseConv(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=32, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + +MODULE_REGISTRY["depthwise_conv"] = { + "model_class": DepthwiseConv, + "input_shapes": [(1, 32, 112, 112)], + "description": "Single Depthwise Conv2d layer model", +} + +# ------------------------------------------------------------------------- +class SmallConv1d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + in_channels=8, + out_channels=6, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + +MODULE_REGISTRY["small_conv1d"] = { + "model_class": SmallConv1d, + "input_shapes": [(1, 8, 5)], + "description": "Conv1d layer with 8 input channels, 6 output channels", +} + +# ------------------------------------------------------------------------- +class MockConv1d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + in_channels=80, + out_channels=384, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + +MODULE_REGISTRY["conv1d"] = { + "model_class": MockConv1d, + "input_shapes": [(1, 80, 3000)], + "description": "Conv1d layer with 80 input channels, 384 output channels", +} + +# ------------------------------------------------------------------------- +class VoxtralConv1d(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + in_channels=128, + out_channels=1280, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + bias=False, + ) + + def forward(self, x): + return self.conv(x) + +MODULE_REGISTRY["voxtral_conv1d"] = { + "model_class": VoxtralConv1d, + "input_shapes": [(10, 128, 3000)], + "description": "Conv1d layer with 128 input channels, 1280 output channels", +} + + +# ------------------------------------------------------------------------- +# Attention (SDPA) Modules +# ------------------------------------------------------------------------- + +class SimpleSDPA(nn.Module): + """Minimal SDPA test model.""" + + def forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + return output + +MODULE_REGISTRY["sdpa"] = { + "model_class": SimpleSDPA, + "input_shapes": [(2, 4, 16, 64), (2, 4, 16, 64), (2, 4, 16, 64)], + "description": "Simple Scaled Dot Product Attention model", +} + +# ------------------------------------------------------------------------- +class AddSDPA(nn.Module): + """SDPA model with Q, K, V as parameters that adds input to SDPA output.""" + + def __init__(self, batch_size=2, num_heads=4, seq_len=16, head_dim=64): + super().__init__() + self.query = nn.Parameter( + torch.randn(batch_size, num_heads, seq_len, head_dim) + ) + self.key = nn.Parameter(torch.randn(batch_size, num_heads, seq_len, head_dim)) + self.value = nn.Parameter( + torch.randn(batch_size, num_heads, seq_len, head_dim) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + sdpa_output = torch.nn.functional.scaled_dot_product_attention( + self.query, self.key, self.value, dropout_p=0.0, is_causal=False + ) + return sdpa_output + x + +MODULE_REGISTRY["add_sdpa"] = { + "model_class": AddSDPA, + "input_shapes": [(2, 4, 16, 64)], + "description": "SDPA model with Q,K,V as parameters that adds input to output", +} + +# ------------------------------------------------------------------------- +class BaseAddStridedSDPA(nn.Module): + """SDPA model with strided Q, K, V parameters.""" + + def __init__(self, q_size, k_size, v_size, q_stride, k_stride, v_stride, attn_mask_size=None): + super().__init__() + self.q_size = q_size + self.k_size = k_size + self.v_size = v_size + self.q_stride = q_stride + self.k_stride = k_stride + self.v_stride = v_stride + self.attn_mask_size = attn_mask_size + + self.query = nn.Parameter(torch.randn(q_size)) + self.key = nn.Parameter(torch.randn(k_size)) + self.value = nn.Parameter(torch.randn(v_size)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + query = torch.as_strided(self.query, size=self.q_size, stride=self.q_stride) + key = torch.as_strided(self.key, size=self.k_size, stride=self.k_stride) + value = torch.as_strided(self.value, size=self.v_size, stride=self.v_stride) + attn_mask = None + if self.attn_mask_size: + attn_mask = torch.zeros(self.attn_mask_size) + + sdpa_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask, dropout_p=0.0, is_causal=False, scale=1.0 + ) + return sdpa_output + x + +# ------------------------------------------------------------------------- +class AddStridedSDPA(BaseAddStridedSDPA): + def __init__(self): + super().__init__( + q_size=(10, 20, 1500, 64), + k_size=(10, 20, 1500, 64), + v_size=(10, 20, 1500, 64), + q_stride=(1920000, 64, 1280, 1), + k_stride=(1920000, 64, 1280, 1), + v_stride=(1920000, 64, 1280, 1), + ) + +MODULE_REGISTRY["audio_encoder_sdpa1"] = { + "model_class": AddStridedSDPA, + "input_shapes": [(10, 20, 1500, 64)], + "description": "Audio Encoder model with strided SDPA", +} + +# ------------------------------------------------------------------------- +class AddStridedSDPA1(BaseAddStridedSDPA): + def __init__(self): + super().__init__( + q_size=(1, 20, 1, 64), + k_size=(1, 20, 1500, 64), + v_size=(1, 20, 1500, 64), + q_stride=(1280, 64, 1280, 1), + k_stride=(1920000, 64, 1280, 1), + v_stride=(1920000, 64, 1280, 1), + ) + +MODULE_REGISTRY["whisper_strided_sdpa1"] = { + "model_class": AddStridedSDPA1, + "input_shapes": [(1, 20, 1, 64)], + "description": "Whisper-like strided SDPA variant 1", +} + +# ------------------------------------------------------------------------- +class AddStridedSDPA2(BaseAddStridedSDPA): + def __init__(self): + super().__init__( + q_size=(1, 20, 1, 64), + k_size=(1, 20, 1024, 64), + v_size=(1, 20, 1024, 64), + q_stride=(1280, 64, 1280, 1), + k_stride=(1310720, 65536, 64, 1), + v_stride=(1310720, 65536, 64, 1), + attn_mask_size=(1, 1, 1, 1024), + ) + +MODULE_REGISTRY["whisper_strided_sdpa2"] = { + "model_class": AddStridedSDPA2, + "input_shapes": [(1, 20, 1, 64)], + "description": "Whisper-like strided SDPA variant 2", +} + + +# ------------------------------------------------------------------------- +# Normalization Modules +# ------------------------------------------------------------------------- + +class BatchNorm(nn.Module): + def __init__(self): + super().__init__() + self.bn = nn.BatchNorm2d(num_features=16) + + def forward(self, x): + return self.bn(x) + +MODULE_REGISTRY["batchnorm"] = { + "model_class": BatchNorm, + "input_shapes": [(1, 16, 32, 32)], + "description": "Single BatchNorm2d layer model", +} + + +# ------------------------------------------------------------------------- +# Block/Composite Modules +# ------------------------------------------------------------------------- + +class SingleResNetBlock(nn.Module): + def __init__(self, in_channels=64, out_channels=64, stride=1): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.skip_connection = None + if stride != 1 or in_channels != out_channels: + self.skip_connection = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.skip_connection is not None: + identity = self.skip_connection(x) + + out += identity + out = self.relu(out) + + return out + +MODULE_REGISTRY["single_resnet_block"] = { + "model_class": SingleResNetBlock, + "input_shapes": [(1, 64, 8, 8)], + "description": "Single ResNet block with skip connection", +} + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def get_model_and_inputs( + model_name: str, dtype: torch.dtype = torch.float32 +) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]: + """Get model and example inputs based on model name.""" + if model_name not in MODULE_REGISTRY: + available_models = ", ".join(MODULE_REGISTRY.keys()) + raise ValueError( + f"Unsupported model: {model_name}. Available models: {available_models}" + ) + + model_config = MODULE_REGISTRY[model_name] + model_class = model_config["model_class"] + input_shapes = model_config["input_shapes"] + + model = model_class().eval() + if dtype is not None: + model = model.to(dtype) + + example_inputs = tuple( + torch.randn(*shape, dtype=dtype) for shape in input_shapes + ) + + return model, example_inputs + + +def export_model_to_metal( + model: nn.Module, example_inputs: Tuple[torch.Tensor, ...] +) -> Any: + """Export model through the Metal backend pipeline.""" + method_name = "forward" + + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + aten_dialect = export(model, example_inputs, strict=False) + + edge_program = to_edge_transform_and_lower( + aten_dialect, + partitioner=[ + MetalPartitioner( + [MetalBackend.generate_method_name_compile_spec(method_name)] + ) + ], + ) + + executorch_program = edge_program.to_executorch() + return executorch_program + + +def export_model_to_files( + model: nn.Module, + example_inputs: Tuple[torch.Tensor, ...], + output_dir: Path, + model_name: str, +) -> Tuple[Path, Path, torch.Tensor]: + """ + Export model to .pte and .ptd files, and compute expected output. + + Returns: + Tuple of (pte_path, ptd_path, expected_output) + """ + # Compute expected output using all-ones input (matching export_aoti_metal.py) + all_ones_input = tuple(torch.ones_like(inp) for inp in example_inputs) + with torch.no_grad(): + expected_output = model(*all_ones_input) + + # Export to executorch + executorch_program = export_model_to_metal(model, example_inputs) + + # Save .pte file + pte_path = output_dir / f"{model_name}.pte" + with open(pte_path, "wb") as f: + f.write(executorch_program.buffer) + + # Save .ptd file (tensor data) + executorch_program.write_tensor_data_to_file(str(output_dir)) + ptd_path = output_dir / "aoti_metal_blob.ptd" + + return pte_path, ptd_path, expected_output + + +def run_executor_runner(pte_path: Path, ptd_path: Path) -> bool: + """ + Run the executor_runner binary with the given model files. + + Returns: + True if execution succeeded, False otherwise. + """ + if not EXECUTOR_RUNNER.exists(): + raise RuntimeError( + f"executor_runner not found at {EXECUTOR_RUNNER}. " + f"Run '{RUN_METAL_TEST_SCRIPT} --build' to build." + ) + + cmd = [ + str(EXECUTOR_RUNNER), + "--model_path", str(pte_path), + "--data_path", str(ptd_path), + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=60, + cwd=str(EXECUTORCH_ROOT), + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + except Exception: + return False + + +def read_output_file(filepath: Path) -> Optional[np.ndarray]: + """Read comma-separated output values from a file.""" + try: + with open(filepath, "r") as f: + content = f.read().strip() + if not content: + return None + values = [float(x.strip()) for x in content.split(",") if x.strip()] + return np.array(values) + except (FileNotFoundError, ValueError): + return None + + +def compare_outputs( + expected: torch.Tensor, + runtime_output_file: Path, + atol: float = 1e-5, + rtol: float = 1e-5, +) -> Tuple[bool, Optional[float], Optional[float]]: + """ + Compare expected PyTorch output with runtime output from file. + + Returns: + Tuple of (is_close, max_atol, max_rtol) + """ + runtime_values = read_output_file(runtime_output_file) + if runtime_values is None: + return False, None, None + + # Flatten expected output + if isinstance(expected, tuple): + expected_values = np.concatenate([t.flatten().numpy() for t in expected]) + else: + expected_values = expected.flatten().numpy() + + if len(runtime_values) != len(expected_values): + return False, None, None + + # Calculate tolerances + abs_diff = np.abs(runtime_values - expected_values) + max_atol_val = np.max(abs_diff) + + eps = 1e-8 + denominator = np.maximum( + np.maximum(np.abs(runtime_values), np.abs(expected_values)), eps + ) + rel_diff = abs_diff / denominator + max_rtol_val = np.max(rel_diff) + + is_close = np.allclose(runtime_values, expected_values, atol=atol, rtol=rtol) + + return is_close, max_atol_val, max_rtol_val + + +# ============================================================================= +# Test Class +# ============================================================================= + + +class TestMetalBackendModules(unittest.TestCase): + """ + Test Metal backend modules export and execution. + + Each test exports a model through the Metal backend and verifies: + 1. The export process completes without errors + 2. The exported program has non-zero buffer size + 3. The runtime output matches the expected PyTorch output + """ + + def _test_module_export( + self, model_name: str, dtype: torch.dtype = torch.float32 + ) -> None: + """Generic test for module export.""" + if SKIP_EXPORT_TESTS: + self.skipTest(SKIP_REASON) + + model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) + + # Verify model forward pass works before export + with torch.no_grad(): + model_output = model(*example_inputs) + + self.assertIsNotNone( + model_output, + f"{model_name} ({DTYPE_NAMES[dtype]}): Forward pass returned None", + ) + + # Export to Metal backend + executorch_program = export_model_to_metal(model, example_inputs) + + self.assertIsNotNone( + executorch_program, + f"{model_name} ({DTYPE_NAMES[dtype]}): Export returned None", + ) + self.assertGreater( + len(executorch_program.buffer), + 0, + f"{model_name} ({DTYPE_NAMES[dtype]}): Exported buffer is empty", + ) + + def _test_module_output_consistency( + self, model_name: str, dtype: torch.dtype = torch.float32 + ) -> None: + """ + Test that Metal backend runtime output matches PyTorch output. + + This test: + 1. Exports the model to .pte and .ptd files + 2. Runs the model using executor_runner + 3. Compares the runtime output with expected PyTorch output + """ + if SKIP_RUNTIME_TESTS: + self.skipTest(SKIP_RUNTIME_REASON) + + model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) + + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + + # Create aoti_debug_data directory for output files + debug_dir = tmpdir_path / "aoti_debug_data" + debug_dir.mkdir(exist_ok=True) + + # Export model and get expected output + pte_path, ptd_path, expected_output = export_model_to_files( + model, example_inputs, tmpdir_path, model_name + ) + + self.assertTrue( + pte_path.exists(), + f"{model_name}: PTE file not created at {pte_path}", + ) + self.assertTrue( + ptd_path.exists(), + f"{model_name}: PTD file not created at {ptd_path}", + ) + + # Run executor_runner + success = run_executor_runner(pte_path, ptd_path) + self.assertTrue( + success, + f"{model_name}: executor_runner failed", + ) + + # Compare outputs + runtime_output_file = debug_dir / "final_runtime_output.txt" + + if runtime_output_file.exists(): + is_close, max_atol, max_rtol = compare_outputs( + expected_output, runtime_output_file + ) + + self.assertTrue( + is_close, + f"{model_name} ({DTYPE_NAMES[dtype]}): Output mismatch - max_atol={max_atol}, max_rtol={max_rtol}", + ) + + +# ============================================================================= +# Dynamically generate test methods for each module and dtype in MODULE_REGISTRY +# ============================================================================= + + +def _make_export_test(model_name: str, dtype: torch.dtype): + """Factory function to create an export test method for a given model and dtype.""" + def test_method(self): + self._test_module_export(model_name, dtype) + dtype_name = DTYPE_NAMES[dtype] + test_method.__doc__ = f"Test {model_name} module export with {dtype_name}." + return test_method + + +def _make_output_consistency_test(model_name: str, dtype: torch.dtype): + """Factory function to create an output consistency test method for a given model and dtype.""" + def test_method(self): + self._test_module_output_consistency(model_name, dtype) + dtype_name = DTYPE_NAMES[dtype] + test_method.__doc__ = f"Test {model_name} module output consistency with {dtype_name}." + return test_method + + +# Add export and output consistency tests for each module and dtype in the registry +for _model_name in MODULE_REGISTRY: + for _dtype in DTYPES: + _dtype_name = DTYPE_NAMES[_dtype] + + # Create export test: test___export + _export_test_name = f"test_{_model_name}_{_dtype_name}_export" + setattr( + TestMetalBackendModules, + _export_test_name, + _make_export_test(_model_name, _dtype), + ) + + # Create output consistency test: test___output_consistency + _consistency_test_name = f"test_{_model_name}_{_dtype_name}_output_consistency" + setattr( + TestMetalBackendModules, + _consistency_test_name, + _make_output_consistency_test(_model_name, _dtype), + ) + + +if __name__ == "__main__": + unittest.main() From 0ed7c5c778087f16c3b06ab9463964f0b1e5287f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 30 Jan 2026 21:59:19 -0500 Subject: [PATCH 02/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/run_metal_test.sh | 2 +- backends/apple/metal/tests/test_modules.py | 166 ++++++++++++++----- 2 files changed, 121 insertions(+), 47 deletions(-) diff --git a/backends/apple/metal/tests/run_metal_test.sh b/backends/apple/metal/tests/run_metal_test.sh index 95c0cb1c6a7..9595cbf0c3d 100755 --- a/backends/apple/metal/tests/run_metal_test.sh +++ b/backends/apple/metal/tests/run_metal_test.sh @@ -56,7 +56,7 @@ build_runtime() { -DCMAKE_BUILD_TYPE=Release" echo "Running cmake..." - eval cmake $CMAKE_ARGS "$EXECUTORCH_ROOT" + cmake $CMAKE_ARGS "$EXECUTORCH_ROOT" echo "Building..." cmake --build . -j$(sysctl -n hw.ncpu) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 424d736b3b7..c97298a6bc2 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -38,6 +38,9 @@ SKIP_EXPORT_TESTS = not MPS_AVAILABLE SKIP_REASON = "MPS not available - Metal export tests require MPS support" +# Check if running in CI (GitHub Actions) +IS_CI = os.environ.get("GITHUB_ACTIONS") == "true" + # Paths TESTS_DIR = Path(__file__).parent EXECUTORCH_ROOT = TESTS_DIR.parent.parent.parent.parent @@ -45,6 +48,12 @@ EXECUTOR_RUNNER = BUILD_DIR / "executor_runner" RUN_METAL_TEST_SCRIPT = TESTS_DIR / "run_metal_test.sh" +# Test output directory - use current working directory in CI for reliable write access +if IS_CI: + TEST_OUTPUT_BASE_DIR = Path.cwd() / "aoti_debug_data" +else: + TEST_OUTPUT_BASE_DIR = None # Will use tempfile.TemporaryDirectory + # Check if executor_runner is built EXECUTOR_RUNNER_AVAILABLE = EXECUTOR_RUNNER.exists() SKIP_RUNTIME_TESTS = not EXECUTOR_RUNNER_AVAILABLE or SKIP_EXPORT_TESTS @@ -76,6 +85,7 @@ class Add(nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): return x + y + MODULE_REGISTRY["add"] = { "model_class": Add, "input_shapes": [(10,), (10,)], @@ -87,16 +97,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # Matrix Multiplication Modules # ------------------------------------------------------------------------- + class Mm(nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): return x.mm(y) + MODULE_REGISTRY["mm"] = { "model_class": Mm, "input_shapes": [(3, 4), (4, 5)], "description": "Simple mm layer model", } + # ------------------------------------------------------------------------- class MmWeights(nn.Module): def __init__(self): @@ -106,12 +119,14 @@ def __init__(self): def forward(self, x: torch.Tensor): return x.mm(self.weight) + MODULE_REGISTRY["mm_weights"] = { "model_class": MmWeights, "input_shapes": [(3, 4)], "description": "Matrix multiplication with weight parameter", } + # ------------------------------------------------------------------------- class TwoMm(nn.Module): def __init__(self): @@ -126,12 +141,14 @@ def __init__(self): def forward(self, x: torch.Tensor): return self.left_weight.mm(x).mm(self.right_weight) + MODULE_REGISTRY["two_mm"] = { "model_class": TwoMm, "input_shapes": [(5, 6)], "description": "Two consecutive matrix multiplications", } + # ------------------------------------------------------------------------- class ElementwiseMmReduction(nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): @@ -140,6 +157,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): z = x1.mm(y2) return z + z.sum() + MODULE_REGISTRY["elementwise_mm_reduction"] = { "model_class": ElementwiseMmReduction, "input_shapes": [(11, 45), (45, 8)], @@ -151,6 +169,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # Linear Modules # ------------------------------------------------------------------------- + class LinearNoBias(nn.Module): def __init__(self): super().__init__() @@ -159,6 +178,7 @@ def __init__(self): def forward(self, x: torch.Tensor): return self.linear(x) + MODULE_REGISTRY["linear_nobias"] = { "model_class": LinearNoBias, "input_shapes": [(127, 7)], @@ -170,6 +190,7 @@ def forward(self, x: torch.Tensor): # Convolution Modules # ------------------------------------------------------------------------- + class SingleConv2d(nn.Module): def __init__(self): super().__init__() @@ -180,12 +201,14 @@ def __init__(self): def forward(self, x: torch.Tensor): return self.conv(x) + MODULE_REGISTRY["conv2d"] = { "model_class": SingleConv2d, "input_shapes": [(4, 3, 8, 8)], "description": "Single Conv2d layer model", } + # ------------------------------------------------------------------------- class DepthwiseConv(nn.Module): def __init__(self): @@ -204,12 +227,14 @@ def __init__(self): def forward(self, x): return self.conv(x) + MODULE_REGISTRY["depthwise_conv"] = { "model_class": DepthwiseConv, "input_shapes": [(1, 32, 112, 112)], "description": "Single Depthwise Conv2d layer model", } + # ------------------------------------------------------------------------- class SmallConv1d(nn.Module): def __init__(self): @@ -228,14 +253,16 @@ def __init__(self): def forward(self, x): return self.conv(x) + MODULE_REGISTRY["small_conv1d"] = { "model_class": SmallConv1d, "input_shapes": [(1, 8, 5)], "description": "Conv1d layer with 8 input channels, 6 output channels", } + # ------------------------------------------------------------------------- -class MockConv1d(nn.Module): +class MediumConv1d(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv1d( @@ -252,12 +279,14 @@ def __init__(self): def forward(self, x): return self.conv(x) + MODULE_REGISTRY["conv1d"] = { - "model_class": MockConv1d, + "model_class": MediumConv1d, "input_shapes": [(1, 80, 3000)], "description": "Conv1d layer with 80 input channels, 384 output channels", } + # ------------------------------------------------------------------------- class VoxtralConv1d(nn.Module): def __init__(self): @@ -276,6 +305,7 @@ def __init__(self): def forward(self, x): return self.conv(x) + MODULE_REGISTRY["voxtral_conv1d"] = { "model_class": VoxtralConv1d, "input_shapes": [(10, 128, 3000)], @@ -287,6 +317,7 @@ def forward(self, x): # Attention (SDPA) Modules # ------------------------------------------------------------------------- + class SimpleSDPA(nn.Module): """Minimal SDPA test model.""" @@ -298,25 +329,23 @@ def forward( ) return output + MODULE_REGISTRY["sdpa"] = { "model_class": SimpleSDPA, "input_shapes": [(2, 4, 16, 64), (2, 4, 16, 64), (2, 4, 16, 64)], "description": "Simple Scaled Dot Product Attention model", } + # ------------------------------------------------------------------------- class AddSDPA(nn.Module): """SDPA model with Q, K, V as parameters that adds input to SDPA output.""" def __init__(self, batch_size=2, num_heads=4, seq_len=16, head_dim=64): super().__init__() - self.query = nn.Parameter( - torch.randn(batch_size, num_heads, seq_len, head_dim) - ) + self.query = nn.Parameter(torch.randn(batch_size, num_heads, seq_len, head_dim)) self.key = nn.Parameter(torch.randn(batch_size, num_heads, seq_len, head_dim)) - self.value = nn.Parameter( - torch.randn(batch_size, num_heads, seq_len, head_dim) - ) + self.value = nn.Parameter(torch.randn(batch_size, num_heads, seq_len, head_dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: sdpa_output = torch.nn.functional.scaled_dot_product_attention( @@ -324,17 +353,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return sdpa_output + x + MODULE_REGISTRY["add_sdpa"] = { "model_class": AddSDPA, "input_shapes": [(2, 4, 16, 64)], "description": "SDPA model with Q,K,V as parameters that adds input to output", } + # ------------------------------------------------------------------------- class BaseAddStridedSDPA(nn.Module): """SDPA model with strided Q, K, V parameters.""" - def __init__(self, q_size, k_size, v_size, q_stride, k_stride, v_stride, attn_mask_size=None): + def __init__( + self, q_size, k_size, v_size, q_stride, k_stride, v_stride, attn_mask_size=None + ): super().__init__() self.q_size = q_size self.k_size = k_size @@ -361,6 +394,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return sdpa_output + x + # ------------------------------------------------------------------------- class AddStridedSDPA(BaseAddStridedSDPA): def __init__(self): @@ -373,12 +407,14 @@ def __init__(self): v_stride=(1920000, 64, 1280, 1), ) + MODULE_REGISTRY["audio_encoder_sdpa1"] = { "model_class": AddStridedSDPA, "input_shapes": [(10, 20, 1500, 64)], "description": "Audio Encoder model with strided SDPA", } + # ------------------------------------------------------------------------- class AddStridedSDPA1(BaseAddStridedSDPA): def __init__(self): @@ -391,12 +427,14 @@ def __init__(self): v_stride=(1920000, 64, 1280, 1), ) + MODULE_REGISTRY["whisper_strided_sdpa1"] = { "model_class": AddStridedSDPA1, "input_shapes": [(1, 20, 1, 64)], "description": "Whisper-like strided SDPA variant 1", } + # ------------------------------------------------------------------------- class AddStridedSDPA2(BaseAddStridedSDPA): def __init__(self): @@ -410,6 +448,7 @@ def __init__(self): attn_mask_size=(1, 1, 1, 1024), ) + MODULE_REGISTRY["whisper_strided_sdpa2"] = { "model_class": AddStridedSDPA2, "input_shapes": [(1, 20, 1, 64)], @@ -421,6 +460,7 @@ def __init__(self): # Normalization Modules # ------------------------------------------------------------------------- + class BatchNorm(nn.Module): def __init__(self): super().__init__() @@ -429,6 +469,7 @@ def __init__(self): def forward(self, x): return self.bn(x) + MODULE_REGISTRY["batchnorm"] = { "model_class": BatchNorm, "input_shapes": [(1, 16, 32, 32)], @@ -440,6 +481,7 @@ def forward(self, x): # Block/Composite Modules # ------------------------------------------------------------------------- + class SingleResNetBlock(nn.Module): def __init__(self, in_channels=64, out_channels=64, stride=1): super().__init__() @@ -485,6 +527,7 @@ def forward(self, x): return out + MODULE_REGISTRY["single_resnet_block"] = { "model_class": SingleResNetBlock, "input_shapes": [(1, 64, 8, 8)], @@ -515,9 +558,7 @@ def get_model_and_inputs( if dtype is not None: model = model.to(dtype) - example_inputs = tuple( - torch.randn(*shape, dtype=dtype) for shape in input_shapes - ) + example_inputs = tuple(torch.randn(*shape, dtype=dtype) for shape in input_shapes) return model, example_inputs @@ -576,12 +617,13 @@ def export_model_to_files( return pte_path, ptd_path, expected_output -def run_executor_runner(pte_path: Path, ptd_path: Path) -> bool: +def run_executor_runner(pte_path: Path, ptd_path: Path) -> Tuple[bool, Optional[str]]: """ Run the executor_runner binary with the given model files. Returns: - True if execution succeeded, False otherwise. + Tuple of (success, error_message). If success is True, error_message is None. + If success is False, error_message contains details about the failure. """ if not EXECUTOR_RUNNER.exists(): raise RuntimeError( @@ -591,8 +633,10 @@ def run_executor_runner(pte_path: Path, ptd_path: Path) -> bool: cmd = [ str(EXECUTOR_RUNNER), - "--model_path", str(pte_path), - "--data_path", str(ptd_path), + "--model_path", + str(pte_path), + "--data_path", + str(ptd_path), ] try: @@ -603,11 +647,17 @@ def run_executor_runner(pte_path: Path, ptd_path: Path) -> bool: timeout=60, cwd=str(EXECUTORCH_ROOT), ) - return result.returncode == 0 - except subprocess.TimeoutExpired: - return False - except Exception: - return False + if result.returncode == 0: + return True, None + else: + error_msg = ( + f"executor_runner exited with code {result.returncode}\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + ) + return False, error_msg + except subprocess.TimeoutExpired as e: + return False, f"executor_runner timed out after 60 seconds: {e}" def read_output_file(filepath: Path) -> Optional[np.ndarray]: @@ -639,11 +689,14 @@ def compare_outputs( if runtime_values is None: return False, None, None - # Flatten expected output + # Flatten expected output and move to CPU for numpy conversion + # (required when tensor is on MPS device) if isinstance(expected, tuple): - expected_values = np.concatenate([t.flatten().numpy() for t in expected]) + expected_values = np.concatenate( + [t.detach().cpu().flatten().numpy() for t in expected] + ) else: - expected_values = expected.flatten().numpy() + expected_values = expected.detach().cpu().flatten().numpy() if len(runtime_values) != len(expected_values): return False, None, None @@ -725,47 +778,62 @@ def _test_module_output_consistency( self.skipTest(SKIP_RUNTIME_REASON) model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) + dtype_name = DTYPE_NAMES[dtype] + test_subdir_name = f"{model_name}_{dtype_name}" - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir_path = Path(tmpdir) - - # Create aoti_debug_data directory for output files - debug_dir = tmpdir_path / "aoti_debug_data" - debug_dir.mkdir(exist_ok=True) + def run_test_in_directory(test_dir: Path) -> None: + """Run the actual test logic in the given directory.""" + # Create model output directory: aoti_debug_data/_/ + model_output_dir = test_dir / test_subdir_name + model_output_dir.mkdir(parents=True, exist_ok=True) # Export model and get expected output pte_path, ptd_path, expected_output = export_model_to_files( - model, example_inputs, tmpdir_path, model_name + model, example_inputs, model_output_dir, model_name ) self.assertTrue( pte_path.exists(), - f"{model_name}: PTE file not created at {pte_path}", + f"{model_name} ({dtype_name}): PTE file not created at {pte_path}", ) self.assertTrue( ptd_path.exists(), - f"{model_name}: PTD file not created at {ptd_path}", + f"{model_name} ({dtype_name}): PTD file not created at {ptd_path}", ) # Run executor_runner - success = run_executor_runner(pte_path, ptd_path) + success, error_msg = run_executor_runner(pte_path, ptd_path) self.assertTrue( success, - f"{model_name}: executor_runner failed", + f"{model_name} ({dtype_name}): executor_runner failed\n{error_msg}", ) - # Compare outputs - runtime_output_file = debug_dir / "final_runtime_output.txt" + # Compare outputs - executor_runner writes to aoti_debug_data/ in cwd + # In CI, this is TEST_OUTPUT_BASE_DIR; locally it may vary + runtime_output_file = model_output_dir / "final_runtime_output.txt" - if runtime_output_file.exists(): - is_close, max_atol, max_rtol = compare_outputs( - expected_output, runtime_output_file - ) + self.assertTrue( + runtime_output_file.exists(), + f"{model_name} ({dtype_name}): Runtime output file not created at {runtime_output_file}", + ) - self.assertTrue( - is_close, - f"{model_name} ({DTYPE_NAMES[dtype]}): Output mismatch - max_atol={max_atol}, max_rtol={max_rtol}", - ) + is_close, max_atol, max_rtol = compare_outputs( + expected_output, runtime_output_file + ) + + self.assertTrue( + is_close, + f"{model_name} ({dtype_name}): Output mismatch - max_atol={max_atol}, max_rtol={max_rtol}", + ) + + if IS_CI: + # In CI, use a persistent directory in the current working directory + TEST_OUTPUT_BASE_DIR.mkdir(parents=True, exist_ok=True) + run_test_in_directory(TEST_OUTPUT_BASE_DIR) + else: + # Locally, use a temporary directory that gets cleaned up + with tempfile.TemporaryDirectory() as tmpdir: + run_test_in_directory(Path(tmpdir)) # ============================================================================= @@ -775,8 +843,10 @@ def _test_module_output_consistency( def _make_export_test(model_name: str, dtype: torch.dtype): """Factory function to create an export test method for a given model and dtype.""" + def test_method(self): self._test_module_export(model_name, dtype) + dtype_name = DTYPE_NAMES[dtype] test_method.__doc__ = f"Test {model_name} module export with {dtype_name}." return test_method @@ -784,10 +854,14 @@ def test_method(self): def _make_output_consistency_test(model_name: str, dtype: torch.dtype): """Factory function to create an output consistency test method for a given model and dtype.""" + def test_method(self): self._test_module_output_consistency(model_name, dtype) + dtype_name = DTYPE_NAMES[dtype] - test_method.__doc__ = f"Test {model_name} module output consistency with {dtype_name}." + test_method.__doc__ = ( + f"Test {model_name} module output consistency with {dtype_name}." + ) return test_method From b4310cc6f03ad25fd8e082b5d4995a18eb4e4491 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 30 Jan 2026 22:01:15 -0500 Subject: [PATCH 03/26] Update [ghstack-poisoned] --- .github/workflows/metal.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 63466f36abb..50ab0a70e1c 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -54,7 +54,7 @@ jobs: export-model-metal-artifact: name: export-model-metal-artifact - # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) + # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit From 94c823c8f1effbcf7f8c0bff3aa3db2c0ef570f3 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 30 Jan 2026 23:32:54 -0500 Subject: [PATCH 04/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/run_metal_test.sh | 19 +- backends/apple/metal/tests/test_modules.py | 241 ++++++++++++++++--- 2 files changed, 212 insertions(+), 48 deletions(-) diff --git a/backends/apple/metal/tests/run_metal_test.sh b/backends/apple/metal/tests/run_metal_test.sh index 9595cbf0c3d..0f70c20ea4e 100755 --- a/backends/apple/metal/tests/run_metal_test.sh +++ b/backends/apple/metal/tests/run_metal_test.sh @@ -8,7 +8,7 @@ # Script to build and run Metal backend tests # Usage: # ./run_metal_test.sh --build # Build the Metal runtime -# ./run_metal_test.sh --run # Run inference with given model files +# ./run_metal_test.sh --run # Run inference with given model file # ./run_metal_test.sh --check-build # Check if runtime is already built set -e # Exit on any error @@ -74,7 +74,6 @@ build_runtime() { # Function to run inference run_inference() { local pte_path="$1" - local ptd_path="$2" if [[ ! -f "$EXECUTOR_RUNNER" ]]; then echo "Error: executor_runner not found at $EXECUTOR_RUNNER" @@ -87,16 +86,10 @@ run_inference() { exit 1 fi - if [[ ! -f "$ptd_path" ]]; then - echo "Error: PTD file not found: $ptd_path" - exit 1 - fi - echo "Running inference..." echo " PTE: $pte_path" - echo " PTD: $ptd_path" - "$EXECUTOR_RUNNER" --model_path "$pte_path" --data_path "$ptd_path" + "$EXECUTOR_RUNNER" --model_path "$pte_path" } # Parse command line arguments @@ -105,11 +98,11 @@ case "$1" in build_runtime ;; --run) - if [[ -z "$2" ]] || [[ -z "$3" ]]; then - echo "Usage: $0 --run " + if [[ -z "$2" ]]; then + echo "Usage: $0 --run " exit 1 fi - run_inference "$2" "$3" + run_inference "$2" ;; --check-build) check_build @@ -119,7 +112,7 @@ case "$1" in echo "" echo "Usage:" echo " $0 --build Build the Metal runtime" - echo " $0 --run Run inference with given model files" + echo " $0 --run Run inference with given model file" echo " $0 --check-build Check if runtime is already built" exit 1 ;; diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index c97298a6bc2..fc3e2c6d4e8 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -72,7 +72,22 @@ torch.bfloat16: "bfloat16", } +# Default tolerances for output comparison by dtype +# bfloat16 has lower precision (7 bits mantissa vs 23 for float32) +DEFAULT_TOLERANCES = { + torch.float32: {"atol": 1e-5, "rtol": 1e-5}, + torch.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, +} + + # Registry mapping model names to their configurations +# Each entry can optionally include: +# - "atol": float - Override absolute tolerance for all dtypes +# - "rtol": float - Override relative tolerance for all dtypes +# - "atol_": float - Override absolute tolerance for specific dtype (e.g., "atol_bfloat16") +# - "rtol_": float - Override relative tolerance for specific dtype (e.g., "rtol_bfloat16") +# - "skip": bool or str - Skip all tests for this module (True to skip, or string with reason) +# - "skip_": bool or str - Skip tests for specific dtype (e.g., "skip_bfloat16") MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {} @@ -206,6 +221,7 @@ def forward(self, x: torch.Tensor): "model_class": SingleConv2d, "input_shapes": [(4, 3, 8, 8)], "description": "Single Conv2d layer model", + "skip": True, } @@ -232,6 +248,7 @@ def forward(self, x): "model_class": DepthwiseConv, "input_shapes": [(1, 32, 112, 112)], "description": "Single Depthwise Conv2d layer model", + "skip": True, } @@ -412,6 +429,8 @@ def __init__(self): "model_class": AddStridedSDPA, "input_shapes": [(10, 20, 1500, 64)], "description": "Audio Encoder model with strided SDPA", + "atol_float32": 1e-4, + "atol_bfloat16": 5e-2, } @@ -532,6 +551,45 @@ def forward(self, x): "model_class": SingleResNetBlock, "input_shapes": [(1, 64, 8, 8)], "description": "Single ResNet block with skip connection", + "skip": True, +} + + +# ------------------------------------------------------------------------- +class TransformerBlock(nn.Module): + def __init__(self, embed_dim=256, num_heads=8, ff_dim=1024, dropout=0.1): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.self_attn = nn.MultiheadAttention( + embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True + ) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + + self.ffn = nn.Sequential( + nn.Linear(embed_dim, ff_dim), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(ff_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + attn_output, _ = self.self_attn(x, x, x) + x = self.norm1(x + attn_output) + ff_output = self.ffn(x) + x = self.norm2(x + ff_output) + return x + + +MODULE_REGISTRY["transformer_block"] = { + "model_class": TransformerBlock, + "input_shapes": [(4, 32, 256)], + "description": "Single transformer block with multi-head attention and FFN", + "skip": True, } @@ -540,6 +598,59 @@ def forward(self, x): # ============================================================================= +def get_tolerances_for_model( + model_name: str, dtype: torch.dtype +) -> Tuple[float, float]: + """ + Get atol and rtol tolerances for a specific model and dtype. + + Priority order: + 1. Model-specific dtype tolerance (e.g., "atol_bfloat16") + 2. Model-specific general tolerance (e.g., "atol") + 3. Default dtype tolerance from DEFAULT_TOLERANCES + + Returns: + Tuple of (atol, rtol) + """ + model_config = MODULE_REGISTRY.get(model_name, {}) + dtype_name = DTYPE_NAMES.get(dtype, "float32") + default_tols = DEFAULT_TOLERANCES.get(dtype, DEFAULT_TOLERANCES[torch.float32]) + + # Check for dtype-specific override, then general override, then default + atol = model_config.get( + f"atol_{dtype_name}", model_config.get("atol", default_tols["atol"]) + ) + rtol = model_config.get( + f"rtol_{dtype_name}", model_config.get("rtol", default_tols["rtol"]) + ) + + return atol, rtol + + +def should_skip_model(model_name: str, dtype: torch.dtype) -> Tuple[bool, str]: + """ + Check if a model should be skipped for testing. + + Priority order: + 1. Model-specific dtype skip (e.g., "skip_bfloat16") + 2. Model-specific general skip (e.g., "skip") + + Returns: + Tuple of (should_skip, reason) + """ + model_config = MODULE_REGISTRY.get(model_name, {}) + dtype_name = DTYPE_NAMES.get(dtype, "float32") + + # Check for dtype-specific skip first, then general skip + skip_value = model_config.get(f"skip_{dtype_name}", model_config.get("skip", False)) + + if skip_value is True: + return True, f"{model_name} is marked as skipped" + elif isinstance(skip_value, str): + return True, skip_value + return False, "" + + def get_model_and_inputs( model_name: str, dtype: torch.dtype = torch.float32 ) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]: @@ -605,22 +716,24 @@ def export_model_to_files( # Export to executorch executorch_program = export_model_to_metal(model, example_inputs) - # Save .pte file + # Save .pte file (Metal backend embeds data into the .pte file, no separate .ptd) pte_path = output_dir / f"{model_name}.pte" with open(pte_path, "wb") as f: f.write(executorch_program.buffer) - # Save .ptd file (tensor data) - executorch_program.write_tensor_data_to_file(str(output_dir)) - ptd_path = output_dir / "aoti_metal_blob.ptd" - - return pte_path, ptd_path, expected_output + return pte_path, expected_output -def run_executor_runner(pte_path: Path, ptd_path: Path) -> Tuple[bool, Optional[str]]: +def run_executor_runner( + pte_path: Path, output_path: Path +) -> Tuple[bool, Optional[str]]: """ Run the executor_runner binary with the given model files. + Args: + pte_path: Path to the .pte model file + output_path: Base path for output files (executor_runner will create -0.bin, etc.) + Returns: Tuple of (success, error_message). If success is True, error_message is None. If success is False, error_message contains details about the failure. @@ -635,8 +748,8 @@ def run_executor_runner(pte_path: Path, ptd_path: Path) -> Tuple[bool, Optional[ str(EXECUTOR_RUNNER), "--model_path", str(pte_path), - "--data_path", - str(ptd_path), + "--output_file", + str(output_path), ] try: @@ -660,32 +773,80 @@ def run_executor_runner(pte_path: Path, ptd_path: Path) -> Tuple[bool, Optional[ return False, f"executor_runner timed out after 60 seconds: {e}" -def read_output_file(filepath: Path) -> Optional[np.ndarray]: - """Read comma-separated output values from a file.""" +def read_binary_output_file(filepath: Path, dtype: torch.dtype) -> Optional[np.ndarray]: + """ + Read binary output values from an executor_runner output file. + + Args: + filepath: Path to the binary output file + dtype: The torch dtype to interpret the binary data as + + Returns: + numpy array of values, or None if file doesn't exist or is empty + """ + if not filepath.exists(): + return None + + # Map torch dtype to numpy dtype + dtype_map = { + torch.float32: np.float32, + torch.float16: np.float16, + torch.bfloat16: np.float32, # bfloat16 is read as float32 after conversion + torch.int32: np.int32, + torch.int64: np.int64, + } + + np_dtype = dtype_map.get(dtype, np.float32) + try: - with open(filepath, "r") as f: - content = f.read().strip() - if not content: + with open(filepath, "rb") as f: + binary_data = f.read() + if not binary_data: return None - values = [float(x.strip()) for x in content.split(",") if x.strip()] - return np.array(values) - except (FileNotFoundError, ValueError): + # For bfloat16, the runtime output is in bfloat16 format (2 bytes per element) + # We need to read it as uint16 and convert + if dtype == torch.bfloat16: + # Read as uint16 (2 bytes per element like bfloat16) + values_uint16 = np.frombuffer(binary_data, dtype=np.uint16) + # Convert bfloat16 to float32 by shifting left 16 bits + values_uint32 = values_uint16.astype(np.uint32) << 16 + values = values_uint32.view(np.float32) + else: + values = np.frombuffer(binary_data, dtype=np_dtype) + return values + except (FileNotFoundError, ValueError) as e: + print(f"Error reading binary file {filepath}: {e}") return None def compare_outputs( expected: torch.Tensor, runtime_output_file: Path, - atol: float = 1e-5, - rtol: float = 1e-5, + dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, ) -> Tuple[bool, Optional[float], Optional[float]]: """ - Compare expected PyTorch output with runtime output from file. + Compare expected PyTorch output with runtime output from binary file. + + Args: + expected: Expected output tensor from PyTorch + runtime_output_file: Path to the binary output file from executor_runner + dtype: The dtype used for the model (needed to parse binary output) + atol: Absolute tolerance for comparison (if None, uses dtype-specific default) + rtol: Relative tolerance for comparison (if None, uses dtype-specific default) Returns: Tuple of (is_close, max_atol, max_rtol) """ - runtime_values = read_output_file(runtime_output_file) + # Use dtype-specific tolerances if not specified + tolerances = DEFAULT_TOLERANCES.get(dtype, DEFAULT_TOLERANCES[torch.float32]) + if atol is None: + atol = tolerances["atol"] + if rtol is None: + rtol = tolerances["rtol"] + + runtime_values = read_binary_output_file(runtime_output_file, dtype) if runtime_values is None: return False, None, None @@ -693,10 +854,10 @@ def compare_outputs( # (required when tensor is on MPS device) if isinstance(expected, tuple): expected_values = np.concatenate( - [t.detach().cpu().flatten().numpy() for t in expected] + [t.detach().cpu().float().flatten().numpy() for t in expected] ) else: - expected_values = expected.detach().cpu().flatten().numpy() + expected_values = expected.detach().cpu().float().flatten().numpy() if len(runtime_values) != len(expected_values): return False, None, None @@ -736,6 +897,11 @@ def _test_module_export( self, model_name: str, dtype: torch.dtype = torch.float32 ) -> None: """Generic test for module export.""" + # Check if this model/dtype combination should be skipped + skip, skip_reason = should_skip_model(model_name, dtype) + if skip: + self.skipTest(skip_reason) + if SKIP_EXPORT_TESTS: self.skipTest(SKIP_REASON) @@ -770,10 +936,15 @@ def _test_module_output_consistency( Test that Metal backend runtime output matches PyTorch output. This test: - 1. Exports the model to .pte and .ptd files + 1. Exports the model to a .pte file 2. Runs the model using executor_runner 3. Compares the runtime output with expected PyTorch output """ + # Check if this model/dtype combination should be skipped + skip, skip_reason = should_skip_model(model_name, dtype) + if skip: + self.skipTest(skip_reason) + if SKIP_RUNTIME_TESTS: self.skipTest(SKIP_RUNTIME_REASON) @@ -788,7 +959,7 @@ def run_test_in_directory(test_dir: Path) -> None: model_output_dir.mkdir(parents=True, exist_ok=True) # Export model and get expected output - pte_path, ptd_path, expected_output = export_model_to_files( + pte_path, expected_output = export_model_to_files( model, example_inputs, model_output_dir, model_name ) @@ -796,29 +967,29 @@ def run_test_in_directory(test_dir: Path) -> None: pte_path.exists(), f"{model_name} ({dtype_name}): PTE file not created at {pte_path}", ) - self.assertTrue( - ptd_path.exists(), - f"{model_name} ({dtype_name}): PTD file not created at {ptd_path}", - ) - # Run executor_runner - success, error_msg = run_executor_runner(pte_path, ptd_path) + # Run executor_runner with output file + output_base_path = model_output_dir / "output" + success, error_msg = run_executor_runner(pte_path, output_base_path) self.assertTrue( success, f"{model_name} ({dtype_name}): executor_runner failed\n{error_msg}", ) - # Compare outputs - executor_runner writes to aoti_debug_data/ in cwd - # In CI, this is TEST_OUTPUT_BASE_DIR; locally it may vary - runtime_output_file = model_output_dir / "final_runtime_output.txt" + # executor_runner writes output files as -.bin + # For single output models, this is output-0.bin + runtime_output_file = model_output_dir / "output-0.bin" self.assertTrue( runtime_output_file.exists(), f"{model_name} ({dtype_name}): Runtime output file not created at {runtime_output_file}", ) + # Get model-specific tolerances (with dtype-specific overrides) + atol, rtol = get_tolerances_for_model(model_name, dtype) + is_close, max_atol, max_rtol = compare_outputs( - expected_output, runtime_output_file + expected_output, runtime_output_file, dtype, atol=atol, rtol=rtol ) self.assertTrue( From 31b6f4555bae21d41797d01ac18fa3b888f441f4 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 16:27:14 -0500 Subject: [PATCH 05/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 261 +++------------------ 1 file changed, 29 insertions(+), 232 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index fc3e2c6d4e8..6a5eaeb9a53 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -50,7 +50,7 @@ # Test output directory - use current working directory in CI for reliable write access if IS_CI: - TEST_OUTPUT_BASE_DIR = Path.cwd() / "aoti_debug_data" + TEST_OUTPUT_BASE_DIR = Path.cwd() / "metal_backend_module_outputs" else: TEST_OUTPUT_BASE_DIR = None # Will use tempfile.TemporaryDirectory @@ -126,7 +126,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): # ------------------------------------------------------------------------- -class MmWeights(nn.Module): +class MmWeightParam(nn.Module): def __init__(self): super().__init__() self.weight = nn.Parameter(torch.arange(20, dtype=torch.float).reshape(4, 5)) @@ -135,51 +135,13 @@ def forward(self, x: torch.Tensor): return x.mm(self.weight) -MODULE_REGISTRY["mm_weights"] = { - "model_class": MmWeights, +MODULE_REGISTRY["mm_weight_param"] = { + "model_class": MmWeightParam, "input_shapes": [(3, 4)], "description": "Matrix multiplication with weight parameter", } -# ------------------------------------------------------------------------- -class TwoMm(nn.Module): - def __init__(self): - super().__init__() - self.left_weight = nn.Parameter( - torch.arange(20, dtype=torch.float).reshape(4, 5) - ) - self.right_weight = nn.Parameter( - torch.arange(42, dtype=torch.float).reshape(6, 7) - ) - - def forward(self, x: torch.Tensor): - return self.left_weight.mm(x).mm(self.right_weight) - - -MODULE_REGISTRY["two_mm"] = { - "model_class": TwoMm, - "input_shapes": [(5, 6)], - "description": "Two consecutive matrix multiplications", -} - - -# ------------------------------------------------------------------------- -class ElementwiseMmReduction(nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor): - x1 = x.sin() + x - y2 = y.cos() + 3 - z = x1.mm(y2) - return z + z.sum() - - -MODULE_REGISTRY["elementwise_mm_reduction"] = { - "model_class": ElementwiseMmReduction, - "input_shapes": [(11, 45), (45, 8)], - "description": "Combining mm with elementwise and reduction ops", -} - - # ------------------------------------------------------------------------- # Linear Modules # ------------------------------------------------------------------------- @@ -206,54 +168,7 @@ def forward(self, x: torch.Tensor): # ------------------------------------------------------------------------- -class SingleConv2d(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d( - in_channels=3, out_channels=5, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x: torch.Tensor): - return self.conv(x) - - -MODULE_REGISTRY["conv2d"] = { - "model_class": SingleConv2d, - "input_shapes": [(4, 3, 8, 8)], - "description": "Single Conv2d layer model", - "skip": True, -} - - -# ------------------------------------------------------------------------- -class DepthwiseConv(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=1, - padding=1, - dilation=1, - groups=32, - bias=False, - ) - - def forward(self, x): - return self.conv(x) - - -MODULE_REGISTRY["depthwise_conv"] = { - "model_class": DepthwiseConv, - "input_shapes": [(1, 32, 112, 112)], - "description": "Single Depthwise Conv2d layer model", - "skip": True, -} - - -# ------------------------------------------------------------------------- -class SmallConv1d(nn.Module): +class Conv1dNoBias(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv1d( @@ -271,15 +186,15 @@ def forward(self, x): return self.conv(x) -MODULE_REGISTRY["small_conv1d"] = { - "model_class": SmallConv1d, +MODULE_REGISTRY["conv1d_nobias"] = { + "model_class": Conv1dNoBias, "input_shapes": [(1, 8, 5)], "description": "Conv1d layer with 8 input channels, 6 output channels", } # ------------------------------------------------------------------------- -class MediumConv1d(nn.Module): +class Conv1dBias(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv1d( @@ -297,15 +212,15 @@ def forward(self, x): return self.conv(x) -MODULE_REGISTRY["conv1d"] = { - "model_class": MediumConv1d, +MODULE_REGISTRY["conv1d_bias"] = { + "model_class": Conv1dBias, "input_shapes": [(1, 80, 3000)], "description": "Conv1d layer with 80 input channels, 384 output channels", } # ------------------------------------------------------------------------- -class VoxtralConv1d(nn.Module): +class Conv1dVoxtral(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv1d( @@ -323,8 +238,8 @@ def forward(self, x): return self.conv(x) -MODULE_REGISTRY["voxtral_conv1d"] = { - "model_class": VoxtralConv1d, +MODULE_REGISTRY["conv1d_voxtral"] = { + "model_class": Conv1dVoxtral, "input_shapes": [(10, 128, 3000)], "description": "Conv1d layer with 128 input channels, 1280 output channels", } @@ -335,7 +250,7 @@ def forward(self, x): # ------------------------------------------------------------------------- -class SimpleSDPA(nn.Module): +class SDPA(nn.Module): """Minimal SDPA test model.""" def forward( @@ -348,14 +263,14 @@ def forward( MODULE_REGISTRY["sdpa"] = { - "model_class": SimpleSDPA, + "model_class": SDPA, "input_shapes": [(2, 4, 16, 64), (2, 4, 16, 64), (2, 4, 16, 64)], "description": "Simple Scaled Dot Product Attention model", } # ------------------------------------------------------------------------- -class AddSDPA(nn.Module): +class SDPAAdd(nn.Module): """SDPA model with Q, K, V as parameters that adds input to SDPA output.""" def __init__(self, batch_size=2, num_heads=4, seq_len=16, head_dim=64): @@ -371,15 +286,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return sdpa_output + x -MODULE_REGISTRY["add_sdpa"] = { - "model_class": AddSDPA, +MODULE_REGISTRY["sdpa_add"] = { + "model_class": SDPAAdd, "input_shapes": [(2, 4, 16, 64)], "description": "SDPA model with Q,K,V as parameters that adds input to output", } # ------------------------------------------------------------------------- -class BaseAddStridedSDPA(nn.Module): +class BaseStridedSDPA(nn.Module): """SDPA model with strided Q, K, V parameters.""" def __init__( @@ -413,7 +328,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # ------------------------------------------------------------------------- -class AddStridedSDPA(BaseAddStridedSDPA): +class SDPAStrided(BaseStridedSDPA): def __init__(self): super().__init__( q_size=(10, 20, 1500, 64), @@ -425,8 +340,8 @@ def __init__(self): ) -MODULE_REGISTRY["audio_encoder_sdpa1"] = { - "model_class": AddStridedSDPA, +MODULE_REGISTRY["sdpa_strided"] = { + "model_class": SDPAStrided, "input_shapes": [(10, 20, 1500, 64)], "description": "Audio Encoder model with strided SDPA", "atol_float32": 1e-4, @@ -435,7 +350,7 @@ def __init__(self): # ------------------------------------------------------------------------- -class AddStridedSDPA1(BaseAddStridedSDPA): +class SDPAStridedBroadcast(BaseStridedSDPA): def __init__(self): super().__init__( q_size=(1, 20, 1, 64), @@ -447,15 +362,15 @@ def __init__(self): ) -MODULE_REGISTRY["whisper_strided_sdpa1"] = { - "model_class": AddStridedSDPA1, +MODULE_REGISTRY["sdpa_strided_broadcast"] = { + "model_class": SDPAStridedBroadcast, "input_shapes": [(1, 20, 1, 64)], "description": "Whisper-like strided SDPA variant 1", } # ------------------------------------------------------------------------- -class AddStridedSDPA2(BaseAddStridedSDPA): +class SDPAStridedBroadcastAttnMask(BaseStridedSDPA): def __init__(self): super().__init__( q_size=(1, 20, 1, 64), @@ -468,131 +383,13 @@ def __init__(self): ) -MODULE_REGISTRY["whisper_strided_sdpa2"] = { - "model_class": AddStridedSDPA2, +MODULE_REGISTRY["sdpa_strided_broadcast_attn_mask"] = { + "model_class": SDPAStridedBroadcastAttnMask, "input_shapes": [(1, 20, 1, 64)], "description": "Whisper-like strided SDPA variant 2", } -# ------------------------------------------------------------------------- -# Normalization Modules -# ------------------------------------------------------------------------- - - -class BatchNorm(nn.Module): - def __init__(self): - super().__init__() - self.bn = nn.BatchNorm2d(num_features=16) - - def forward(self, x): - return self.bn(x) - - -MODULE_REGISTRY["batchnorm"] = { - "model_class": BatchNorm, - "input_shapes": [(1, 16, 32, 32)], - "description": "Single BatchNorm2d layer model", -} - - -# ------------------------------------------------------------------------- -# Block/Composite Modules -# ------------------------------------------------------------------------- - - -class SingleResNetBlock(nn.Module): - def __init__(self, in_channels=64, out_channels=64, stride=1): - super().__init__() - self.conv1 = nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=1, - bias=False, - ) - self.bn1 = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False - ) - self.bn2 = nn.BatchNorm2d(out_channels) - - self.skip_connection = None - if stride != 1 or in_channels != out_channels: - self.skip_connection = nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=stride, bias=False - ), - nn.BatchNorm2d(out_channels), - ) - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.skip_connection is not None: - identity = self.skip_connection(x) - - out += identity - out = self.relu(out) - - return out - - -MODULE_REGISTRY["single_resnet_block"] = { - "model_class": SingleResNetBlock, - "input_shapes": [(1, 64, 8, 8)], - "description": "Single ResNet block with skip connection", - "skip": True, -} - - -# ------------------------------------------------------------------------- -class TransformerBlock(nn.Module): - def __init__(self, embed_dim=256, num_heads=8, ff_dim=1024, dropout=0.1): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - - self.self_attn = nn.MultiheadAttention( - embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True - ) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - - self.ffn = nn.Sequential( - nn.Linear(embed_dim, ff_dim), - nn.ReLU(), - nn.Dropout(dropout), - nn.Linear(ff_dim, embed_dim), - nn.Dropout(dropout), - ) - - def forward(self, x): - attn_output, _ = self.self_attn(x, x, x) - x = self.norm1(x + attn_output) - ff_output = self.ffn(x) - x = self.norm2(x + ff_output) - return x - - -MODULE_REGISTRY["transformer_block"] = { - "model_class": TransformerBlock, - "input_shapes": [(4, 32, 256)], - "description": "Single transformer block with multi-head attention and FFN", - "skip": True, -} - - # ============================================================================= # Helper Functions # ============================================================================= @@ -954,7 +751,7 @@ def _test_module_output_consistency( def run_test_in_directory(test_dir: Path) -> None: """Run the actual test logic in the given directory.""" - # Create model output directory: aoti_debug_data/_/ + # Create model output directory: metal_backend_module_outputs/_/ model_output_dir = test_dir / test_subdir_name model_output_dir.mkdir(parents=True, exist_ok=True) From c68cc6b6fcbfc1c97d00cc872b0c04a03d0a10af Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 16:27:15 -0500 Subject: [PATCH 06/26] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 7 +- backends/apple/metal/passes/__init__.py | 11 ++ .../metal/passes/decompose_linear_pass.py | 111 ++++++++++++++++++ backends/apple/metal/tests/test_modules.py | 17 +++ 4 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 backends/apple/metal/passes/__init__.py create mode 100644 backends/apple/metal/passes/decompose_linear_pass.py diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index fde0410cca3..fd94c5ba7a7 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -43,8 +43,11 @@ def get_decomposition_table(cls) -> Dict[Any, Any]: @classmethod def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: - """Return Metal-specific passes (currently none)""" - return [] + """Return Metal-specific passes""" + from executorch.backends.apple.metal.passes.decompose_linear_pass import ( + DecomposeLinearPass, + ) + return [DecomposeLinearPass()] @classmethod def get_aoti_compile_options( diff --git a/backends/apple/metal/passes/__init__.py b/backends/apple/metal/passes/__init__.py new file mode 100644 index 00000000000..2a1209f1356 --- /dev/null +++ b/backends/apple/metal/passes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401 + DecomposeLinearPass, +) + +__all__ = ["DecomposeLinearPass"] diff --git a/backends/apple/metal/passes/decompose_linear_pass.py b/backends/apple/metal/passes/decompose_linear_pass.py new file mode 100644 index 00000000000..e6b8578cc9f --- /dev/null +++ b/backends/apple/metal/passes/decompose_linear_pass.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + Decompose aten.linear into matmul + add to avoid addmm. + + For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul + code path instead of addmm. The C++ implementation of aten.linear directly + calls addmm for 2D inputs with bias, which would require implementing + aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path, + then squeeze back to 2D. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + graph = graph_module.graph + + for node in graph.nodes: + # Check if this is a linear operation + is_linear = False + + if node.op == "call_function": + # Match both edge dialect and core aten linear operators + if node.target == exir_ops.edge.aten.linear.default: + is_linear = True + elif node.target == torch.ops.aten.linear.default: + is_linear = True + + if is_linear: + # Get input, weight, and bias arguments + input_node = node.args[0] + weight_node = node.args[1] + bias_node = node.args[2] if len(node.args) > 2 else None + + with graph.inserting_before(node): + # Determine which ops to use based on the input operator + target_str = str(node.target) + + if "executorch_exir_dialects_edge" in target_str: + # Use edge dialect operators + t_op = exir_ops.edge.aten.t.default + matmul_op = exir_ops.edge.aten.matmul.default + add_op = exir_ops.edge.aten.add.Tensor + unsqueeze_op = exir_ops.edge.aten.unsqueeze.default + squeeze_op = exir_ops.edge.aten.squeeze.dims + else: + # Use core aten operators + t_op = torch.ops.aten.t.default + matmul_op = torch.ops.aten.matmul.default + add_op = torch.ops.aten.add.Tensor + unsqueeze_op = torch.ops.aten.unsqueeze.default + squeeze_op = torch.ops.aten.squeeze.dims + + # Check if input is 2D + needs_unsqueeze = False + if hasattr(input_node, "meta") and "val" in input_node.meta: + if len(input_node.meta["val"].shape) == 2: + needs_unsqueeze = True + + # Unsqueeze 2D input to 3D: (M, K) -> (1, M, K) + current_input = input_node + if needs_unsqueeze: + current_input = graph.call_function( + unsqueeze_op, + args=(input_node, 0), + ) + + # Decompose linear: matmul(input, weight.T) + bias + weight_t = graph.call_function( + t_op, + args=(weight_node,), + ) + + matmul_result = graph.call_function( + matmul_op, + args=(current_input, weight_t), + ) + + if bias_node is not None: + result = graph.call_function( + add_op, + args=(matmul_result, bias_node), + ) + else: + result = matmul_result + + # Squeeze 3D output back to 2D: (1, M, N) -> (M, N) + if needs_unsqueeze: + result = graph.call_function( + squeeze_op, + args=(result, [0]), + ) + + # Replace all uses of the linear node with the decomposed result + node.replace_all_uses_with(result) + graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 6a5eaeb9a53..1828545c8a0 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -163,6 +163,23 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearWithBias(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(7, 101, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_bias"] = { + "model_class": LinearWithBias, + "input_shapes": [(127, 7)], + "description": "Simple linear layer model with no bias", +} + + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- From bd7192fac4ba42c0f293179b40df919d4239775f Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 16:27:19 -0500 Subject: [PATCH 07/26] Update [ghstack-poisoned] --- .../models/parakeet/export_parakeet_tdt.py | 35 ++----------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index c97c01c1bcb..9b21f14dc71 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -419,25 +419,6 @@ def _create_xnnpack_partitioners(programs): return partitioner, programs -# This custom decomposition is the key to making Parakeet run on the Metal backend. -# Without this, linear gets decomposed in a way that doesn't work for us. -# When input/weight tensors are 2D and bias is present, this gets decomposed into addmm and -# reinterpret_tensor_wrapper gets called on the bias, to make it look like a 2D tensor. -# On one hand, this requires us to implement addmm in the Metal backend. But more importantly, -# the reinterpret_tensor_wrapper call makes its way to ExecuTorch, causing a call to executorch::extension::from_blob -# with a 0 stride. ExecuTorch doesn't support that, and raises an error. -# This decomposition avoids that problem, and also avoids having to implement addmm. -def _linear_bias_decomposition(input, weight, bias=None): - """Decompose linear with bias into matmul + add.""" - # linear(input, weight) = input @ weight.T - # Use matmul instead of mm to handle batched inputs (3D+) - weight_t = torch.ops.aten.t.default(weight) - out = torch.ops.aten.matmul.default(input, weight_t) - if bias is not None: - return torch.ops.aten.add.Tensor(out, bias) - return out - - def _create_metal_partitioners(programs): """Create Metal partitioners for all programs except preprocessor.""" from executorch.backends.apple.metal.metal_backend import MetalBackend @@ -445,26 +426,14 @@ def _create_metal_partitioners(programs): print("\nLowering to ExecuTorch with Metal...") - # Run decompositions for non-preprocessor programs - updated_programs = {} - for key, ep in programs.items(): - # print(f"Running decompositions for {key}") - # print(ep.graph_module) - if key != "preprocessor": - updated_programs[key] = ep.run_decompositions( - {torch.ops.aten.linear.default: _linear_bias_decomposition} - ) - else: - updated_programs[key] = ep - partitioner = {} - for key in updated_programs.keys(): + for key in programs.keys(): if key == "preprocessor": partitioner[key] = [] else: compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] partitioner[key] = [MetalPartitioner(compile_specs)] - return partitioner, updated_programs + return partitioner, programs def _create_cuda_partitioners(programs, is_windows=False): From bcc8bda985b97078e25e9ab03defd962afcae6b2 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 16:27:23 -0500 Subject: [PATCH 08/26] Update [ghstack-poisoned] --- backends/aoti/common_shims.cpp | 8 +- backends/aoti/common_shims.h | 1 + backends/aoti/utils.h | 2 + backends/apple/metal/metal_backend.py | 9 +- .../apple/metal/runtime/metal_backend.cpp | 2 + .../apple/metal/runtime/shims/et_metal_ops.h | 8 + .../apple/metal/runtime/shims/et_metal_ops.mm | 844 ++++++++++++++++++ backends/apple/metal/runtime/shims/utils.cpp | 4 +- backends/apple/metal/runtime/shims/utils.h | 2 +- 9 files changed, 875 insertions(+), 5 deletions(-) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index 7c88e4cfb5b..55dad54526e 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -174,12 +174,16 @@ int32_t aoti_torch_dtype_bfloat16() { return 15; // PyTorch's bfloat16 dtype code } +int32_t aoti_torch_dtype_uint8() { + return 0; // PyTorch's uint8 dtype code +} + int32_t aoti_torch_dtype_int8() { - return 1; // PyTorch's int32 dtype code + return 1; // PyTorch's int8 dtype code } int32_t aoti_torch_dtype_int16() { - return 2; // PyTorch's int32 dtype code + return 2; // PyTorch's int16 dtype code } int32_t aoti_torch_dtype_int32() { diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 3fc414fb669..6f7313e9b60 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -66,6 +66,7 @@ AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_float32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bfloat16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_bool(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_uint8(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int16(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int32(); AOTI_SHIM_EXPORT int32_t aoti_torch_dtype_int64(); diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 8f64bdbe7da..7dfe2d38568 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -35,6 +35,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { // Convert based on known PyTorch dtype codes (without CUDA-specific // dependency) switch (dtype) { + case 0: // PyTorch's uint8 dtype code + return executorch::aten::ScalarType::Byte; case 1: // PyTorch's int8 dtype code return executorch::aten::ScalarType::Char; case 2: // PyTorch's int16 dtype code diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index fd94c5ba7a7..fd114b20b3d 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -35,6 +35,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]: "aoti_torch_mps_convolution": None, "aoti_torch_mps_mm_out": None, "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + "torchao::_linear_fp_act_4bit_weight": None, } @classmethod @@ -55,7 +56,8 @@ def get_aoti_compile_options( ) -> Dict[str, typing.Any]: """Get AOTI compile options for Metal backend.""" _ = compile_specs # Unused, but required by interface - return { + + inductor_configs = { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, # Separate weight constants from the .so file @@ -68,3 +70,8 @@ def get_aoti_compile_options( # "aot_inductor.debug_compile": True, # "aot_inductor.force_mmap_weights": False, } + + from torchao.experimental.ops.mps.cshim import torchao_op_c_shim + inductor_configs["aot_inductor.custom_ops_to_c_shims"] = torchao_op_c_shim + + return inductor_configs diff --git a/backends/apple/metal/runtime/metal_backend.cpp b/backends/apple/metal/runtime/metal_backend.cpp index dfa148fd437..7c626da4e5b 100644 --- a/backends/apple/metal/runtime/metal_backend.cpp +++ b/backends/apple/metal/runtime/metal_backend.cpp @@ -315,6 +315,8 @@ class ET_EXPERIMENTAL MetalBackend final "Failed to load shared library: %s", dlerror()); + ET_LOG(Info, "MetalBackend::init - Loaded shared library: %s", so_path.c_str()); + processed->Free(); // Create handle and load function pointers into it diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h index fcc6dfc03da..e1467bdd842 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.h +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -74,6 +74,14 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( AOTITensorHandle* ret0, AOTITensorHandle* ret1); +AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( + AOTITensorHandle A, + AOTITensorHandle B, + int64_t group_size, + AOTITensorHandle S, + AOTITensorHandle Z, + AOTITensorHandle* ret); + #ifdef __cplusplus } // extern "C" #endif diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 21989fa5665..ec884b50776 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -348,6 +348,616 @@ void logStats() { return sdpa_shader_library.get(); } +// Helper function to get the Metal shader source for SDPA +static std::string get_int4_metal_source() { + return R"( + /** + * common.metal + */ + + // Copyright (c) Meta Platforms, Inc. and affiliates. + // All rights reserved. + // + // This source code is licensed under the BSD 3-Clause license found in the + // LICENSE file in the root directory of this source tree. + + template struct Vec4Type {}; + + template <> struct Vec4Type { + using type = float4; + }; + + template <> struct Vec4Type { + using type = half4; + }; + + #if __METAL_VERSION__ >= 310 + template <> struct Vec4Type { + using type = bfloat4; + }; + #endif + + /** + * int4mm_opt.metal + */ + + // Copyright (c) Meta Platforms, Inc. and affiliates. + // All rights reserved. + // + // This source code is licensed under the BSD 3-Clause license found in the + // LICENSE file in the root directory of this source tree. + #include + #include + using namespace metal; + + /* + This code takes heavy inspiration from MLX: + https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.h + Specifically: + - Multiplying activation by inverse scaling factor to reduce compute + boundedness + - Handling zero point by accumulating act in separate sum term. Needed with + optimization done above. MLX MIT License: + https://github.com/ml-explore/mlx/blob/main/LICENSE + */ + + /* + A matrix is [M x K] (right now this kernel does not support M > 1 but this is + a very easy fix that will follow right after) B matrix is [N x K]. For 4 bit + 2 of the k values are packed in one byte so you can think of B as [N x K/2] + matrix from layout perspective. + + Since this kernel is optimizing for gemv case, we split work, along reduction + dim k, among the threads of same simdgroup. Ex: if K = 4096 and simdgroup + size is 32 (current algorithm should work as long as simdgroup size is > 32). + Then each thread will accumulate 4096/32 = 128 k values. However these 128 + values, handled by each thread are not laid out contiguously. Each thread + handles 4 contiguous k values and then jumps 128 elements, k_jump = + thread_per_channel (32) * ks_per_thread (4). Take a simpler example where + simdgroup is of size 4. In this case threads_per_channel = 4. Assume K = 32 + k thread + [0, 1, 2, 3, 0 + 4, 5, 6, 7, 1 + 8, 9, 10, 11, 2 + 12, 13, 14, 15, 3 + 16, 17, 18, 19, 0 + 20, 21, 22, 23, 1 + 24, 25, 26, 27, 2 + 28, 29, 30, 31] 3 + thread id in simd group that handle corresponding + ks + Thread 0 here is handling (0, 1, 2, 3) and then (16, 17, 18, 19). They are + apart by k_jump = 4 * 4 = 16 This is done to improve memory access locality + amonng threads that are working co-operatively. Once each thread has their + partial sums accumulated, we use tree reduction (Metal offers simd_sum but + not used so that we support simdgroup size = 64). In the + example above we will have 4 partial sums. + + Each thread also handles 4 different output rows. Thus each simdgroup will be + responsible for (1x4) tile of the output. We haven't evaluated whether a + different tile size is better or not. We probably will do some auto-tuning + once initial work is done. + */ + + /* + @brief This shader implements 4-bit matrix-vector multiplication where A + matrix is fp16, bfloat or float and B matrix is a 4-bit groupwise-quantized weight + matrix. + @param [in] A is activation matrix of size M x K. + @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit + values, along K dim, packed together. + @param [in] scales_ptr is scales ptr corresponding each + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output + channels. + @param [in] zeros_ptr is zero points corresponding each + output channel x groups. These are packed as [N, num_groups = ceil(K / group_size)]. N = output + channels. + @param [out] output_data is output matrix of size M x N. + @param [in] sizes array contains values of M, K and N. + @param [in] thread_index is global thread id. + @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31]. + */ + template + kernel void int4pack_mm(constant T *A [[buffer(0)]], + constant uchar *B [[buffer(1)]], + constant T *scales_ptr [[buffer(2)]], + constant T *zeros_ptr [[buffer(3)]], + device T *output_data [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 thread_index [[thread_position_in_grid]], + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) { + constexpr uint threads_per_channel = 32; + constexpr uint ks_per_thread = 4; + constexpr uint k_pack_factor = 2; + const uint K = sizes.y; + const uint N = sizes.z; + const uint num_groups = (K + group_size - 1) / group_size; + uint n = thread_index.x; // 0..N/4-1 + uint m = thread_index.z; // 0..M + n = n / threads_per_channel; + n = n * 4; + // This is starting k for each thread. In the example above, for thread 1 this + // value will be 4. + uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread; + constexpr int k_jump = threads_per_channel * ks_per_thread; + + using vecT = typename Vec4Type::type; + constant vecT *A_ptr = reinterpret_cast(A + m * K); + constant uchar *B_ptr = B + ((n * K) / k_pack_factor); + + thread float4 result = float4(0.0); + // We multipy group of 4 channels with these scales. + // Because corresponding values from weight matrix are effectively left + // shifted. This is to avoid doing right shift on those values which ends up + // affecting performance. This is the trick applied in MLX kernels. + float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f}; + + for (; k < K; k += k_jump) { + // Find specific group to which channels handled by this thread + // belong. + uint k_block_index = k / group_size; + uint scales_group_offset = (n * num_groups + k_block_index); + + vecT scales = + vecT(scales_ptr[scales_group_offset], + scales_ptr[scales_group_offset + num_groups], + scales_ptr[scales_group_offset + 2 * num_groups], + scales_ptr[scales_group_offset + 3 * num_groups]); + // Adding zero point results in 10% perf penalty. + vecT zeros = + vecT(zeros_ptr[scales_group_offset], + zeros_ptr[scales_group_offset + num_groups], + zeros_ptr[scales_group_offset + 2 * num_groups], + zeros_ptr[scales_group_offset + 3 * num_groups]); + float4 zeros_float = float4(zeros); + + float4 a_val = float4(A_ptr[k / 4]); + // We are gonna skip right-shifts of the weights and hence divide by corresponding factor. + float4 a_vec = a_val * act_div_scales; + float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3]; + + float4x4 b_mat; + ushort b_val0 = (reinterpret_cast( + B_ptr + (k + 0 * K) / k_pack_factor))[0]; + ushort b_val1 = (reinterpret_cast( + B_ptr + (k + 1 * K) / k_pack_factor))[0]; + ushort b_val2 = (reinterpret_cast( + B_ptr + (k + 2 * K) / k_pack_factor))[0]; + ushort b_val3 = (reinterpret_cast( + B_ptr + (k + 3 * K) / k_pack_factor))[0]; + b_mat[0] = scales[0] * float4(float(b_val0 & 0x000f), float(b_val0 & 0x00f0), + float(b_val0 & 0x0f00), float(b_val0 & 0xf000)); + b_mat[1] = scales[1] * float4(float(b_val1 & 0x000f), float(b_val1 & 0x00f0), + float(b_val1 & 0x0f00), float(b_val1 & 0xf000)); + b_mat[2] = scales[2] * float4(float(b_val2 & 0x000f), float(b_val2 & 0x00f0), + float(b_val2 & 0x0f00), float(b_val2 & 0xf000)); + b_mat[3] = scales[3] * float4(float(b_val3 & 0x000f), float(b_val3 & 0x00f0), + float(b_val3 & 0x0f00), float(b_val3 & 0xf000)); + + result += a_vec * b_mat; + result += a_val_sum * zeros_float; + } + result += simd_shuffle_down(result, 1); + result += simd_shuffle_down(result, 2); + result += simd_shuffle_down(result, 4); + result += simd_shuffle_down(result, 8); + result += simd_shuffle_down(result, 16); + if (tid_in_simdgroup % threads_per_channel == 0) { + reinterpret_cast(output_data + m * N)[n / 4] = vecT(result); + } + } + + #define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ + template [[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \ + int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + + INSTANTIATE_INT4MM(float, 32); + INSTANTIATE_INT4MM(half, 32); + INSTANTIATE_INT4MM(float, 64); + INSTANTIATE_INT4MM(half, 64); + INSTANTIATE_INT4MM(float, 128); + INSTANTIATE_INT4MM(half, 128); + INSTANTIATE_INT4MM(float, 256); + INSTANTIATE_INT4MM(half, 256); + #if __METAL_VERSION__ >= 310 + INSTANTIATE_INT4MM(bfloat, 32); + INSTANTIATE_INT4MM(bfloat, 64); + INSTANTIATE_INT4MM(bfloat, 128); + INSTANTIATE_INT4MM(bfloat, 256); + #endif + + /** + * qmv_fast.metal + */ + + // Copyright (c) Meta Platforms, Inc. and affiliates. + // All rights reserved. + // + // This source code is licensed under the BSD 3-Clause license found in the + // LICENSE file in the root directory of this source tree. + + /* + This code was taken from MLX, and modified to add support for 1, 5 & 7 bit packing. + The original code is Copyright © 2023-2024 Apple Inc. + https://github.com/ml-explore/mlx/blob/481349495b8c3d094eb699e678077bbe1406392d/mlx/backend/metal/kernels/quantized.h#L1 + MLX MIT License: https://github.com/ml-explore/mlx/blob/main/LICENSE + */ + + #include + #include + + static constant constexpr const int SIMD_SIZE = 32; + + template + inline U load_vector(constant T* x, thread U* x_thread) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U sum = 0; + + if (bits == 1) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 7) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 128.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 32.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 8.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 2.0f; + } + } + + return sum; + } + + template + inline U qdot( + constant uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U accum = 0; + + if (bits == 1) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + + accum += + (x_thread[0] * (w[i] & 0x01) + + x_thread[1] * (w[i] & 0x02) + + x_thread[2] * (w[i] & 0x04) + + x_thread[3] * (w[i] & 0x08) + + x_thread[4] * (w[i] & 0x10) + + x_thread[5] * (w[i] & 0x20) + + x_thread[6] * (w[i] & 0x40) + + x_thread[7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + constant uint16_t* ws = (constant uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + + accum += (w[1] & 0x03) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + + accum += (w[2] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + + accum += (w[3] & 0x01) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + + accum += (w[4] & 0x07) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 7) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 7 * i; + + accum += (w[0] & 0x7f) * x_thread[0]; + accum += (w[0] & 0x80) * x_thread[1]; + + accum += (w[1] & 0x3f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xc0) * x_thread[2]; + + accum += (w[2] & 0x1f) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xe0) * x_thread[3]; + + accum += (w[3] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[3] & 0xf0) * x_thread[4]; + + accum += (w[4] & 0x07) * (x_thread[4] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[5]; + + accum += (w[5] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[5] & 0xfc) * x_thread[6]; + + accum += (w[6] & 0x01) * (x_thread[6] * 256.0f); + accum += (w[6] & 0xfe) * x_thread[7]; + } + } + + return scale * accum + sum * bias; + } + + template + [[kernel]] void qmv_fast( + constant T* x [[buffer(0)]], + constant uchar* w [[buffer(1)]], + constant T* scales [[buffer(2)]], + constant T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int in_vec_size = static_cast(sizes.y); // K + const int out_vec_size = static_cast(sizes.z); // N + + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = bits == 1 ? 16 : power_of_2_bits ? 32 / bits : bits == 6 ? 4 : 8; + constexpr int bytes_per_pack = bits == 1 ? 2 : power_of_2_bits ? 4 : bits == 6 ? 3 : bits; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + constant uint8_t* ws = (constant uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + #define INSTANTIATE_QMV_FAST(DTYPE, GSIZE, NBIT) \ + template [[host_name("qmv_fast_" #NBIT "bit_" #GSIZE "_" #DTYPE)]] kernel void \ + qmv_fast( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + + #define INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, GSIZE) \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 1); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 2); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 3); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 4); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 5); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 6); \ + INSTANTIATE_QMV_FAST(DTYPE, GSIZE, 7); + + #define INSTANTIATE_QMV_FAST_DTYPE(DTYPE) \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 32); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 64); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 128); \ + INSTANTIATE_QMV_FAST_DTYPE_GSIZE(DTYPE, 256); + + INSTANTIATE_QMV_FAST_DTYPE(float); + INSTANTIATE_QMV_FAST_DTYPE(half); + #if __METAL_VERSION__ >= 310 + INSTANTIATE_QMV_FAST_DTYPE(bfloat); + #endif + + )"; +} + +// Global shader library cache for SDPA +static std::unique_ptr int4_shader_library = nullptr; + +static std::once_flag int4_shader_library_once_flag; + +static ETMetalShaderLibrary* get_int4_shader_library() { + std::call_once(int4_shader_library_once_flag, []() { + std::string source = get_int4_metal_source(); + int4_shader_library = std::make_unique(source); + }); + return int4_shader_library.get(); +} + } // anonymous namespace extern "C" { @@ -1844,6 +2454,240 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps( } } +AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( + AOTITensorHandle A, + AOTITensorHandle B, + int64_t group_size, + AOTITensorHandle S, + AOTITensorHandle Z, + AOTITensorHandle* ret) { + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Starting with Metal kernel implementation"); + + if (!A || !B || !S || !Z || !ret) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: null required tensor handles"); + return Error::InvalidArgument; + } + + // Validate group_size + if (group_size != 32 && group_size != 64 && group_size != 128 && group_size != 256) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Invalid group_size %lld (must be 32, 64, 128, or 256)", group_size); + return Error::InvalidArgument; + } + + // Use the same dispatch pattern as other MPS operations for consistent synchronization + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Failed to get current Metal stream"); + return Error::Internal; + } + + try { + @autoreleasepool { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto* a_tensor = reinterpret_cast(A); // Activation: [M, K] + auto* b_tensor = reinterpret_cast(B); // Weight (packed): [N, K/2] (4-bit packed) + auto* s_tensor = reinterpret_cast(S); // Scales: [N, num_groups] + auto* z_tensor = reinterpret_cast(Z); // Zero points: [N, num_groups] + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Converted tensor handles to ET tensors"); + + // Log tensor shapes for debugging + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A shape: [%d, %d], strides: [%d, %d]", + a_tensor->dim() > 0 ? (int)a_tensor->sizes()[0] : 0, + a_tensor->dim() > 1 ? (int)a_tensor->sizes()[1] : 0, + a_tensor->dim() > 0 ? (int)a_tensor->strides()[0] : 0, + a_tensor->dim() > 1 ? (int)a_tensor->strides()[1] : 0); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B shape: [%d, %d]", + b_tensor->dim() > 0 ? (int)b_tensor->sizes()[0] : 0, + b_tensor->dim() > 1 ? (int)b_tensor->sizes()[1] : 0); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S shape: [%d, %d], Z shape: [%d, %d]", + s_tensor->dim() > 0 ? (int)s_tensor->sizes()[0] : 0, + s_tensor->dim() > 1 ? (int)s_tensor->sizes()[1] : 0, + z_tensor->dim() > 0 ? (int)z_tensor->sizes()[0] : 0, + z_tensor->dim() > 1 ? (int)z_tensor->sizes()[1] : 0); + + // Validate tensor dimensions + if (a_tensor->dim() != 2) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor must be 2-D, got %d", (int)a_tensor->dim()); + return Error::InvalidArgument; + } + if (b_tensor->dim() != 2) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor must be 2-D, got %d", (int)b_tensor->dim()); + return Error::InvalidArgument; + } + + // Get dimensions: A is [M, K], B is [N, K/2] (4-bit packed, 2 values per byte) + int32_t M = static_cast(a_tensor->sizes()[0]); + int32_t K = static_cast(a_tensor->sizes()[1]); + int32_t N = static_cast(b_tensor->sizes()[0]); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: M=%d, K=%d, N=%d, group_size=%lld", M, K, N, group_size); + + // Validate alignment requirements + if (K % 8 != 0) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: K (%d) must be divisible by 8", K); + return Error::InvalidArgument; + } + if (N % 4 != 0) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: N (%d) must be divisible by 4", N); + return Error::InvalidArgument; + } + + // Determine data type + int32_t dtype = static_cast(a_tensor->scalar_type()); + size_t element_size; + std::string type_str; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + element_size = sizeof(float); + type_str = "float"; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + element_size = sizeof(uint16_t); + type_str = "bfloat"; + } else { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Unsupported data type: %d", dtype); + return Error::InvalidArgument; + } + + // Get shader library + ETMetalShaderLibrary* library = get_int4_shader_library(); + if (!library) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Failed to get shader library"); + return Error::Internal; + } + + // Select kernel based on dimensions (matching torchao's get_shader_func_and_dispatch) + std::string kernel_name; + bool use_qmv_fast = (M == 1 && N % 8 == 0 && K % 512 == 0); + + if (use_qmv_fast) { + // Use optimized qmv_fast kernel for M=1 case + kernel_name = "qmv_fast_4bit_" + std::to_string(group_size) + "_" + type_str; + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using qmv_fast kernel: %s", kernel_name.c_str()); + } else { + // Use general int4pack_mm kernel + kernel_name = "int4pack_mm_" + std::to_string(group_size) + "_" + type_str; + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using int4pack_mm kernel: %s", kernel_name.c_str()); + } + + // Get kernel function + auto kernel_func = library->getKernelFunction(kernel_name); + if (!kernel_func) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Failed to get kernel function: %s", kernel_name.c_str()); + return Error::Internal; + } + + // Allocate output tensor [M, N] + size_t out_size_bytes = M * N * element_size; + void* out_contents_ptr = nullptr; + id out_buffer = allocate_mtl_buffer(&out_contents_ptr, out_size_bytes); + + // Create output tensor handle + std::vector output_sizes = {M, N}; + std::vector output_strides = {N, 1}; + + AOTITensorHandle out_tensor_handle = nullptr; + AOTITorchError create_out_result = aoti_torch_create_tensor_from_blob_v2( + out_contents_ptr, + 2, // ndim + output_sizes.data(), + output_strides.data(), + 0, // storage_offset + dtype, + 13, // device_type (MPS) + 0, // device_index + &out_tensor_handle, + 0, // layout (strided) + nullptr, // opaque_metadata + 0 // opaque_metadata_size + ); + + if (create_out_result != Error::Ok || !out_tensor_handle) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: Failed to create output tensor"); + aoti_torch_mps_free(out_contents_ptr); + return Error::Internal; + } + + // Mark that we own the memory + extern std::unordered_map memory_to_n_tensor; + memory_to_n_tensor[out_contents_ptr] = 1; + + auto* out_tensor = reinterpret_cast(out_tensor_handle); + + // Prepare sizes array for the kernel (M, K, N, 0) + std::array sizes = { + static_cast(M), + static_cast(K), + static_cast(N), + 0 + }; + + // Execute kernel + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Preparing to execute kernel"); + + kernel_func->runCommandBlock([&]() { + kernel_func->startEncoding(); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Encoder started, setting arguments"); + + // Set buffer arguments + // Buffer 0: A (activation) [M, K] + kernel_func->setArg(0, *a_tensor); + // Buffer 1: B (weight, packed) [N, K/2] + kernel_func->setArg(1, *b_tensor); + // Buffer 2: scales [N, num_groups] + kernel_func->setArg(2, *s_tensor); + // Buffer 3: zeros [N, num_groups] + kernel_func->setArg(3, *z_tensor); + // Buffer 4: output [M, N] + kernel_func->setArg(4, *out_tensor); + // Buffer 5: sizes (M, K, N, 0) + kernel_func->setArg(5, sizes.data(), sizeof(uint32_t) * sizes.size()); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: All arguments set, dispatching"); + + // Dispatch based on kernel type (matching torchao dispatch patterns) + if (use_qmv_fast) { + // dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) + kernel_func->dispatchThreadgroups( + M, // gridX + (N + 7) / 8, // gridY + 1, // gridZ + 32, // threadsX + 2, // threadsY + 1); // threadsZ + } else { + // dispatch_mm_Mr1xNr4_per_TG: dispatchThreads with grid (N/4 * 32, 1, M), group (32, 1, 1) + uint64_t grid_dims[3] = {static_cast(N / 4 * 32), 1, static_cast(M)}; + uint64_t group_dims[3] = {32, 1, 1}; + kernel_func->dispatchArrayWithGroupSize(grid_dims, 3, group_dims, 3); + } + }); + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Command block completed"); + + // Set output tensor handle + *ret = out_tensor_handle; + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Metal kernel implementation completed successfully"); + + } // @autoreleasepool + + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Executed successfully"); + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: unknown exception"); + return Error::Internal; + } +} + } // extern "C" } // namespace metal diff --git a/backends/apple/metal/runtime/shims/utils.cpp b/backends/apple/metal/runtime/shims/utils.cpp index 061360a4e28..50b46ec69d4 100644 --- a/backends/apple/metal/runtime/shims/utils.cpp +++ b/backends/apple/metal/runtime/shims/utils.cpp @@ -19,6 +19,7 @@ extern "C" { // Helper function to check if a dtype is supported in Metal backend bool is_dtype_supported_in_et_metal(int32_t dtype) { switch (dtype) { + case static_cast(SupportedDTypes::UINT8): case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): case static_cast(SupportedDTypes::BFLOAT16): @@ -36,8 +37,9 @@ AOTITorchError validate_dtype(int32_t dtype) { ET_LOG( Error, - "Unsupported dtype: %d. Supported dtypes: %d (int64), %d (float32), %d (bfloat16)", + "Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)", dtype, + static_cast(SupportedDTypes::UINT8), static_cast(SupportedDTypes::INT64), static_cast(SupportedDTypes::FLOAT32), static_cast(SupportedDTypes::BFLOAT16)); diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h index 974832fa365..2ac5d1d5857 100644 --- a/backends/apple/metal/runtime/shims/utils.h +++ b/backends/apple/metal/runtime/shims/utils.h @@ -19,7 +19,7 @@ namespace metal { // Enum for supported data types in et-metal backend enum class SupportedDTypes : int32_t { - // UINT8 = 0, // PyTorch's uint8 dtype code + UINT8 = 0, // PyTorch's uint8 dtype code // INT8 = 1, // PyTorch's int8 dtype code // INT16 = 2, // PyTorch's int16 dtype code // INT32 = 3, // PyTorch's int32 dtype code From f166c5040a815fe4757b54dcce702e98b758a905 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 16:27:27 -0500 Subject: [PATCH 09/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 178 ++++++++++++++++++++- 1 file changed, 171 insertions(+), 7 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 1828545c8a0..b2e91494c89 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -31,6 +31,17 @@ from torch.export import export from torch.nn.attention import SDPBackend +try: + from torchao.quantization.quant_api import quantize_ + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + + # Need to import to load the ops + import torchao.experimental.ops.mps # noqa: F401 + + TORCHAO_AVAILABLE = True +except ImportError: + TORCHAO_AVAILABLE = False + # Check if MPS is available for export tests MPS_AVAILABLE = torch.backends.mps.is_available() @@ -88,6 +99,31 @@ # - "rtol_": float - Override relative tolerance for specific dtype (e.g., "rtol_bfloat16") # - "skip": bool or str - Skip all tests for this module (True to skip, or string with reason) # - "skip_": bool or str - Skip tests for specific dtype (e.g., "skip_bfloat16") +# - "qlinear": str - Quantization config for linear layers (e.g., "fpa4w" for 4-bit weights) +# - "qlinear_group_size": int - Group size for quantization (default: 32) +# - "compare_to_unquantized": bool - If True, compare quantized model output to unquantized reference (default: True for quantized models) +# +# Quantization Usage: +# To enable int4 quantization for a module, add "qlinear": "fpa4w" to its registry entry. +# This applies 4-bit weight quantization (floating point activation, 4-bit weight) using torchao. +# The quantization is applied after converting the model to the specified dtype but before export. +# +# By default, quantized models are compared against unquantized reference models to measure +# the actual quantization error. Set "compare_to_unquantized": False to compare against +# the quantized PyTorch model instead. +# +# Example: +# MODULE_REGISTRY["my_linear_model"] = { +# "model_class": MyLinearModel, +# "input_shapes": [(128, 256)], +# "description": "My linear model with int4 quantization", +# "qlinear": "fpa4w", +# "qlinear_group_size": 32, +# "compare_to_unquantized": True, # Compare to unquantized reference +# "atol_float32": 5e-2, # Quantization reduces precision, so increase tolerance +# "rtol_float32": 5e-2, +# } +# MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {} @@ -180,6 +216,31 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearNoBiasInt4(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 256, bias=False) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_nobias_int4"] = { + "model_class": LinearNoBiasInt4, + "input_shapes": [(127, 128)], + "description": "Linear layer without bias and int4 quantization", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- @@ -466,9 +527,24 @@ def should_skip_model(model_name: str, dtype: torch.dtype) -> Tuple[bool, str]: def get_model_and_inputs( - model_name: str, dtype: torch.dtype = torch.float32 + model_name: str, + dtype: torch.dtype = torch.float32, + qlinear: Optional[str] = None, + qlinear_group_size: Optional[int] = None, ) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]: - """Get model and example inputs based on model name.""" + """Get model and example inputs based on model name. + + Args: + model_name: Name of the model to create + dtype: Data type for the model (default: torch.float32) + qlinear: Optional quantization config (e.g., "fpa4w" for 4-bit weights). + If None, uses value from MODULE_REGISTRY if present. + qlinear_group_size: Group size for quantization. If None, uses value from + MODULE_REGISTRY if present, otherwise defaults to 32. + + Returns: + Tuple of (model, example_inputs) + """ if model_name not in MODULE_REGISTRY: available_models = ", ".join(MODULE_REGISTRY.keys()) raise ValueError( @@ -479,15 +555,58 @@ def get_model_and_inputs( model_class = model_config["model_class"] input_shapes = model_config["input_shapes"] + # Use registry values if not explicitly provided + if qlinear is None: + qlinear = model_config.get("qlinear") + if qlinear_group_size is None: + qlinear_group_size = model_config.get("qlinear_group_size", 32) + model = model_class().eval() if dtype is not None: model = model.to(dtype) + # Apply quantization if requested + if qlinear is not None: + quantize_model(model, qlinear, qlinear_group_size) + example_inputs = tuple(torch.randn(*shape, dtype=dtype) for shape in input_shapes) return model, example_inputs +def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32): + """Apply quantization to the model's linear layers. + + Args: + model: The model to quantize (in-place). + qlinear: Quantization config. Options: + - "fpa4w": Floating point activation, 4-bit weight (Metal backend) + qlinear_group_size: Group size for quantization (default: 32). + """ + if not TORCHAO_AVAILABLE: + raise RuntimeError( + "torchao is not available. Install torchao to use quantization." + ) + + if qlinear == "fpa4w": + linear_config = UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) + else: + raise ValueError(f"Unsupported linear quantization config '{qlinear}'.") + + def linear_filter(module, fqn): + if isinstance(module, torch.nn.Linear): + # Check if hidden dimension is divisible by group size + return qlinear_group_size == 0 or ( + module.weight.shape[1] % qlinear_group_size == 0 + ) + return False + + quantize_(model, linear_config, filter_fn=linear_filter) + + def export_model_to_metal( model: nn.Module, example_inputs: Tuple[torch.Tensor, ...] ) -> Any: @@ -515,17 +634,50 @@ def export_model_to_files( example_inputs: Tuple[torch.Tensor, ...], output_dir: Path, model_name: str, -) -> Tuple[Path, Path, torch.Tensor]: + compare_to_unquantized: bool = False, + model_config: Optional[Dict[str, Any]] = None, +) -> Tuple[Path, torch.Tensor]: """ - Export model to .pte and .ptd files, and compute expected output. + Export model to .pte file and compute expected output. + + Args: + model: Model to export (may be quantized) + example_inputs: Example inputs for export + output_dir: Directory to save output files + model_name: Name of the model + compare_to_unquantized: If True and model has quantization config, + compute expected output from unquantized model + model_config: Model configuration from MODULE_REGISTRY Returns: - Tuple of (pte_path, ptd_path, expected_output) + Tuple of (pte_path, expected_output) """ # Compute expected output using all-ones input (matching export_aoti_metal.py) all_ones_input = tuple(torch.ones_like(inp) for inp in example_inputs) + with torch.no_grad(): - expected_output = model(*all_ones_input) + if compare_to_unquantized and model_config and model_config.get("qlinear"): + # Create unquantized reference model for comparison + dtype = example_inputs[0].dtype if example_inputs else torch.float32 + model_class = model_config["model_class"] + reference_model = model_class().eval() + reference_model = reference_model.to(dtype) + expected_output = reference_model(*all_ones_input) + else: + # Use the quantized model's output + # For quantized models, torchao operators require MPS device + if model_config and model_config.get("qlinear") and MPS_AVAILABLE: + # Move model and inputs to MPS + model_mps = model.to("mps") + all_ones_input_mps = tuple(inp.to("mps") for inp in all_ones_input) + expected_output = model_mps(*all_ones_input_mps) + # Move output back to CPU for comparison + expected_output = expected_output.cpu() + # Move model back to CPU for export + model = model_mps.to("cpu") + else: + # Non-quantized model, run on CPU + expected_output = model(*all_ones_input) # Export to executorch executorch_program = export_model_to_metal(model, example_inputs) @@ -765,6 +917,7 @@ def _test_module_output_consistency( model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) dtype_name = DTYPE_NAMES[dtype] test_subdir_name = f"{model_name}_{dtype_name}" + model_config = MODULE_REGISTRY.get(model_name, {}) def run_test_in_directory(test_dir: Path) -> None: """Run the actual test logic in the given directory.""" @@ -772,9 +925,20 @@ def run_test_in_directory(test_dir: Path) -> None: model_output_dir = test_dir / test_subdir_name model_output_dir.mkdir(parents=True, exist_ok=True) + # Determine if we should compare to unquantized reference + # Default to True for quantized models, False otherwise + compare_to_unquantized = model_config.get( + "compare_to_unquantized", bool(model_config.get("qlinear")) + ) + # Export model and get expected output pte_path, expected_output = export_model_to_files( - model, example_inputs, model_output_dir, model_name + model, + example_inputs, + model_output_dir, + model_name, + compare_to_unquantized=compare_to_unquantized, + model_config=model_config, ) self.assertTrue( From 08346599720e32ec8d58bc5a7050603694c97003 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 17:26:39 -0500 Subject: [PATCH 10/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 38 ++++++++++++++++++---- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 6a5eaeb9a53..59904bb494d 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -88,6 +88,14 @@ # - "rtol_": float - Override relative tolerance for specific dtype (e.g., "rtol_bfloat16") # - "skip": bool or str - Skip all tests for this module (True to skip, or string with reason) # - "skip_": bool or str - Skip tests for specific dtype (e.g., "skip_bfloat16") +# +# Model Parameter Initialization: +# Model parameters are initialized with their default dtype (typically float32) when the +# model class is instantiated. The parameters are then converted to the target dtype using +# model.to(dtype). For example: +# - nn.Parameter(torch.arange(20, dtype=torch.get_default_dtype()) creates float32 parameters +# - These are converted to bfloat16 when model.to(torch.bfloat16) is called +# MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {} @@ -129,7 +137,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): class MmWeightParam(nn.Module): def __init__(self): super().__init__() - self.weight = nn.Parameter(torch.arange(20, dtype=torch.float).reshape(4, 5)) + self.weight = nn.Parameter( + torch.arange(20, dtype=torch.get_default_dtype()).reshape(4, 5) + ) def forward(self, x: torch.Tensor): return x.mm(self.weight) @@ -451,7 +461,18 @@ def should_skip_model(model_name: str, dtype: torch.dtype) -> Tuple[bool, str]: def get_model_and_inputs( model_name: str, dtype: torch.dtype = torch.float32 ) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]: - """Get model and example inputs based on model name.""" + """Get model and example inputs based on model name. + + Note: Model parameters are initialized with their default dtype (typically float32) + during model instantiation, then converted to the target dtype using model.to(dtype). + + Args: + model_name: Name of the model to create + dtype: Target data type for the model (default: torch.float32) + + Returns: + Tuple of (model, example_inputs) + """ if model_name not in MODULE_REGISTRY: available_models = ", ".join(MODULE_REGISTRY.keys()) raise ValueError( @@ -462,7 +483,10 @@ def get_model_and_inputs( model_class = model_config["model_class"] input_shapes = model_config["input_shapes"] + # Create model with default parameter dtypes (typically float32) model = model_class().eval() + + # Convert model parameters to target dtype if specified if dtype is not None: model = model.to(dtype) @@ -493,17 +517,17 @@ def export_model_to_metal( return executorch_program -def export_model_to_files( +def export_model_to_pte( model: nn.Module, example_inputs: Tuple[torch.Tensor, ...], output_dir: Path, model_name: str, -) -> Tuple[Path, Path, torch.Tensor]: +) -> Tuple[Path, torch.Tensor]: """ - Export model to .pte and .ptd files, and compute expected output. + Export model to .pte file, and compute expected output. Returns: - Tuple of (pte_path, ptd_path, expected_output) + Tuple of (pte_path, expected_output) """ # Compute expected output using all-ones input (matching export_aoti_metal.py) all_ones_input = tuple(torch.ones_like(inp) for inp in example_inputs) @@ -756,7 +780,7 @@ def run_test_in_directory(test_dir: Path) -> None: model_output_dir.mkdir(parents=True, exist_ok=True) # Export model and get expected output - pte_path, expected_output = export_model_to_files( + pte_path, expected_output = export_model_to_pte( model, example_inputs, model_output_dir, model_name ) From fe5be37b3021ec4657d004359fd5487c1d45416d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 18:34:05 -0500 Subject: [PATCH 11/26] Update [ghstack-poisoned] --- .github/workflows/metal.yml | 1 + backends/apple/metal/tests/run_metal_test.sh | 22 ++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 50ab0a70e1c..5d06fece457 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -45,6 +45,7 @@ jobs: echo "::endgroup::" echo "::group::Build Metal Runtime" + ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --update-ao ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build echo "::endgroup::" diff --git a/backends/apple/metal/tests/run_metal_test.sh b/backends/apple/metal/tests/run_metal_test.sh index 0f70c20ea4e..488bca6d14c 100755 --- a/backends/apple/metal/tests/run_metal_test.sh +++ b/backends/apple/metal/tests/run_metal_test.sh @@ -8,6 +8,7 @@ # Script to build and run Metal backend tests # Usage: # ./run_metal_test.sh --build # Build the Metal runtime +# ./run_metal_test.sh --update-ao # Update and build torchao with experimental MPS support # ./run_metal_test.sh --run # Run inference with given model file # ./run_metal_test.sh --check-build # Check if runtime is already built @@ -29,6 +30,23 @@ check_build() { fi } +# Function to update and build torchao +update_and_build_torchao() { + echo "Building torchao..." + TORCHAO_DIR="$EXECUTORCH_ROOT/third-party/ao" + if [[ -d "$TORCHAO_DIR" ]]; then + cd "$TORCHAO_DIR" + echo "Pulling latest changes from ao repository..." + git pull origin main + USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation + cd "$EXECUTORCH_ROOT" + echo "torchao build complete" + else + echo "Error: torchao directory not found at $TORCHAO_DIR" + exit 1 + fi +} + # Function to build the Metal runtime build_runtime() { echo "Building Metal runtime..." @@ -97,6 +115,9 @@ case "$1" in --build) build_runtime ;; + --update-ao) + update_and_build_torchao + ;; --run) if [[ -z "$2" ]]; then echo "Usage: $0 --run " @@ -112,6 +133,7 @@ case "$1" in echo "" echo "Usage:" echo " $0 --build Build the Metal runtime" + echo " $0 --update-ao Update and build torchao with experimental MPS support" echo " $0 --run Run inference with given model file" echo " $0 --check-build Check if runtime is already built" exit 1 From fcfa8320042cd4c5dd7bd99bf4c947b8d568a725 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 18:52:44 -0500 Subject: [PATCH 12/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/run_metal_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/metal/tests/run_metal_test.sh b/backends/apple/metal/tests/run_metal_test.sh index 488bca6d14c..cc8af8ce2c4 100755 --- a/backends/apple/metal/tests/run_metal_test.sh +++ b/backends/apple/metal/tests/run_metal_test.sh @@ -37,6 +37,7 @@ update_and_build_torchao() { if [[ -d "$TORCHAO_DIR" ]]; then cd "$TORCHAO_DIR" echo "Pulling latest changes from ao repository..." + git checkout main git pull origin main USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation cd "$EXECUTORCH_ROOT" From 0145613681c76f2ff5519c3b7f360d282fbeca8d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 2 Feb 2026 18:55:04 -0500 Subject: [PATCH 13/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 6dceaa5fc9b..7f12963ace6 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -251,6 +251,30 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearWithBiasInt4(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 256, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_bias_int4"] = { + "model_class": LinearWithBiasInt4, + "input_shapes": [(127, 128)], + "description": "Linear layer with bias and int4 quantization", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- From c5a3c1a2949c52a89f3f6eb0e77fc21dcd62c48d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 11:53:18 -0500 Subject: [PATCH 14/26] Update [ghstack-poisoned] --- .github/workflows/metal.yml | 2 +- backends/apple/metal/metal_backend.py | 2 ++ backends/apple/metal/runtime/metal_backend.cpp | 5 ++++- backends/apple/metal/runtime/shims/utils.h | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 5d06fece457..bf86b01aff8 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -96,7 +96,7 @@ jobs: echo "::endgroup::" echo "::group::Setup ExecuTorch" - PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh echo "::endgroup::" echo "::group::Pip List" diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index fd114b20b3d..20d558ecaa7 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -48,6 +48,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] from executorch.backends.apple.metal.passes.decompose_linear_pass import ( DecomposeLinearPass, ) + return [DecomposeLinearPass()] @classmethod @@ -72,6 +73,7 @@ def get_aoti_compile_options( } from torchao.experimental.ops.mps.cshim import torchao_op_c_shim + inductor_configs["aot_inductor.custom_ops_to_c_shims"] = torchao_op_c_shim return inductor_configs diff --git a/backends/apple/metal/runtime/metal_backend.cpp b/backends/apple/metal/runtime/metal_backend.cpp index 7c626da4e5b..777fa597b72 100644 --- a/backends/apple/metal/runtime/metal_backend.cpp +++ b/backends/apple/metal/runtime/metal_backend.cpp @@ -315,7 +315,10 @@ class ET_EXPERIMENTAL MetalBackend final "Failed to load shared library: %s", dlerror()); - ET_LOG(Info, "MetalBackend::init - Loaded shared library: %s", so_path.c_str()); + ET_LOG( + Info, + "MetalBackend::init - Loaded shared library: %s", + so_path.c_str()); processed->Free(); diff --git a/backends/apple/metal/runtime/shims/utils.h b/backends/apple/metal/runtime/shims/utils.h index 2ac5d1d5857..60412812b16 100644 --- a/backends/apple/metal/runtime/shims/utils.h +++ b/backends/apple/metal/runtime/shims/utils.h @@ -19,7 +19,7 @@ namespace metal { // Enum for supported data types in et-metal backend enum class SupportedDTypes : int32_t { - UINT8 = 0, // PyTorch's uint8 dtype code + UINT8 = 0, // PyTorch's uint8 dtype code // INT8 = 1, // PyTorch's int8 dtype code // INT16 = 2, // PyTorch's int16 dtype code // INT32 = 3, // PyTorch's int32 dtype code From fec15bc7aa30e53ae35aa6457bdee9b5b1995a3c Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 12:30:11 -0500 Subject: [PATCH 15/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 7b98ae2801f..d1956fd2913 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -888,15 +888,6 @@ def _test_module_export( model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) - # Verify model forward pass works before export - with torch.no_grad(): - model_output = model(*example_inputs) - - self.assertIsNotNone( - model_output, - f"{model_name} ({DTYPE_NAMES[dtype]}): Forward pass returned None", - ) - # Export to Metal backend executorch_program = export_model_to_metal(model, example_inputs) From c16dc590c40f6895b5bf1810fa768c85b5d9ed18 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 19:03:02 -0500 Subject: [PATCH 16/26] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 105 ++++++++++++++---- 1 file changed, 82 insertions(+), 23 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index ec884b50776..c484a03c433 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -2492,30 +2492,38 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Converted tensor handles to ET tensors"); - // Log tensor shapes for debugging - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A shape: [%d, %d], strides: [%d, %d]", - a_tensor->dim() > 0 ? (int)a_tensor->sizes()[0] : 0, - a_tensor->dim() > 1 ? (int)a_tensor->sizes()[1] : 0, - a_tensor->dim() > 0 ? (int)a_tensor->strides()[0] : 0, - a_tensor->dim() > 1 ? (int)a_tensor->strides()[1] : 0); - - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B shape: [%d, %d]", - b_tensor->dim() > 0 ? (int)b_tensor->sizes()[0] : 0, - b_tensor->dim() > 1 ? (int)b_tensor->sizes()[1] : 0); - - ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S shape: [%d, %d], Z shape: [%d, %d]", - s_tensor->dim() > 0 ? (int)s_tensor->sizes()[0] : 0, - s_tensor->dim() > 1 ? (int)s_tensor->sizes()[1] : 0, - z_tensor->dim() > 0 ? (int)z_tensor->sizes()[0] : 0, - z_tensor->dim() > 1 ? (int)z_tensor->sizes()[1] : 0); - - // Validate tensor dimensions + // Validate A tensor: ndim, dtype, contiguity if (a_tensor->dim() != 2) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor must be 2-D, got %d", (int)a_tensor->dim()); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 2D tensor, got %d", (int)a_tensor->dim()); + return Error::InvalidArgument; + } + auto a_dtype = a_tensor->scalar_type(); + if (a_dtype != exec_aten::ScalarType::Float && + a_dtype != exec_aten::ScalarType::BFloat16) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be 32-bit or 16-bit float tensor, got dtype %d", (int)a_dtype); + return Error::InvalidArgument; + } + // Check A is contiguous (stride[1] == 1 and stride[0] == size[1]) + if (a_tensor->strides()[1] != 1 || a_tensor->strides()[0] != a_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect A to be contiguous, strides=[%lld, %lld]", + (long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]); return Error::InvalidArgument; } + + + // Validate B tensor: ndim, dtype (uint8), contiguity if (b_tensor->dim() != 2) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor must be 2-D, got %d", (int)b_tensor->dim()); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be 2D tensor, got %d", (int)b_tensor->dim()); + return Error::InvalidArgument; + } + if (b_tensor->scalar_type() != exec_aten::ScalarType::Byte) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be uint8 tensor, got dtype %d", (int)b_tensor->scalar_type()); + return Error::InvalidArgument; + } + // Check B is contiguous + if (b_tensor->strides()[1] != 1 || b_tensor->strides()[0] != b_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B to be contiguous, strides=[%lld, %lld]", + (long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]); return Error::InvalidArgument; } @@ -2523,18 +2531,67 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( int32_t M = static_cast(a_tensor->sizes()[0]); int32_t K = static_cast(a_tensor->sizes()[1]); int32_t N = static_cast(b_tensor->sizes()[0]); + constexpr int nbit = 4; ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: M=%d, K=%d, N=%d, group_size=%lld", M, K, N, group_size); - // Validate alignment requirements + // B.size(1) should be (K / 8) * nbit for 4-bit packing + int64_t expected_b_size1 = (K / 8) * nbit; + if (b_tensor->sizes()[1] != expected_b_size1) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect B.size(1) == %lld, got %lld", + (long long)expected_b_size1, (long long)b_tensor->sizes()[1]); + return Error::InvalidArgument; + } + + // Validate K alignment if (K % 8 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: K (%d) must be divisible by 8", K); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect K to be multiple of 8, got %d", K); return Error::InvalidArgument; } + + // Validate N alignment if (N % 4 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: N (%d) must be divisible by 4", N); + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N); + return Error::InvalidArgument; + } + + // Validate S tensor: 2D with S.size(0) == N, contiguous + if (s_tensor->dim() != 2 || s_tensor->sizes()[0] != N) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld", + N, (int)s_tensor->dim(), (long long)s_tensor->sizes()[0]); + return Error::InvalidArgument; + } + if (s_tensor->strides()[1] != 1 || s_tensor->strides()[0] != s_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect S to be contiguous, strides=[%lld, %lld]", + (long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]); + return Error::InvalidArgument; + } + + // Validate Z tensor: 2D with Z.size(0) == N, contiguous + if (z_tensor->dim() != 2 || z_tensor->sizes()[0] != N) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be 2D tensor with shape [%d, :], got dim=%d, size[0]=%lld", + N, (int)z_tensor->dim(), (long long)z_tensor->sizes()[0]); return Error::InvalidArgument; } + if (z_tensor->strides()[1] != 1 || z_tensor->strides()[0] != z_tensor->sizes()[1]) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect Z to be contiguous, strides=[%lld, %lld]", + (long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]); + return Error::InvalidArgument; + } + + // Log shapes and strides for all tensors + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: A tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)a_tensor->sizes()[0], (long long)a_tensor->sizes()[1], + (long long)a_tensor->strides()[0], (long long)a_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: B tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)b_tensor->sizes()[0], (long long)b_tensor->sizes()[1], + (long long)b_tensor->strides()[0], (long long)b_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: S tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)s_tensor->sizes()[0], (long long)s_tensor->sizes()[1], + (long long)s_tensor->strides()[0], (long long)s_tensor->strides()[1]); + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Z tensor shape=[%lld, %lld], strides=[%lld, %lld]", + (long long)z_tensor->sizes()[0], (long long)z_tensor->sizes()[1], + (long long)z_tensor->strides()[0], (long long)z_tensor->strides()[1]); // Determine data type int32_t dtype = static_cast(a_tensor->scalar_type()); @@ -2652,6 +2709,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( // Dispatch based on kernel type (matching torchao dispatch patterns) if (use_qmv_fast) { // dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); kernel_func->dispatchThreadgroups( M, // gridX (N + 7) / 8, // gridY @@ -2661,6 +2719,7 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( 1); // threadsZ } else { // dispatch_mm_Mr1xNr4_per_TG: dispatchThreads with grid (N/4 * 32, 1, M), group (32, 1, 1) + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); uint64_t grid_dims[3] = {static_cast(N / 4 * 32), 1, static_cast(M)}; uint64_t group_dims[3] = {32, 1, 1}; kernel_func->dispatchArrayWithGroupSize(grid_dims, 3, group_dims, 3); From 8ee7d609edc4ff947a556ece64150c8fa8b57839 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 19:03:06 -0500 Subject: [PATCH 17/26] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 460 +++++++++++++++++- backends/apple/metal/tests/test_modules.py | 26 + 2 files changed, 481 insertions(+), 5 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index c484a03c433..1cbad91f83c 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -696,6 +696,112 @@ inline U load_vector(constant T* x, thread U* x_thread) { return sum; } + template + inline U load_vector_safe(constant T* x, thread U* x_thread, int N) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U sum = 0; + + if (bits == 1) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 2.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 8.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 32.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 128.0f; + } + } + + else if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 7) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 128.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 32.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 8.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 2.0f; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; + } + template inline U qdot( constant uint8_t* w, @@ -838,6 +944,149 @@ inline U qdot( return scale * accum + sum * bias; } + template + inline U qdot_safe( + constant uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + 1 <= bits && bits <= 7, + "Template undefined for bits not in {1, 2, 3, 4, 5, 6, 7}"); + + U accum = 0; + + if (bits == 1) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + + accum += + (x_thread[0] * (w[i] & 0x01) + + x_thread[1] * (w[i] & 0x02) + + x_thread[2] * (w[i] & 0x04) + + x_thread[3] * (w[i] & 0x08) + + x_thread[4] * (w[i] & 0x10) + + x_thread[5] * (w[i] & 0x20) + + x_thread[6] * (w[i] & 0x40) + + x_thread[7] * (w[i] & 0x80)); + } + } + + else if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + constant uint16_t* ws = (constant uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + + accum += (w[1] & 0x03) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + + accum += (w[2] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + + accum += (w[3] & 0x01) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + + accum += (w[4] & 0x07) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 7) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 7 * i; + + accum += (w[0] & 0x7f) * x_thread[0]; + accum += (w[0] & 0x80) * x_thread[1]; + + accum += (w[1] & 0x3f) * (x_thread[1] * 256.0f); + accum += (w[1] & 0xc0) * x_thread[2]; + + accum += (w[2] & 0x1f) * (x_thread[2] * 256.0f); + accum += (w[2] & 0xe0) * x_thread[3]; + + accum += (w[3] & 0x0f) * (x_thread[3] * 256.0f); + accum += (w[3] & 0xf0) * x_thread[4]; + + accum += (w[4] & 0x07) * (x_thread[4] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[5]; + + accum += (w[5] & 0x03) * (x_thread[5] * 256.0f); + accum += (w[5] & 0xfc) * x_thread[6]; + + accum += (w[6] & 0x01) * (x_thread[6] * 256.0f); + accum += (w[6] & 0xfe) * x_thread[7]; + } + } + + return scale * accum + sum * bias; + } + template [[kernel]] void qmv_fast( constant T* x [[buffer(0)]], @@ -942,6 +1191,202 @@ inline U qdot( INSTANTIATE_QMV_FAST_DTYPE(bfloat); #endif + /** + * qmv_impl.metal - handles generic N (any even N, not just N % 8 == 0) + */ + + template + [[kernel]] void qmv_impl( + constant T* x [[buffer(0)]], + constant uchar* w [[buffer(1)]], + constant T* scales [[buffer(2)]], + constant T* biases [[buffer(3)]], + device T* y [[buffer(4)]], + constant uint3 &sizes [[buffer(5)]], // M, K, N + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int in_vec_size = static_cast(sizes.y); // K + const int out_vec_size = static_cast(sizes.z); // N + + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = (bits == 1 || bits == 2) ? 1 : 2; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int pack_factor = bits == 1 ? 16 : power_of_2_bits ? 32 / bits : bits == 6 ? 4 : 8; + constexpr int bytes_per_pack = bits == 1 ? 2 : power_of_2_bits ? 4 : bits == 6 ? 3 : bits; + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + constant uint8_t* ws = (constant uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (constant uint8_t*)(ws + row * in_vec_size_w); + constant T* sl = scales + row * in_vec_size_g; + constant T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + } + + #define INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, NBIT) \ + template [[host_name("qmv_impl_" #NBIT "bit_" #GSIZE "_" #DTYPE)]] kernel void \ + qmv_impl( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scales_ptr [[buffer(2)]], \ + constant DTYPE * zeros_ptr [[buffer(3)]], \ + device DTYPE * output_data [[buffer(4)]], \ + constant uint3 & sizes [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint tid_in_simdgroup [[thread_index_in_simdgroup]]) + + #define INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, GSIZE) \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 1); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 2); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 3); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 4); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 5); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 6); \ + INSTANTIATE_QMV_IMPL(DTYPE, GSIZE, 7); + + #define INSTANTIATE_QMV_IMPL_DTYPE(DTYPE) \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 32); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 64); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 128); \ + INSTANTIATE_QMV_IMPL_DTYPE_GSIZE(DTYPE, 256); + + INSTANTIATE_QMV_IMPL_DTYPE(float); + INSTANTIATE_QMV_IMPL_DTYPE(half); + #if __METAL_VERSION__ >= 310 + INSTANTIATE_QMV_IMPL_DTYPE(bfloat); + #endif + )"; } @@ -2550,8 +2995,8 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( } // Validate N alignment - if (N % 4 != 0) { - ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4, got M=%d, N=%d", M, N); + if (N % 4 != 0 && M != 1) { + ET_LOG(Error, "aoti_torch_mps__linear_fp_act_4bit_weight: expect N to be multiple of 4 when M != 1, got M=%d, N=%d", M, N); return Error::InvalidArgument; } @@ -2619,11 +3064,16 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( // Select kernel based on dimensions (matching torchao's get_shader_func_and_dispatch) std::string kernel_name; bool use_qmv_fast = (M == 1 && N % 8 == 0 && K % 512 == 0); + bool use_qmv_impl = (M == 1 && !use_qmv_fast); if (use_qmv_fast) { - // Use optimized qmv_fast kernel for M=1 case + // Use optimized qmv_fast kernel for M=1 case with aligned dimensions kernel_name = "qmv_fast_4bit_" + std::to_string(group_size) + "_" + type_str; ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using qmv_fast kernel: %s", kernel_name.c_str()); + } else if (use_qmv_impl) { + // Use qmv_impl kernel for M=1 case with generic N (handles any even N) + kernel_name = "qmv_impl_4bit_" + std::to_string(group_size) + "_" + type_str; + ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Using qmv_impl kernel: %s", kernel_name.c_str()); } else { // Use general int4pack_mm kernel kernel_name = "int4pack_mm_" + std::to_string(group_size) + "_" + type_str; @@ -2707,8 +3157,8 @@ AOTITorchError aoti_torch_mps__linear_fp_act_4bit_weight( ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: All arguments set, dispatching"); // Dispatch based on kernel type (matching torchao dispatch patterns) - if (use_qmv_fast) { - // dispatch_qmv_fast: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) + if (use_qmv_fast || use_qmv_impl) { + // dispatch_qmv: dispatchThreadgroups with grid (M, (N+7)/8, 1), group (32, 2, 1) ET_LOG(Debug, "aoti_torch_mps__linear_fp_act_4bit_weight: Dispatching kernel: %s", kernel_name.c_str()); kernel_func->dispatchThreadgroups( M, // gridX diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index b841d0a47ab..fc3ca30afe4 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -274,6 +274,32 @@ def forward(self, x: torch.Tensor): "skip": not TORCHAO_AVAILABLE, } + +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(640, 8198, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl"] = { + "model_class": LinearInt4_QMV_IMPL, + "input_shapes": [(1, 640)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- From 9966d376e09567f3c6ea425c133e9203a960354d Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 3 Feb 2026 19:03:10 -0500 Subject: [PATCH 18/26] Update [ghstack-poisoned] --- .../apple/metal/runtime/shims/et_metal_ops.mm | 6 +-- backends/apple/metal/tests/test_modules.py | 50 +++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 1cbad91f83c..48816e2f3a4 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -1229,7 +1229,7 @@ inline U qdot_safe( // Adjust positions const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; + const int in_vec_size_g = (in_vec_size + group_size - 1) / group_size; const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + simd_gid * results_per_simdgroup; const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); @@ -1283,8 +1283,8 @@ inline U qdot_safe( U s = sl[0]; U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); } } diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index fc3ca30afe4..740dd146350 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -300,6 +300,56 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL_small_odd(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 3, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl_small_odd"] = { + "model_class": LinearInt4_QMV_IMPL_small_odd, + "input_shapes": [(1, 8)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + +# ------------------------------------------------------------------------- +class LinearInt4_QMV_IMPL_small_even(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 10, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_int4_qmv_impl_small_even"] = { + "model_class": LinearInt4_QMV_IMPL_small_even, + "input_shapes": [(1, 8)], + "description": "Linear int4 quantization dispatching to qmv_impl", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- From ade165f5f71fa4e20afb327b202b54352564f3bb Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Feb 2026 23:18:26 -0500 Subject: [PATCH 19/26] Update [ghstack-poisoned] --- .ci/scripts/export_model_artifact.sh | 19 +++++++--- .github/workflows/metal.yml | 12 ++++++ examples/models/parakeet/README.md | 31 +++++++++++----- .../models/parakeet/export_parakeet_tdt.py | 10 ++++- examples/models/parakeet/quantize.py | 37 ++++++++++++++++++- third-party/ao | 2 +- 6 files changed, 92 insertions(+), 19 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index d5c1913619d..eb89444a27a 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -26,13 +26,15 @@ Arguments: quant_name Quantization type (optional, default: non-quantized) Options: - non-quantized - - quantized-int4-tile-packed - - quantized-int4-weight-only + - quantized-int4-tile-packed (CUDA only) + - quantized-int4-weight-only (CUDA only) + - quantized-int4-metal (Metal only) output_dir Output directory for artifacts (optional, default: current directory) Examples: export_model_artifact.sh metal "openai/whisper-small" + export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal" export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output" export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output" @@ -127,21 +129,28 @@ case "$QUANT_NAME" in ;; quantized-int4-tile-packed) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" ;; quantized-int4-weight-only) if [ "$DEVICE" = "metal" ]; then - echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'" + echo "Error: Metal backend does not support quantization '$QUANT_NAME'" exit 1 fi EXTRA_ARGS="--qlinear_encoder 4w" ;; + quantized-int4-metal) + if [ "$DEVICE" != "metal" ]; then + echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend" + exit 1 + fi + EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w" + ;; *) echo "Error: Unsupported quantization '$QUANT_NAME'" - echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only" + echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal" exit 1 ;; esac diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index bf86b01aff8..1ca0910a67b 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -73,6 +73,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' @@ -123,6 +129,12 @@ jobs: name: "parakeet-tdt" quant: - "non-quantized" + # Only test int4 quantization with parakeet-tdt + include: + - model: + repo: "nvidia" + name: "parakeet-tdt" + quant: "quantized-int4-metal" with: runner: macos-m2-stable python-version: '3.11' diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 23593d324d1..756fce068c5 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -39,10 +39,10 @@ The export script supports quantizing encoder and decoder linear layers using [t | Argument | Description | |----------|-------------| -| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | | `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | -| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` | | `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | | `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | | `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | @@ -50,14 +50,15 @@ The export script supports quantizing encoder and decoder linear layers using [t #### Quantization Configs -| Config | Description | -|--------|-------------| -| `4w` | 4-bit weight only quantization | -| `8w` | 8-bit weight only quantization | -| `8da4w` | 8-bit dynamic activation, 4-bit weight | -| `8da8w` | 8-bit dynamic activation, 8-bit weight | +| Config | Description | Backends | +|--------|-------------|----------| +| `4w` | 4-bit weight only quantization | CUDA | +| `8w` | 8-bit weight only quantization | CUDA | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA | +| `fpa4w` | Floating point activation, 4-bit weight | Metal | -#### Example: 4-bit Weight Quantization with Tile Packing +#### Example: 4-bit Weight Quantization with Tile Packing (CUDA) ```bash python export_parakeet_tdt.py \ @@ -74,6 +75,18 @@ python export_parakeet_tdt.py \ **Note:** The `tile_packed_to_4d` packing format is optimized for CUDA. +#### Example: Metal 4-bit Quantization + +```bash +python export_parakeet_tdt.py \ + --backend metal \ + --qlinear_encoder fpa4w \ + --qlinear_encoder_group_size 32 \ + --qlinear fpa4w \ + --qlinear_group_size 32 \ + --output-dir ./parakeet_metal_quantized +``` + ### Metal Export (macOS) ```bash diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 703c25091ec..dd0beca3307 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -583,7 +583,7 @@ def main(): parser.add_argument( "--qlinear", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for decoder linear layers", ) parser.add_argument( @@ -603,7 +603,7 @@ def main(): parser.add_argument( "--qlinear_encoder", type=str, - choices=["4w", "8w", "8da4w", "8da8w"], + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], help="Quantization config for encoder linear layers", ) parser.add_argument( @@ -639,6 +639,12 @@ def main(): if args.dtype == "fp16": parser.error("fp16 is not yet supported") + # Validate fpa4w quantization requires Metal backend + if args.qlinear == "fpa4w" and args.backend != "metal": + parser.error("--qlinear=fpa4w can only be used with --backend=metal") + if args.qlinear_encoder == "fpa4w" and args.backend != "metal": + parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") diff --git a/examples/models/parakeet/quantize.py b/examples/models/parakeet/quantize.py index 3e540d84834..1d69662efa1 100644 --- a/examples/models/parakeet/quantize.py +++ b/examples/models/parakeet/quantize.py @@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901 Args: module: The PyTorch module to quantize. - qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w"). + qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w"). qlinear_group_size: Group size for linear quantization (default: 32). qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). qembedding_config: Quantization config for embedding layers ("4w", "8w"). @@ -26,12 +26,45 @@ def quantize_model_( # noqa: C901 if not qlinear_config and not qembedding_config: return + from torchao.quantization.quant_api import quantize_ + + # Metal (MPS) quantization uses different API + if qlinear_config == "fpa4w": + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + + # Load MPS ops + import torchao.experimental.ops.mps # noqa: F401 + + config = UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) + + # Filter: only quantize Linear layers with compatible dimensions + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if qlinear_group_size == 0: + return False + if m.weight.shape[1] % 8 != 0: + print( + f" Skipping {fqn}: weight shape {m.weight.shape} not multiple of 8" + ) + return False + return True + return False + + print( + f" Applying {qlinear_config} linear quantization " + f"(group_size={qlinear_group_size})..." + ) + quantize_(module, config, filter_fn=linear_filter) + return + from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig, - quantize_, ) # Quantize embedding layers first diff --git a/third-party/ao b/third-party/ao index 28306f08500..1b4b6d998bf 160000 --- a/third-party/ao +++ b/third-party/ao @@ -1 +1 @@ -Subproject commit 28306f085003b892bc0a250c209d80f5d4a5147b +Subproject commit 1b4b6d998bf988f059e97a10181cbc4aec269b69 From 11da5474f89cf70c0ed6a8b65c7e7ceb6a47ef31 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 4 Feb 2026 23:29:46 -0500 Subject: [PATCH 20/26] Update [ghstack-poisoned] --- backends/apple/metal/metal_backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index fd94c5ba7a7..36295c27786 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -47,6 +47,7 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] from executorch.backends.apple.metal.passes.decompose_linear_pass import ( DecomposeLinearPass, ) + return [DecomposeLinearPass()] @classmethod From e81b589d0b9780d146876e3f798ae232f21abd0c Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 00:21:45 -0500 Subject: [PATCH 21/26] Update [ghstack-poisoned] --- examples/models/parakeet/quantize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/models/parakeet/quantize.py b/examples/models/parakeet/quantize.py index 1d69662efa1..ad9d2730c5b 100644 --- a/examples/models/parakeet/quantize.py +++ b/examples/models/parakeet/quantize.py @@ -30,10 +30,9 @@ def quantize_model_( # noqa: C901 # Metal (MPS) quantization uses different API if qlinear_config == "fpa4w": - from torchao.experimental.quant_api import UIntxWeightOnlyConfig - # Load MPS ops import torchao.experimental.ops.mps # noqa: F401 + from torchao.experimental.quant_api import UIntxWeightOnlyConfig config = UIntxWeightOnlyConfig( group_size=qlinear_group_size, From 8ff273f5556bf865a9603fa616f34efaf6ce9742 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 13:42:19 -0500 Subject: [PATCH 22/26] Update [ghstack-poisoned] --- .github/workflows/metal.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/metal.yml b/.github/workflows/metal.yml index 1ca0910a67b..ec15c87f737 100644 --- a/.github/workflows/metal.yml +++ b/.github/workflows/metal.yml @@ -41,11 +41,10 @@ jobs: set -eux echo "::group::Setup ExecuTorch" - PYTHON_EXECUTABLE=python ${CONDA_RUN} ./install_executorch.sh + PYTHON_EXECUTABLE=python ${CONDA_RUN} EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh echo "::endgroup::" echo "::group::Build Metal Runtime" - ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --update-ao ${CONDA_RUN} backends/apple/metal/tests/run_metal_test.sh --build echo "::endgroup::" From 4316164805381c6c1f4412b20ad25594d20b4f24 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 14:09:14 -0500 Subject: [PATCH 23/26] Update [ghstack-poisoned] --- examples/models/parakeet/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index bf9b1d76032..649a2225536 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -99,6 +99,18 @@ python export_parakeet_tdt.py \ --output-dir ./parakeet_metal_quantized ``` +**Note:** Metal 4-bit quantization requires torchao built with experimental MPS (Metal) ops. + +You can install torchao with Metal support from the `ao` repo: +```bash +USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . --no-build-isolation +``` + +Alternatively, you can build torchao with Metal support while installing ExecuTorch: +```bash +EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_executorch.sh +``` + ### Metal Export (macOS) ```bash From 401af46161fd53e83fbd470f7ef4bd9bf74538ed Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 14:45:34 -0500 Subject: [PATCH 24/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 8e0c4965170..4abec8f7fe3 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -694,12 +694,18 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32) else: raise ValueError(f"Unsupported linear quantization config '{qlinear}'.") - def linear_filter(module, fqn): - if isinstance(module, torch.nn.Linear): - # Check if hidden dimension is divisible by group size - return qlinear_group_size == 0 or ( - module.weight.shape[1] % qlinear_group_size == 0 - ) + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if qlinear_group_size == 0: + raise ValueError( + f"Invalid group_size=0 for Metal int4 quantization (layer: {fqn})" + ) + if m.weight.shape[1] % 8 != 0: + raise ValueError( + f"Metal int4 quantization requires weight dimension K to be multiple of 8. " + f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]})" + ) + return True return False quantize_(model, linear_config, filter_fn=linear_filter) From 87f15295607efce8d9d712a5dd715ecc4871a8aa Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 15:05:24 -0500 Subject: [PATCH 25/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 4abec8f7fe3..73ba3a8ed65 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -287,7 +287,7 @@ def forward(self, x: torch.Tensor): class LinearInt4_QMV_IMPL_small_odd(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(8, 3, bias=True) + self.linear = nn.Linear(32, 3, bias=True) def forward(self, x: torch.Tensor): return self.linear(x) @@ -295,7 +295,7 @@ def forward(self, x: torch.Tensor): MODULE_REGISTRY["linear_int4_qmv_impl_small_odd"] = { "model_class": LinearInt4_QMV_IMPL_small_odd, - "input_shapes": [(1, 8)], + "input_shapes": [(1, 32)], "description": "Linear int4 quantization dispatching to qmv_impl", "qlinear": "fpa4w", "qlinear_group_size": 32, @@ -312,7 +312,7 @@ def forward(self, x: torch.Tensor): class LinearInt4_QMV_IMPL_small_even(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(8, 10, bias=True) + self.linear = nn.Linear(32, 10, bias=True) def forward(self, x: torch.Tensor): return self.linear(x) @@ -320,7 +320,7 @@ def forward(self, x: torch.Tensor): MODULE_REGISTRY["linear_int4_qmv_impl_small_even"] = { "model_class": LinearInt4_QMV_IMPL_small_even, - "input_shapes": [(1, 8)], + "input_shapes": [(1, 32)], "description": "Linear int4 quantization dispatching to qmv_impl", "qlinear": "fpa4w", "qlinear_group_size": 32, @@ -700,10 +700,10 @@ def linear_filter(m, fqn): raise ValueError( f"Invalid group_size=0 for Metal int4 quantization (layer: {fqn})" ) - if m.weight.shape[1] % 8 != 0: + if m.weight.shape[1] % qlinear_group_size != 0: raise ValueError( - f"Metal int4 quantization requires weight dimension K to be multiple of 8. " - f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]})" + f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. " + f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 ) return True return False From cf89a2b06298ff114ac67949d7889723ef08aab1 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 5 Feb 2026 15:15:49 -0500 Subject: [PATCH 26/26] Update [ghstack-poisoned] --- backends/apple/metal/tests/test_modules.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 73ba3a8ed65..403ce355381 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -696,14 +696,10 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32) def linear_filter(m, fqn): if isinstance(m, torch.nn.Linear): - if qlinear_group_size == 0: - raise ValueError( - f"Invalid group_size=0 for Metal int4 quantization (layer: {fqn})" - ) if m.weight.shape[1] % qlinear_group_size != 0: raise ValueError( f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. " - f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 + f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 ) return True return False