From 8a4a5111e94fb7ffe4fdc1008e9ba179c8d25717 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Tue, 3 Feb 2026 14:20:37 -0800 Subject: [PATCH 1/2] Add Tosa to LLM extension (#15556) Summary: cc freddan80 per zingo oscarandersson8218 digantdesai larryliu0820 mergennachin cccclai helunwencser jackzhxng Pull Request resolved: https://github.com/pytorch/executorch/pull/15556 Differential Revision: D90692071 Pulled By: SS-JIA --- examples/models/llama/export_llama_lib.py | 50 ++++++++++++++++++- examples/models/llama/tests/BUCK | 1 + .../llama/tests/test_export_llama_lib.py | 20 +++++++- extension/llm/export/builder.py | 5 +- extension/llm/export/config/llm_config.py | 13 +++++ extension/llm/export/partitioner_lib.py | 10 ++++ extension/llm/export/quantizer_lib.py | 20 ++++++++ 7 files changed, 115 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..3fad9f39f53 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,6 +37,7 @@ get_mps_partitioner, get_openvino_partitioner, get_qnn_partitioner, + get_tosa_partitioner, get_vulkan_partitioner, get_xnnpack_partitioner, ) @@ -46,6 +47,7 @@ get_pt2e_quantization_params, get_pt2e_quantizers, get_qnn_quantizer, + get_tosa_quantizer, get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace @@ -210,6 +212,7 @@ def build_args_parser() -> argparse.ArgumentParser: "coreml_baseline_8a_c8w", "coreml_baseline_8a_c4w", "vulkan_8w", + "tosa_8a8w", ], help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", ) @@ -788,6 +791,11 @@ def get_quantizer_and_quant_params(llm_config): llm_config.quantization.pt2e_quantize.value ) quantizers.append(coreml_quantizer) + if llm_config.backend.tosa.enabled and llm_config.quantization.pt2e_quantize: + tosa_quantizer = get_tosa_quantizer( + llm_config.backend.tosa.version, llm_config.quantization.pt2e_quantize.value + ) + quantizers.append(tosa_quantizer) if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize: assert ( len(quantizers) == 0 @@ -930,6 +938,32 @@ def _to_edge_and_lower_llama_openvino( return builder.to_executorch(passes=additional_passes) +def _to_edge_and_lower_llama_tosa( + builder_exported, + modelname, + quantizers, + additional_passes, + tosa_spec, + verbose: bool = False, +) -> LLMEdgeManager: + + logging.info("Lowering model using TOSA partitioner") + + partitioners = [] + partitioners.append(get_tosa_partitioner(tosa_spec)) + + modelname = f"tosa_{modelname}" + + builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower( + partitioners + ) + + if verbose: + print_delegation_info(builder.edge_manager.exported_program().graph_module) + + return builder.to_executorch(passes=additional_passes) + + def _to_edge_and_lower_llama( # noqa: C901 builder_exported, modelname, @@ -1119,7 +1153,10 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(llm_config).export() + builder_manager = _prepare_for_llama_export(llm_config) + if llm_config.backend.tosa.enabled: + builder_manager.skip_dim_order = False + builder_exported = builder_manager.export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname @@ -1162,6 +1199,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 openvino_device=llm_config.backend.openvino.device, verbose=llm_config.debug.verbose, ) + elif llm_config.backend.tosa.enabled: + builder = _to_edge_and_lower_llama_tosa( + builder_exported, + modelname, + quantizers, + additional_passes, + llm_config.backend.tosa.version, + verbose=llm_config.debug.verbose, + ) else: builder = _to_edge_and_lower_llama( builder_exported, diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index d9570ee8b7a..b50159de509 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -22,6 +22,7 @@ fbcode_target(_kind = python_unittest, ], deps = [ "//caffe2:torch", + "//executorch/backends/arm/quantizer:lib", "//executorch/examples/models/llama:export_library", "//executorch/examples/models/llama:llama_transformer", "//pytorch/ao:torchao", diff --git a/examples/models/llama/tests/test_export_llama_lib.py b/examples/models/llama/tests/test_export_llama_lib.py index 172517207de..d5ce8984880 100644 --- a/examples/models/llama/tests/test_export_llama_lib.py +++ b/examples/models/llama/tests/test_export_llama_lib.py @@ -1,17 +1,21 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest +from executorch.backends.arm.quantizer.arm_quantizer import TOSAQuantizer + from executorch.devtools.backend_debug import get_delegation_info from executorch.examples.models.llama.export_llama_lib import ( _export_llama, build_args_parser, + get_quantizer_and_quant_params, ) -from executorch.extension.llm.export.config.llm_config import LlmConfig +from executorch.extension.llm.export.config.llm_config import LlmConfig, Pt2eQuantize UNWANTED_OPS = [ "aten_permute_copy_default", @@ -48,3 +52,17 @@ def test_has_expected_ops_and_op_counts(self): for op, _op_info in delegation_info.delegation_by_operator.items(): self.assertTrue(op not in UNWANTED_OPS) + + def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self): + llm_config = LlmConfig() + llm_config.backend.tosa.enabled = True + llm_config.quantization.pt2e_quantize = Pt2eQuantize.tosa_8a8w + + pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( + llm_config + ) + + self.assertIsNone(pt2e_quant_params) + self.assertIsNone(quant_dtype) + self.assertEqual(len(quantizers), 1) + self.assertIsInstance(quantizers[0], TOSAQuantizer) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ae15dded91d..e6d086eb202 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -96,6 +97,7 @@ def __init__( dynamic_shapes: Optional[Any] = None, save_exported_program: bool = False, generate_etrecord: bool = False, + skip_dim_order: bool = True, ): # Store necessary constructor arguments. self.model = model @@ -118,6 +120,7 @@ def __init__( self.dynamic_shapes = dynamic_shapes self.save_exported_program = save_exported_program self.generate_etrecord = generate_etrecord + self.skip_dim_order = skip_dim_order # Note: treat this as the source of truth for the result of # torch.export'ing a model. If the overall ExportedProgram is needed, @@ -197,7 +200,7 @@ def _get_dynamic_shape(self) -> Any: def _get_edge_config(self) -> EdgeCompileConfig: edge_config = EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, + _skip_dim_order=self.skip_dim_order, ) return edge_config diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index b40fad88a9c..a7453fd09c1 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -288,6 +289,7 @@ class Pt2eQuantize(str, Enum): coreml_baseline_8a_c8w = "coreml_baseline_8a_c8w" coreml_baseline_8a_c4w = "coreml_baseline_8a_c4w" vulkan_8w = "vulkan_8w" + tosa_8a8w = "tosa_8a8w" class SpinQuant(str, Enum): @@ -474,6 +476,16 @@ class TorchAOKernelsConfig: use_torchao_kernels_tied_embedding: bool = False +@dataclass +class TosaConfig: + """ + Configures the TOSA backend. + """ + + enabled: bool = False + version: str = "TOSA-1.0+INT" + + @dataclass class BackendConfig: """ @@ -488,6 +500,7 @@ class BackendConfig: mps: MPSConfig = field(default_factory=MPSConfig) openvino: OpenvinoConfig = field(default_factory=OpenvinoConfig) torchao: TorchAOKernelsConfig = field(default_factory=TorchAOKernelsConfig) + tosa: TosaConfig = field(default_factory=TosaConfig) ################################################################################ diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 554c8a16ac7..d5b3329be71 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -236,3 +237,12 @@ def get_qnn_partitioner( # TODO: if deprecated legacy export, skip_mutable_buffer can be set False skip_mutable_buffer=True, ) + + +def get_tosa_partitioner(version: str): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + from executorch.backends.arm.tosa.partitioner import TOSAPartitioner + + compile_spec = TosaCompileSpec(version) + + return TOSAPartitioner(compile_spec) diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index ade0f8d089b..0996319e4a2 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -320,3 +321,22 @@ def get_vulkan_quantizer(pt2e_quantize: str): quantizer = VulkanQuantizer().set_global(config) return quantizer + + +def get_tosa_quantizer(version: str, pt2e_quantize: str): + from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_quantization_config, + TOSAQuantizer, + ) + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + compile_spec = TosaCompileSpec(version) + + quantizer = TOSAQuantizer(compile_spec) + + if pt2e_quantize == "tosa_8a8w": + quantizer.set_global(get_symmetric_quantization_config()) + else: + raise ValueError(f"Unsupported quantizer specification {pt2e_quantize}") + + return quantizer From 1efa0fa69584c898aa77f82460dc848ff9770132 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Feb 2026 15:14:47 -0800 Subject: [PATCH 2/2] Tosa to LLM extension - buck fixes Summary: As title! Differential Revision: D92209173 --- backends/arm/quantizer/TARGETS | 5 +++++ examples/models/llama/tests/BUCK | 1 + 2 files changed, 6 insertions(+) diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index 1a02340f92b..28bfe15b528 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -17,6 +17,11 @@ runtime.python_library( deps = [ ":arm_quantizer_utils", ":quantization_annotator", + "//executorch/backends/arm:constants", + "//executorch/backends/arm:ethosu", + "//executorch/backends/arm:vgf", + "//executorch/backends/arm/tosa:specification", + "//executorch/backends/arm:arm_compile_spec", "//caffe2:torch", "//executorch/exir:lib", "//pytorch/ao:torchao", diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index b50159de509..8f4dec2237b 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -98,6 +98,7 @@ fbcode_target(_kind = python_unittest, ], deps = [ "//caffe2:torch", + "//executorch/backends/arm/quantizer:lib", "//executorch/examples/models/llama:export_library", "//executorch/examples/models/llama:llama_transformer", "//executorch/extension/pybindings:portable_lib",