Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 160 additions & 14 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@
from torch.export import export
from torch.nn.attention import SDPBackend

try:
# Need to import to load the ops
import torchao.experimental.ops.mps # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyConfig
from torchao.quantization.quant_api import quantize_

TORCHAO_AVAILABLE = True
Comment on lines +35 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fail hard if you can't import

except ImportError:
TORCHAO_AVAILABLE = False


# Check if MPS is available for export tests
MPS_AVAILABLE = torch.backends.mps.is_available()
Expand Down Expand Up @@ -88,14 +98,39 @@
# - "rtol_<dtype>": float - Override relative tolerance for specific dtype (e.g., "rtol_bfloat16")
# - "skip": bool or str - Skip all tests for this module (True to skip, or string with reason)
# - "skip_<dtype>": bool or str - Skip tests for specific dtype (e.g., "skip_bfloat16")
# - "qlinear": str - Quantization config for linear layers (e.g., "fpa4w" for 4-bit weights)
# - "qlinear_group_size": int - Group size for quantization (default: 32)
# - "compare_to_unquantized": bool - If True, compare quantized model output to unquantized reference (default: True for quantized models)
#
# Quantization Usage:
# To enable int4 quantization for a module, add "qlinear": "fpa4w" to its registry entry.
# This applies 4-bit weight quantization (floating point activation, 4-bit weight) using torchao.
# The quantization is applied after converting the model to the specified dtype but before export.
#
# By default, quantized models are compared against unquantized reference models to measure
# the actual quantization error. Set "compare_to_unquantized": False to compare against
# the quantized PyTorch model instead.
#
# Example:
# MODULE_REGISTRY["my_linear_model"] = {
# "model_class": MyLinearModel,
# "input_shapes": [(128, 256)],
# "description": "My linear model with int4 quantization",
# "qlinear": "fpa4w",
# "qlinear_group_size": 32,
# "compare_to_unquantized": True, # Compare to unquantized reference
# "atol_float32": 5e-2, # Quantization reduces precision, so increase tolerance
# "rtol_float32": 5e-2,
# }
#
# Model Parameter Initialization:
# Model parameters are initialized with their default dtype (typically float32) when the
# model class is instantiated. The parameters are then converted to the target dtype using
# model.to(dtype). For example:
# - nn.Parameter(torch.arange(20, dtype=torch.get_default_dtype()) creates float32 parameters
# - These are converted to bfloat16 when model.to(torch.bfloat16) is called
#


MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {}


Expand Down Expand Up @@ -190,6 +225,31 @@ def forward(self, x: torch.Tensor):
}


