diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 5cc00fa5ab..906b20bca0 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -40,6 +40,7 @@ python_library( ":utils", ":ops_registrations", ":replace_ops", + ":memory_planning", "//caffe2:torch", "//executorch/backends/cadence/aot/quantizer:fusion_pass", "//executorch/backends/cadence/aot/quantizer:quantizer", diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 2b05e30f4c..bc854337a8 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -12,11 +12,20 @@ import executorch.backends.cadence.aot.ops_registrations # noqa import torch +from executorch.backends.cadence.aot.memory_planning import ( + CadenceMemoryPlanning, + print_memory_planning_info, +) from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.backends.cadence.aot.replace_ops import ReplaceSafeSoftmaxWithSoftmax -from executorch.backends.cadence.aot.utils import model_gm_has_SDPA, model_is_quantized +from executorch.backends.cadence.aot.utils import ( + get_default_memory_config, + MemoryConfig, + model_gm_has_SDPA, + model_is_quantized, +) from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) @@ -24,10 +33,13 @@ from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, + ExecutorchBackendConfig, ExecutorchProgramManager, to_edge, ) from executorch.exir.pass_base import PassResult +from executorch.exir.passes import ToOutVarPass +from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from torch._inductor.decomposition import remove_decompositions from torch.ao.quantization.pt2e.export_utils import model_is_exported from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -263,6 +275,10 @@ def export_to_executorch_gen_etrecord( inputs: tuple[object, ...], output_dir: Optional[str] = None, opt_level: int = 1, + mem_algo: int = 0, + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, + memory_config: Optional[MemoryConfig] = None, dump_graphs: bool = False, ) -> ExecutorchProgramManager: cadence_passes = get_cadence_passes(opt_level) @@ -281,8 +297,36 @@ def export_to_executorch_gen_etrecord( cadence_prog_manager.exported_program().graph_module, ) + if memory_config is None: + memory_config = get_default_memory_config() + + memory_planning_pass = CadenceMemoryPlanning( + memory_config, + opt_level=opt_level, + mem_algo=mem_algo, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + ) + # Get executorch program after Cadence specific passes - exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch() + exec_prog: ExecutorchProgramManager = cadence_prog_manager.to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=memory_planning_pass, + emit_stacktrace=False, + to_out_var_pass=ToOutVarPass(), + extract_delegate_segments=False, + sym_shape_eval_pass=HintBasedSymShapeEvalPass(), + ), + ) + + print_memory_planning_info( + exec_prog, + memory_config, + opt_level, + alloc_graph_input, + alloc_graph_output, + ) + if output_dir: _gen_etrecord(edge_prog_manager, exec_prog, Path(output_dir)) else: