diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 4908b2e0ffc..b3f592fa509 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -24,22 +24,17 @@ import numpy as np import torch + +# Need to import to load the torchao metal ops +import torchao.experimental.ops.mps # noqa: F401 + from executorch.backends.apple.metal.metal_backend import MetalBackend from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner from executorch.exir import to_edge_transform_and_lower from torch import nn from torch.export import export from torch.nn.attention import SDPBackend - -try: - # Need to import to load the ops - import torchao.experimental.ops.mps # noqa: F401 - from torchao.experimental.quant_api import UIntxWeightOnlyConfig - from torchao.quantization.quant_api import quantize_ - - TORCHAO_AVAILABLE = True -except ImportError: - TORCHAO_AVAILABLE = False +from torchao.experimental.quant_api import quantize_, UIntxWeightOnlyConfig # Check if MPS is available for export tests @@ -245,7 +240,6 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, - "skip": not TORCHAO_AVAILABLE, } @@ -270,7 +264,6 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, - "skip": not TORCHAO_AVAILABLE, } @@ -295,7 +288,6 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, - "skip": not TORCHAO_AVAILABLE, } @@ -320,7 +312,6 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, - "skip": not TORCHAO_AVAILABLE, } @@ -345,7 +336,6 @@ def forward(self, x: torch.Tensor): "rtol_float32": 5e-2, "atol_bfloat16": 1e-1, "rtol_bfloat16": 1e-1, - "skip": not TORCHAO_AVAILABLE, } @@ -697,11 +687,6 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32) - "fpa4w": Floating point activation, 4-bit weight (Metal backend) qlinear_group_size: Group size for quantization (default: 32). """ - if not TORCHAO_AVAILABLE: - raise RuntimeError( - "torchao is not available. Install torchao to use quantization." - ) - if qlinear == "fpa4w": linear_config = UIntxWeightOnlyConfig( group_size=qlinear_group_size,