Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 106 additions & 10 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __init__(
self,
exported_program: Optional[ExportedProgram] = None,
export_graph_id: Optional[int] = None,
edge_dialect_program: Optional[ExportedProgram] = None,
edge_dialect_program: Optional[
Union[ExportedProgram, Dict[str, ExportedProgram]]
] = None,
graph_map: Optional[Dict[str, ExportedProgram]] = None,
_debug_handle_map: Optional[Dict[int, Union[int, List[int]]]] = None,
_delegate_map: Optional[
Expand All @@ -88,6 +90,31 @@ def __init__(
```

If user need to create an ETRecord manually, please use the `create_etrecord` function.

**EXPERIMENTAL**: This API supports multiple methods. For example:
```python
lowered_and_edge = to_edge_transform_and_lower(
{
"vision_encoder": vision_encoder_ep,
"token_embedding": token_embedding_ep,
"text_decoder": causal_llm_ep,
},
partitioner={
"vision_encoder": [XnnpackPartitioner()],
"token_embedding": [XnnpackPartitioner()],
"text_decoder": [
XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
),
XnnpackPartitioner(),
],
},
compile_config=EdgeCompileConfig(_check_ir_validity=False),
constant_methods=manager.metadata,
generate_etrecord=True, # Enable ETRecord generation for all 3 methods
)
```
"""

self.exported_program = exported_program
Expand Down Expand Up @@ -121,6 +148,14 @@ def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
"ETRecord must contain edge dialect program and executorch program to be saved"
)

# Normalize edge_dialect_program to dict format for consistent handling
if isinstance(self.edge_dialect_program, ExportedProgram):
self._edge_dialect_programs_dict: Dict[str, ExportedProgram] = {
"forward": self.edge_dialect_program
}
else:
self._edge_dialect_programs_dict = self.edge_dialect_program

etrecord_zip = ZipFile(path, "w")

try:
Expand All @@ -136,7 +171,7 @@ def _write_identifier(self, etrecord_zip: ZipFile) -> None:
etrecord_zip.writestr(ETRecordReservedFileNames.ETRECORD_IDENTIFIER, "")

def _save_programs(self, etrecord_zip: ZipFile) -> None:
"""Save exported program and edge dialect program."""
"""Save exported program and edge dialect program(s)."""
if self.exported_program is not None:
self._save_exported_program(
etrecord_zip,
Expand All @@ -145,8 +180,9 @@ def _save_programs(self, etrecord_zip: ZipFile) -> None:
self.exported_program,
)

if self.edge_dialect_program is not None:
self._save_edge_dialect_program(etrecord_zip, self.edge_dialect_program)
# Save all edge dialect programs (supports multiple methods)
for method_name, edge_program in self._edge_dialect_programs_dict.items():
self._save_edge_dialect_program(etrecord_zip, method_name, edge_program)

def _save_graph_map(self, etrecord_zip: ZipFile) -> None:
"""Save graph map if present."""
Expand Down Expand Up @@ -223,13 +259,19 @@ def _save_exported_program(
)

def _save_edge_dialect_program(
self, etrecord_zip: ZipFile, edge_dialect_program: ExportedProgram
self,
etrecord_zip: ZipFile,
method_name: str,
edge_dialect_program: ExportedProgram,
) -> None:
"""Save the edge dialect program to the ETRecord zip file."""
serialized_artifact = serialize(edge_dialect_program)
assert isinstance(serialized_artifact.exported_program, bytes)

base_name = ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
# Use format: edge_dialect_exported_program/method_name for multi-method support
base_name = (
f"{ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM}/{method_name}"
)
etrecord_zip.writestr(base_name, serialized_artifact.exported_program)
etrecord_zip.writestr(f"{base_name}_state_dict", serialized_artifact.state_dict)
etrecord_zip.writestr(f"{base_name}_constants", serialized_artifact.constants)
Expand Down Expand Up @@ -591,13 +633,28 @@ def _add_module_to_graph_map(

def _process_edge_dialect_program(
edge_dialect_program: Union[EdgeProgramManager, ExirExportedProgram]
) -> ExportedProgram:
"""Process edge dialect program and return the exported program."""
) -> Union[ExportedProgram, Dict[str, ExportedProgram]]:
"""Process edge dialect program and return the exported program(s).

For EdgeProgramManager with multiple methods, returns a Dict[str, ExportedProgram]
mapping method names to their exported programs. For single-method cases or
ExirExportedProgram, returns a single ExportedProgram.
"""
if isinstance(
edge_dialect_program,
(EdgeProgramManager, exir.program._program.EdgeProgramManager),
):
return edge_dialect_program.exported_program()
methods = edge_dialect_program.methods
if len(methods) == 1:
# Single method case - return the ExportedProgram directly
method_name = next(iter(methods))
return edge_dialect_program.exported_program(method_name)
else:
# Multiple methods - return a dict of all methods
return {
method: edge_dialect_program.exported_program(method)
for method in methods
}
elif isinstance(edge_dialect_program, ExirExportedProgram):
return edge_dialect_program.exported_program
else:
Expand Down Expand Up @@ -676,19 +733,28 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
)

graph_map: Dict[str, ExportedProgram] = {}
edge_dialect_programs: Dict[str, ExportedProgram] = {}
debug_handle_map = None
delegate_map = None
instruction_id_to_num_outs_map = None
exported_program = None
edge_dialect_program = None
edge_dialect_program: Optional[
Union[ExportedProgram, Dict[str, ExportedProgram]]
] = None
reference_outputs = None
representative_inputs = None
export_graph_id = 0

serialized_exported_program_files = set()
serialized_edge_dialect_program_files = set()
serialized_state_dict_files = set()
serialized_constants_files = set()
serialized_example_inputs_files = set()

edge_dialect_prefix = (
str(ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM) + "/"
)

for entry in file_list:
if entry == ETRecordReservedFileNames.DEBUG_HANDLE_MAP_NAME:
debug_handle_map = json.loads(
Expand All @@ -707,6 +773,7 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
elif entry == ETRecordReservedFileNames.ETRECORD_IDENTIFIER:
continue
elif entry == ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM:
# Old format: single edge dialect program (backward compatibility)
serialized_artifact = SerializedArtifact(
etrecord_zip.read(
ETRecordReservedFileNames.EDGE_DIALECT_EXPORTED_PROGRAM
Expand All @@ -716,6 +783,11 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
etrecord_zip.read(f"{entry}_example_inputs"),
)
edge_dialect_program = deserialize(serialized_artifact)
elif entry.startswith(edge_dialect_prefix) and not entry.endswith(
("_state_dict", "_constants", "_example_inputs")
):
# New format: edge_dialect_exported_program/method_name
serialized_edge_dialect_program_files.add(entry)
elif entry == ETRecordReservedFileNames.EXPORTED_PROGRAM:
serialized_artifact = SerializedArtifact(
etrecord_zip.read(ETRecordReservedFileNames.EXPORTED_PROGRAM),
Expand Down Expand Up @@ -748,6 +820,30 @@ def parse_etrecord(etrecord_path: str) -> ETRecord: # noqa: C901
else:
serialized_exported_program_files.add(entry)

# Parse new format edge dialect programs (multi-method support)
for serialized_file in serialized_edge_dialect_program_files:
serialized_state_dict_file = f"{serialized_file}_state_dict"
serialized_constants_file = f"{serialized_file}_constants"
serialized_example_inputs_file = f"{serialized_file}_example_inputs"
serialized_artifact = SerializedArtifact(
etrecord_zip.read(serialized_file),
etrecord_zip.read(serialized_state_dict_file),
etrecord_zip.read(serialized_constants_file),
etrecord_zip.read(serialized_example_inputs_file),
)
# Extract method name from path: edge_dialect_exported_program/method_name -> method_name
method_name = serialized_file[len(edge_dialect_prefix) :]
edge_dialect_programs[method_name] = deserialize(serialized_artifact)

# If we found multi-method edge dialect programs, use them
if edge_dialect_programs:
if len(edge_dialect_programs) == 1:
# Single method - store as ExportedProgram for backward compatibility
edge_dialect_program = next(iter(edge_dialect_programs.values()))
else:
# Multiple methods - store as dict
edge_dialect_program = edge_dialect_programs

for serialized_file in serialized_exported_program_files:
serialized_state_dict_file = f"{serialized_file}_state_dict"
serialized_constants_file = f"{serialized_file}_constants"
Expand Down
64 changes: 64 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,3 +1746,67 @@ def expected_runtime_error(etrecord, etrecord_path):

# All essential components are now present, so save should succeed
etrecord.save(etrecord_path)

def test_multi_method_etrecord_generation(self):
"""Test that ETRecord correctly handles multiple methods (e.g., vision_encoder, text_decoder)."""
# Create two different models to simulate multi-method export
f1 = models.BasicSinMax()
f2 = models.BasicSinMax()

# Export both models
aten_program1 = export(f1, f1.get_random_inputs(), strict=True)
aten_program2 = export(f2, f2.get_random_inputs(), strict=True)

# Create multi-method edge program
multi_method_programs = {
"vision_encoder": aten_program1,
"text_decoder": aten_program2,
}

edge_manager = to_edge_transform_and_lower(
multi_method_programs,
generate_etrecord=True,
)

# Verify that ETRecord was generated
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord

# Verify edge_dialect_program is a dict with both methods
self.assertIsNotNone(etrecord.edge_dialect_program)
self.assertIsInstance(etrecord.edge_dialect_program, dict)
self.assertIn("vision_encoder", etrecord.edge_dialect_program)
self.assertIn("text_decoder", etrecord.edge_dialect_program)

# Convert to executorch to get complete ETRecord
et_manager = edge_manager.to_executorch()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_multi_method.bin"

# Get ETRecord and save
complete_etrecord = et_manager.get_etrecord()
complete_etrecord.save(etrecord_path)

# Parse ETRecord back
parsed_etrecord = parse_etrecord(etrecord_path)

# Verify edge_dialect_program is correctly parsed as a dict
self.assertIsNotNone(parsed_etrecord.edge_dialect_program)
self.assertIsInstance(parsed_etrecord.edge_dialect_program, dict)
self.assertIn("vision_encoder", parsed_etrecord.edge_dialect_program)
self.assertIn("text_decoder", parsed_etrecord.edge_dialect_program)

# Verify both methods have valid ExportedProgram objects
self.assertIsNotNone(parsed_etrecord.edge_dialect_program["vision_encoder"])
self.assertIsNotNone(parsed_etrecord.edge_dialect_program["text_decoder"])
self.assertIsNotNone(
parsed_etrecord.edge_dialect_program["vision_encoder"].graph_module
)
self.assertIsNotNone(
parsed_etrecord.edge_dialect_program["text_decoder"].graph_module
)

# Verify other ETRecord components are preserved
self.assertIsNotNone(parsed_etrecord._debug_handle_map)
self.assertIsNotNone(parsed_etrecord._delegate_map)
Loading