diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index fde0410cca3..36295c27786 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -43,8 +43,12 @@ def get_decomposition_table(cls) -> Dict[Any, Any]: @classmethod def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]: - """Return Metal-specific passes (currently none)""" - return [] + """Return Metal-specific passes""" + from executorch.backends.apple.metal.passes.decompose_linear_pass import ( + DecomposeLinearPass, + ) + + return [DecomposeLinearPass()] @classmethod def get_aoti_compile_options( diff --git a/backends/apple/metal/passes/__init__.py b/backends/apple/metal/passes/__init__.py new file mode 100644 index 00000000000..2a1209f1356 --- /dev/null +++ b/backends/apple/metal/passes/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.apple.metal.passes.decompose_linear_pass import ( # noqa: F401 + DecomposeLinearPass, +) + +__all__ = ["DecomposeLinearPass"] diff --git a/backends/apple/metal/passes/decompose_linear_pass.py b/backends/apple/metal/passes/decompose_linear_pass.py new file mode 100644 index 00000000000..e6b8578cc9f --- /dev/null +++ b/backends/apple/metal/passes/decompose_linear_pass.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class DecomposeLinearPass(ExportPass): + """ + Decompose aten.linear into matmul + add to avoid addmm. + + For 2D inputs, we unsqueeze to 3D before decomposition to force the matmul + code path instead of addmm. The C++ implementation of aten.linear directly + calls addmm for 2D inputs with bias, which would require implementing + aoti_torch_mps_addmm_out. By unsqueezing to 3D, we force the matmul path, + then squeeze back to 2D. + """ + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + modified = False + graph = graph_module.graph + + for node in graph.nodes: + # Check if this is a linear operation + is_linear = False + + if node.op == "call_function": + # Match both edge dialect and core aten linear operators + if node.target == exir_ops.edge.aten.linear.default: + is_linear = True + elif node.target == torch.ops.aten.linear.default: + is_linear = True + + if is_linear: + # Get input, weight, and bias arguments + input_node = node.args[0] + weight_node = node.args[1] + bias_node = node.args[2] if len(node.args) > 2 else None + + with graph.inserting_before(node): + # Determine which ops to use based on the input operator + target_str = str(node.target) + + if "executorch_exir_dialects_edge" in target_str: + # Use edge dialect operators + t_op = exir_ops.edge.aten.t.default + matmul_op = exir_ops.edge.aten.matmul.default + add_op = exir_ops.edge.aten.add.Tensor + unsqueeze_op = exir_ops.edge.aten.unsqueeze.default + squeeze_op = exir_ops.edge.aten.squeeze.dims + else: + # Use core aten operators + t_op = torch.ops.aten.t.default + matmul_op = torch.ops.aten.matmul.default + add_op = torch.ops.aten.add.Tensor + unsqueeze_op = torch.ops.aten.unsqueeze.default + squeeze_op = torch.ops.aten.squeeze.dims + + # Check if input is 2D + needs_unsqueeze = False + if hasattr(input_node, "meta") and "val" in input_node.meta: + if len(input_node.meta["val"].shape) == 2: + needs_unsqueeze = True + + # Unsqueeze 2D input to 3D: (M, K) -> (1, M, K) + current_input = input_node + if needs_unsqueeze: + current_input = graph.call_function( + unsqueeze_op, + args=(input_node, 0), + ) + + # Decompose linear: matmul(input, weight.T) + bias + weight_t = graph.call_function( + t_op, + args=(weight_node,), + ) + + matmul_result = graph.call_function( + matmul_op, + args=(current_input, weight_t), + ) + + if bias_node is not None: + result = graph.call_function( + add_op, + args=(matmul_result, bias_node), + ) + else: + result = matmul_result + + # Squeeze 3D output back to 2D: (1, M, N) -> (M, N) + if needs_unsqueeze: + result = graph.call_function( + squeeze_op, + args=(result, [0]), + ) + + # Replace all uses of the linear node with the decomposed result + node.replace_all_uses_with(result) + graph.erase_node(node) + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/apple/metal/tests/test_modules.py b/backends/apple/metal/tests/test_modules.py index 59904bb494d..f97561b2e7c 100644 --- a/backends/apple/metal/tests/test_modules.py +++ b/backends/apple/metal/tests/test_modules.py @@ -173,6 +173,23 @@ def forward(self, x: torch.Tensor): } +# ------------------------------------------------------------------------- +class LinearWithBias(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(7, 101, bias=True) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +MODULE_REGISTRY["linear_bias"] = { + "model_class": LinearWithBias, + "input_shapes": [(127, 7)], + "description": "Simple linear layer model bias", +} + + # ------------------------------------------------------------------------- # Convolution Modules # -------------------------------------------------------------------------