From b1c12b6b20a4eefe16fed2483927d3cc896ff983 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 3 Feb 2026 16:02:47 -0800 Subject: [PATCH] Add multi-method support for ETRecord edge dialect programs Summary: ETRecord previously only supported a single edge dialect program (defaulting to the "forward" method). This change adds support for multiple methods like vision_encoder, token_embedding, and text_decoder that are used in VLM exports. The edge_dialect_program field can now be either a single ExportedProgram (for backward compatibility) or a Dict[str, ExportedProgram] mapping method names to their exported programs. Key changes: - Updated ETRecord to accept Dict[str, ExportedProgram] for edge_dialect_program - Modified _process_edge_dialect_program to extract all methods from EdgeProgramManager - Updated save/parse logic to handle multi-method format with backward compatibility - New format uses paths like edge_dialect_exported_program/{method_name} Authored with Claude Code. Differential Revision: D92215098 --- devtools/etrecord/_etrecord.py | 116 +++++++++++++++++++++-- devtools/etrecord/tests/etrecord_test.py | 64 +++++++++++++ 2 files changed, 170 insertions(+), 10 deletions(-) diff --git a/devtools/etrecord/_etrecord.py b/devtools/etrecord/_etrecord.py index a144d7e4eaf..bc6987c457b 100644 --- a/devtools/etrecord/_etrecord.py +++ b/devtools/etrecord/_etrecord.py @@ -62,7 +62,9 @@ def __init__( self, exported_program: Optional[ExportedProgram] = None, export_graph_id: Optional[int] = None, - edge_dialect_program: Optional[ExportedProgram] = None, + edge_dialect_program: Optional[ + Union[ExportedProgram, Dict[str, ExportedProgram]] + ] = None, graph_map: Optional[Dict[str, ExportedProgram]] = None, _debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None, _delegate_map: Optional[ @@ -88,6 +90,31 @@ def __init__( ``` If user need to create an ETRecord manually, please use the `create_etrecord` function. + + **EXPERIMENTAL**: This API supports multiple methods. For example: + ```python + lowered_and_edge = to_edge_transform_and_lower( + { + "vision_encoder": vision_encoder_ep, + "token_embedding": token_embedding_ep, + "text_decoder": causal_llm_ep, + }, + partitioner={ + "vision_encoder": [XnnpackPartitioner()], + "token_embedding": [XnnpackPartitioner()], + "text_decoder": [ + XnnpackPartitioner( + config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, + per_op_mode=True, + ), + XnnpackPartitioner(), + ], + }, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + constant_methods=manager.metadata, + generate_etrecord=True, # Enable ETRecord generation for all 3 methods + ) + ``` """ self.exported_program = exported_program @@ -121,6 +148,14 @@ def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None: "ETRecord must contain edge dialect program and executorch program to be saved" ) + # Normalize edge_dialect_program to dict format for consistent handling + if isinstance(self.edge_dialect_program, ExportedProgram): + self._edge_dialect_programs_dict: Dict[str, ExportedProgram] = { + "forward": self.edge_dialect_program + } + else: + self._edge_dialect_programs_dict = self.edge_dialect_program + etrecord_zip = ZipFile(path, "w") try: @@ -136,7 +171,7 @@ def _write_identifier(self, etrecord_zip: ZipFile) -> None: etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "") def _save_programs(self, etrecord_zip: ZipFile) -> None: - """Save exported program and edge dialect program.""" + """Save exported program and edge dialect program(s).""" if self.exported_program is not None: self._save_exported_program( etrecord_zip, @@ -145,8 +180,9 @@ def _save_programs(self, etrecord_zip: ZipFile) -> None: self.exported_program, ) - if self.edge_dialect_program is not None: - self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program) + # Save all edge dialect programs (supports multiple methods) + for method_name, edge_program in self._edge_dialect_programs_dict.items(): + self._save_edge_dialect_program(etrecord_zip, method_name, edge_program) def _save_graph_map(self, etrecord_zip: ZipFile) -> None: """Save graph map if present.""" @@ -223,13 +259,19 @@ def _save_exported_program( ) def _save_edge_dialect_program( - self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram + self, + etrecord_zip: ZipFile, + method_name: str, + edge_dialect_program: ExportedProgram, ) -> None: """Save the edge dialect program to the ETRecord zip file.""" serialized_artifact = serialize(edge_dialect_program) assert isinstance(serialized_artifact.exported_program, bytes) - base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM + # Use format: edge_dialect_exported_program/method_name for multi-method support + base_name = ( + f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}/{method_name}" + ) etrecord_zip.writestr(base_name, serialized_artifact.exported_program) etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict) etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants) @@ -591,13 +633,28 @@ def _add_module_to_graph_map( def _process_edge_dialect_program( edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram] -) -> ExportedProgram: - """Process edge dialect program and return the exported program.""" +) -> Union[ExportedProgram, Dict[str, ExportedProgram]]: + """Process edge dialect program and return the exported program(s). + + For EdgeProgramManager with multiple methods, returns a Dict[str, ExportedProgram] + mapping method names to their exported programs. For single-method cases or + ExirExportedProgram, returns a single ExportedProgram. + """ if isinstance( edge_dialect_program, (EdgeProgramManager, exir.program._program.EdgeProgramManager), ): - return edge_dialect_program.exported_program() + methods = edge_dialect_program.methods + if len(methods) == 1: + # Single method case - return the ExportedProgram directly + method_name = next(iter(methods)) + return edge_dialect_program.exported_program(method_name) + else: + # Multiple methods - return a dict of all methods + return { + method: edge_dialect_program.exported_program(method) + for method in methods + } elif isinstance(edge_dialect_program, ExirExportedProgram): return edge_dialect_program.exported_program else: @@ -676,19 +733,28 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 ) graph_map: Dict[str, ExportedProgram] = {} + edge_dialect_programs: Dict[str, ExportedProgram] = {} debug_handle_map = None delegate_map = None instruction_id_to_num_outs_map = None exported_program = None - edge_dialect_program = None + edge_dialect_program: Optional[ + Union[ExportedProgram, Dict[str, ExportedProgram]] + ] = None reference_outputs = None representative_inputs = None export_graph_id = 0 serialized_exported_program_files = set() + serialized_edge_dialect_program_files = set() serialized_state_dict_files = set() serialized_constants_files = set() serialized_example_inputs_files = set() + + edge_dialect_prefix = ( + str(ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM) + "/" + ) + for entry in file_list: if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME: debug_handle_map = json.loads( @@ -707,6 +773,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER: continue elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM: + # Old format: single edge dialect program (backward compatibility) serialized_artifact = SerializedArtifact( etrecord_zip.read( ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM @@ -716,6 +783,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 etrecord_zip.read(f"{entry}_example_inputs"), ) edge_dialect_program = deserialize(serialized_artifact) + elif entry.startswith(edge_dialect_prefix) and not entry.endswith( + ("_state_dict", "_constants", "_example_inputs") + ): + # New format: edge_dialect_exported_program/method_name + serialized_edge_dialect_program_files.add(entry) elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM: serialized_artifact = SerializedArtifact( etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM), @@ -748,6 +820,30 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901 else: serialized_exported_program_files.add(entry) + # Parse new format edge dialect programs (multi-method support) + for serialized_file in serialized_edge_dialect_program_files: + serialized_state_dict_file = f"{serialized_file}_state_dict" + serialized_constants_file = f"{serialized_file}_constants" + serialized_example_inputs_file = f"{serialized_file}_example_inputs" + serialized_artifact = SerializedArtifact( + etrecord_zip.read(serialized_file), + etrecord_zip.read(serialized_state_dict_file), + etrecord_zip.read(serialized_constants_file), + etrecord_zip.read(serialized_example_inputs_file), + ) + # Extract method name from path: edge_dialect_exported_program/method_name -> method_name + method_name = serialized_file[len(edge_dialect_prefix) :] + edge_dialect_programs[method_name] = deserialize(serialized_artifact) + + # If we found multi-method edge dialect programs, use them + if edge_dialect_programs: + if len(edge_dialect_programs) == 1: + # Single method - store as ExportedProgram for backward compatibility + edge_dialect_program = next(iter(edge_dialect_programs.values())) + else: + # Multiple methods - store as dict + edge_dialect_program = edge_dialect_programs + for serialized_file in serialized_exported_program_files: serialized_state_dict_file = f"{serialized_file}_state_dict" serialized_constants_file = f"{serialized_file}_constants" diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 535de2e9a56..68085fb6796 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -1746,3 +1746,67 @@ def expected_runtime_error(etrecord, etrecord_path): # All essential components are now present, so save should succeed etrecord.save(etrecord_path) + + def test_multi_method_etrecord_generation(self): + """Test that ETRecord correctly handles multiple methods (e.g., vision_encoder, text_decoder).""" + # Create two different models to simulate multi-method export + f1 = models.BasicSinMax() + f2 = models.BasicSinMax() + + # Export both models + aten_program1 = export(f1, f1.get_random_inputs(), strict=True) + aten_program2 = export(f2, f2.get_random_inputs(), strict=True) + + # Create multi-method edge program + multi_method_programs = { + "vision_encoder": aten_program1, + "text_decoder": aten_program2, + } + + edge_manager = to_edge_transform_and_lower( + multi_method_programs, + generate_etrecord=True, + ) + + # Verify that ETRecord was generated + self.assertIsNotNone(edge_manager._etrecord) + etrecord = edge_manager._etrecord + + # Verify edge_dialect_program is a dict with both methods + self.assertIsNotNone(etrecord.edge_dialect_program) + self.assertIsInstance(etrecord.edge_dialect_program, dict) + self.assertIn("vision_encoder", etrecord.edge_dialect_program) + self.assertIn("text_decoder", etrecord.edge_dialect_program) + + # Convert to executorch to get complete ETRecord + et_manager = edge_manager.to_executorch() + + with tempfile.TemporaryDirectory() as tmpdirname: + etrecord_path = tmpdirname + "/etrecord_multi_method.bin" + + # Get ETRecord and save + complete_etrecord = et_manager.get_etrecord() + complete_etrecord.save(etrecord_path) + + # Parse ETRecord back + parsed_etrecord = parse_etrecord(etrecord_path) + + # Verify edge_dialect_program is correctly parsed as a dict + self.assertIsNotNone(parsed_etrecord.edge_dialect_program) + self.assertIsInstance(parsed_etrecord.edge_dialect_program, dict) + self.assertIn("vision_encoder", parsed_etrecord.edge_dialect_program) + self.assertIn("text_decoder", parsed_etrecord.edge_dialect_program) + + # Verify both methods have valid ExportedProgram objects + self.assertIsNotNone(parsed_etrecord.edge_dialect_program["vision_encoder"]) + self.assertIsNotNone(parsed_etrecord.edge_dialect_program["text_decoder"]) + self.assertIsNotNone( + parsed_etrecord.edge_dialect_program["vision_encoder"].graph_module + ) + self.assertIsNotNone( + parsed_etrecord.edge_dialect_program["text_decoder"].graph_module + ) + + # Verify other ETRecord components are preserved + self.assertIsNotNone(parsed_etrecord._debug_handle_map) + self.assertIsNotNone(parsed_etrecord._delegate_map)