# -------------------------------------------------------------------------
class LinearNoBiasInt4(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(128, 256, bias=False)

def forward(self, x: torch.Tensor):
return self.linear(x)


MODULE_REGISTRY["linear_nobias_int4"] = {
"model_class": LinearNoBiasInt4,
"input_shapes": [(127, 128)],
"description": "Linear layer without bias and int4 quantization",
"qlinear": "fpa4w",
"qlinear_group_size": 32,
"compare_to_unquantized": False,
"atol_float32": 5e-2,
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


# -------------------------------------------------------------------------
# Convolution Modules
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -476,7 +536,10 @@ def should_skip_model(model_name: str, dtype: torch.dtype) -> Tuple[bool, str]:


def get_model_and_inputs(
model_name: str, dtype: torch.dtype = torch.float32
model_name: str,
dtype: torch.dtype = torch.float32,
qlinear: Optional[str] = None,
qlinear_group_size: Optional[int] = None,
) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]:
"""Get model and example inputs based on model name.

Expand All @@ -486,6 +549,10 @@ def get_model_and_inputs(
Args:
model_name: Name of the model to create
dtype: Target data type for the model (default: torch.float32)
qlinear: Optional quantization config (e.g., "fpa4w" for 4-bit weights).
If None, uses value from MODULE_REGISTRY if present.
qlinear_group_size: Group size for quantization. If None, uses value from
MODULE_REGISTRY if present, otherwise defaults to 32.

Returns:
Tuple of (model, example_inputs)
Expand All @@ -500,18 +567,61 @@ def get_model_and_inputs(
model_class = model_config["model_class"]
input_shapes = model_config["input_shapes"]

# Use registry values if not explicitly provided
if qlinear is None:
qlinear = model_config.get("qlinear")
if qlinear_group_size is None:
qlinear_group_size = model_config.get("qlinear_group_size", 32)

# Create model with default parameter dtypes (typically float32)
model = model_class().eval()

# Convert model parameters to target dtype if specified
if dtype is not None:
model = model.to(dtype)

# Apply quantization if requested
if qlinear is not None:
quantize_model(model, qlinear, qlinear_group_size)

example_inputs = tuple(torch.randn(*shape, dtype=dtype) for shape in input_shapes)

return model, example_inputs


def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32):
"""Apply quantization to the model's linear layers.

Args:
model: The model to quantize (in-place).
qlinear: Quantization config. Options:
- "fpa4w": Floating point activation, 4-bit weight (Metal backend)
qlinear_group_size: Group size for quantization (default: 32).
"""
if not TORCHAO_AVAILABLE:
raise RuntimeError(
"torchao is not available. Install torchao to use quantization."
)

if qlinear == "fpa4w":
linear_config = UIntxWeightOnlyConfig(
group_size=qlinear_group_size,
bitwidth=4,
)
else:
raise ValueError(f"Unsupported linear quantization config '{qlinear}'.")

def linear_filter(module, fqn):
if isinstance(module, torch.nn.Linear):
# Check if hidden dimension is divisible by group size
return qlinear_group_size == 0 or (
module.weight.shape[1] % qlinear_group_size == 0
)
return False

quantize_(model, linear_config, filter_fn=linear_filter)


def export_model_to_metal(
model: nn.Module, example_inputs: Tuple[torch.Tensor, ...]
) -> Any:
Expand Down Expand Up @@ -539,17 +649,50 @@ def export_model_to_pte(
example_inputs: Tuple[torch.Tensor, ...],
output_dir: Path,
model_name: str,
compare_to_unquantized: bool = False,
model_config: Optional[Dict[str, Any]] = None,
) -> Tuple[Path, torch.Tensor]:
"""
Export model to .pte file, and compute expected output.
Export model to .pte file and compute expected output.

Args:
model: Model to export (may be quantized)
example_inputs: Example inputs for export
output_dir: Directory to save output files
model_name: Name of the model
compare_to_unquantized: If True and model has quantization config,
compute expected output from unquantized model
model_config: Model configuration from MODULE_REGISTRY

Returns:
Tuple of (pte_path, expected_output)
"""
# Compute expected output using all-ones input (matching export_aoti_metal.py)
all_ones_input = tuple(torch.ones_like(inp) for inp in example_inputs)

with torch.no_grad():
expected_output = model(*all_ones_input)
if compare_to_unquantized and model_config and model_config.get("qlinear"):
# Create unquantized reference model for comparison
dtype = example_inputs[0].dtype if example_inputs else torch.float32
model_class = model_config["model_class"]
reference_model = model_class().eval()
reference_model = reference_model.to(dtype)
expected_output = reference_model(*all_ones_input)
else:
# Use the quantized model's output
# For quantized models, torchao operators require MPS device
if model_config and model_config.get("qlinear") and MPS_AVAILABLE:
# Move model and inputs to MPS
model_mps = model.to("mps")
all_ones_input_mps = tuple(inp.to("mps") for inp in all_ones_input)
expected_output = model_mps(*all_ones_input_mps)
# Move output back to CPU for comparison
expected_output = expected_output.cpu()
# Move model back to CPU for export
model = model_mps.to("cpu")
else:
# Non-quantized model, run on CPU
expected_output = model(*all_ones_input)

# Export to executorch
executorch_program = export_model_to_metal(model, example_inputs)
Expand Down Expand Up @@ -745,15 +888,6 @@ def _test_module_export(

model, example_inputs = get_model_and_inputs(model_name, dtype=dtype)

# Verify model forward pass works before export
with torch.no_grad():
model_output = model(*example_inputs)

self.assertIsNotNone(
model_output,
f"{model_name} ({DTYPE_NAMES[dtype]}): Forward pass returned None",
)

# Export to Metal backend
executorch_program = export_model_to_metal(model, example_inputs)

Expand Down Expand Up @@ -789,16 +923,28 @@ def _test_module_output_consistency(
model, example_inputs = get_model_and_inputs(model_name, dtype=dtype)
dtype_name = DTYPE_NAMES[dtype]
test_subdir_name = f"{model_name}_{dtype_name}"
model_config = MODULE_REGISTRY.get(model_name, {})

def run_test_in_directory(test_dir: Path) -> None:
"""Run the actual test logic in the given directory."""
# Create model output directory: metal_backend_module_outputs/<model_name>_<dtype>/
model_output_dir = test_dir / test_subdir_name
model_output_dir.mkdir(parents=True, exist_ok=True)

# Determine if we should compare to unquantized reference
# Default to True for quantized models, False otherwise
compare_to_unquantized = model_config.get(
"compare_to_unquantized", bool(model_config.get("qlinear"))
)

# Export model and get expected output
pte_path, expected_output = export_model_to_pte(
model, example_inputs, model_output_dir, model_name
model,
example_inputs,
model_output_dir,
model_name,
compare_to_unquantized=compare_to_unquantized,
model_config=model_config,
)

self.assertTrue(
Expand Down
Loading