diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index f97561b2e7c..abd975567f7 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -31,6 +31,16 @@ from torch.export import export from torch.nn.attention import SDPBackend +try: + # Need to import to load the ops + import torchao.experimental.ops.mps # noqa: F401 + from torchao.experimental.quant_api import UIntxWeightOnlyConfig + from torchao.quantization.quant_api import quantize_ + + TORCHAO_AVAILABLE = True +except ImportError: + TORCHAO_AVAILABLE = False + # Check if MPS is available for export tests MPS_AVAILABLE = torch.backends.mps.is_available() @@ -88,6 +98,30 @@ # - "rtol_": float - Override relative tolerance for specific dtype (e.g., "rtol_bfloat16") # - "skip": bool or str - Skip all tests for this module (True to skip, or string with reason) # - "skip_": bool or str - Skip tests for specific dtype (e.g., "skip_bfloat16") +# - "qlinear": str - Quantization config for linear layers (e.g., "fpa4w" for 4-bit weights) +# - "qlinear_group_size": int - Group size for quantization (default: 32) +# - "compare_to_unquantized": bool - If True, compare quantized model output to unquantized reference (default: True for quantized models) +# +# Quantization Usage: +# To enable int4 quantization for a module, add "qlinear": "fpa4w" to its registry entry. +# This applies 4-bit weight quantization (floating point activation, 4-bit weight) using torchao. +# The quantization is applied after converting the model to the specified dtype but before export. +# +# By default, quantized models are compared against unquantized reference models to measure +# the actual quantization error. Set "compare_to_unquantized": False to compare against +# the quantized PyTorch model instead. +# +# Example: +# MODULE_REGISTRY["my_linear_model"] = { +# "model_class": MyLinearModel, +# "input_shapes": [(128, 256)], +# "description": "My linear model with int4 quantization", +# "qlinear": "fpa4w", +# "qlinear_group_size": 32, +# "compare_to_unquantized": True, # Compare to unquantized reference +# "atol_float32": 5e-2, # Quantization reduces precision, so increase tolerance +# "rtol_float32": 5e-2, +# } # # Model Parameter Initialization: # Model parameters are initialized with their default dtype (typically float32) when the @@ -95,7 +129,8 @@ # model.to(dtype). For example: # - nn.Parameter(torch.arange(20, dtype=torch.get_default_dtype()) creates float32 parameters # - These are converted to bfloat16 when model.to(torch.bfloat16) is called -# + + MODULE_REGISTRY: Dict[str, Dict[str, Any]] = {} @@ -190,6 +225,31 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearNoBiasInt4(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(128, 256, bias=False) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_nobias_int4"] = { + "model_class": LinearNoBiasInt4, + "input_shapes": [(127, 128)], + "description": "Linear layer without bias and int4 quantization", + "qlinear": "fpa4w", + "qlinear_group_size": 32, + "compare_to_unquantized": False, + "atol_float32": 5e-2, + "rtol_float32": 5e-2, + "atol_bfloat16": 1e-1, + "rtol_bfloat16": 1e-1, + "skip": not TORCHAO_AVAILABLE, +} + + # ------------------------------------------------------------------------- # Convolution Modules # ------------------------------------------------------------------------- @@ -476,7 +536,10 @@ def should_skip_model(model_name: str, dtype: torch.dtype) -> Tuple[bool, str]: def get_model_and_inputs( - model_name: str, dtype: torch.dtype = torch.float32 + model_name: str, + dtype: torch.dtype = torch.float32, + qlinear: Optional[str] = None, + qlinear_group_size: Optional[int] = None, ) -> Tuple[nn.Module, Tuple[torch.Tensor, ...]]: """Get model and example inputs based on model name. @@ -486,6 +549,10 @@ def get_model_and_inputs( Args: model_name: Name of the model to create dtype: Target data type for the model (default: torch.float32) + qlinear: Optional quantization config (e.g., "fpa4w" for 4-bit weights). + If None, uses value from MODULE_REGISTRY if present. + qlinear_group_size: Group size for quantization. If None, uses value from + MODULE_REGISTRY if present, otherwise defaults to 32. Returns: Tuple of (model, example_inputs) @@ -500,6 +567,12 @@ def get_model_and_inputs( model_class = model_config["model_class"] input_shapes = model_config["input_shapes"] + # Use registry values if not explicitly provided + if qlinear is None: + qlinear = model_config.get("qlinear") + if qlinear_group_size is None: + qlinear_group_size = model_config.get("qlinear_group_size", 32) + # Create model with default parameter dtypes (typically float32) model = model_class().eval() @@ -507,11 +580,48 @@ def get_model_and_inputs( if dtype is not None: model = model.to(dtype) + # Apply quantization if requested + if qlinear is not None: + quantize_model(model, qlinear, qlinear_group_size) + example_inputs = tuple(torch.randn(*shape, dtype=dtype) for shape in input_shapes) return model, example_inputs +def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32): + """Apply quantization to the model's linear layers. + + Args: + model: The model to quantize (in-place). + qlinear: Quantization config. Options: + - "fpa4w": Floating point activation, 4-bit weight (Metal backend) + qlinear_group_size: Group size for quantization (default: 32). + """ + if not TORCHAO_AVAILABLE: + raise RuntimeError( + "torchao is not available. Install torchao to use quantization." + ) + + if qlinear == "fpa4w": + linear_config = UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) + else: + raise ValueError(f"Unsupported linear quantization config '{qlinear}'.") + + def linear_filter(module, fqn): + if isinstance(module, torch.nn.Linear): + # Check if hidden dimension is divisible by group size + return qlinear_group_size == 0 or ( + module.weight.shape[1] % qlinear_group_size == 0 + ) + return False + + quantize_(model, linear_config, filter_fn=linear_filter) + + def export_model_to_metal( model: nn.Module, example_inputs: Tuple[torch.Tensor, ...] ) -> Any: @@ -539,17 +649,50 @@ def export_model_to_pte( example_inputs: Tuple[torch.Tensor, ...], output_dir: Path, model_name: str, + compare_to_unquantized: bool = False, + model_config: Optional[Dict[str, Any]] = None, ) -> Tuple[Path, torch.Tensor]: """ - Export model to .pte file, and compute expected output. + Export model to .pte file and compute expected output. + + Args: + model: Model to export (may be quantized) + example_inputs: Example inputs for export + output_dir: Directory to save output files + model_name: Name of the model + compare_to_unquantized: If True and model has quantization config, + compute expected output from unquantized model + model_config: Model configuration from MODULE_REGISTRY Returns: Tuple of (pte_path, expected_output) """ # Compute expected output using all-ones input (matching export_aoti_metal.py) all_ones_input = tuple(torch.ones_like(inp) for inp in example_inputs) + with torch.no_grad(): - expected_output = model(*all_ones_input) + if compare_to_unquantized and model_config and model_config.get("qlinear"): + # Create unquantized reference model for comparison + dtype = example_inputs[0].dtype if example_inputs else torch.float32 + model_class = model_config["model_class"] + reference_model = model_class().eval() + reference_model = reference_model.to(dtype) + expected_output = reference_model(*all_ones_input) + else: + # Use the quantized model's output + # For quantized models, torchao operators require MPS device + if model_config and model_config.get("qlinear") and MPS_AVAILABLE: + # Move model and inputs to MPS + model_mps = model.to("mps") + all_ones_input_mps = tuple(inp.to("mps") for inp in all_ones_input) + expected_output = model_mps(*all_ones_input_mps) + # Move output back to CPU for comparison + expected_output = expected_output.cpu() + # Move model back to CPU for export + model = model_mps.to("cpu") + else: + # Non-quantized model, run on CPU + expected_output = model(*all_ones_input) # Export to executorch executorch_program = export_model_to_metal(model, example_inputs) @@ -745,15 +888,6 @@ def _test_module_export( model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) - # Verify model forward pass works before export - with torch.no_grad(): - model_output = model(*example_inputs) - - self.assertIsNotNone( - model_output, - f"{model_name} ({DTYPE_NAMES[dtype]}): Forward pass returned None", - ) - # Export to Metal backend executorch_program = export_model_to_metal(model, example_inputs) @@ -789,6 +923,7 @@ def _test_module_output_consistency( model, example_inputs = get_model_and_inputs(model_name, dtype=dtype) dtype_name = DTYPE_NAMES[dtype] test_subdir_name = f"{model_name}_{dtype_name}" + model_config = MODULE_REGISTRY.get(model_name, {}) def run_test_in_directory(test_dir: Path) -> None: """Run the actual test logic in the given directory.""" @@ -796,9 +931,20 @@ def run_test_in_directory(test_dir: Path) -> None: model_output_dir = test_dir / test_subdir_name model_output_dir.mkdir(parents=True, exist_ok=True) + # Determine if we should compare to unquantized reference + # Default to True for quantized models, False otherwise + compare_to_unquantized = model_config.get( + "compare_to_unquantized", bool(model_config.get("qlinear")) + ) + # Export model and get expected output pte_path, expected_output = export_model_to_pte( - model, example_inputs, model_output_dir, model_name + model, + example_inputs, + model_output_dir, + model_name, + compare_to_unquantized=compare_to_unquantized, + model_config=model_config, ) self.assertTrue(