Skip to content

Commit

Permalink
executorch/exir/program/test (#7397)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #7397

Reviewed By: avikchaudhuri, ydwu4

Differential Revision: D67383235
  • Loading branch information
gmagogsfm authored and facebook-github-bot committed Dec 19, 2024
1 parent f341da8 commit e824a31
Show file tree
Hide file tree
Showing 72 changed files with 396 additions and 587 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main() -> None:
torch.randn((1, embedding_dim)),
torch.tensor([0]),
)
exported_model = export(model, example_inputs)
exported_model = export(model, example_inputs, strict=True)
edge_program_manager = exir.to_edge(exported_model)
compile_specs = CoreMLBackend.generate_compile_specs(
compute_precision=ct.precision.FLOAT16,
Expand Down
7 changes: 3 additions & 4 deletions backends/apple/coreml/test/test_coreml_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class TestCoreMLPartitioner(unittest.TestCase):

# TODO(T182928844): Delegate dim order op to backend.
edge_compile_config = executorch.exir.EdgeCompileConfig(_skip_dim_order=True)

Expand All @@ -34,7 +33,7 @@ def forward(self, a, x, b):
model.eval()

example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)

edge_program_manager = executorch.exir.to_edge(
exir_program_aten, compile_config=self.edge_compile_config
Expand All @@ -61,7 +60,7 @@ def test_vit_skip_conv(self):
model.eval()

example_inputs = (torch.randn(1, 3, 224, 224),)
exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
edge_program_manager = executorch.exir.to_edge(
exir_program_aten, compile_config=self.edge_compile_config
)
Expand Down Expand Up @@ -106,7 +105,7 @@ def forward(self, q, k_val, input_pos):
k_val = torch.randn((1, embedding_dim))
input_pos = torch.tensor([0])
example_inputs = (q, k_val, input_pos)
exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)

compile_specs = CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS18
Expand Down
5 changes: 1 addition & 4 deletions backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,7 @@ def lower_module_and_test_output(
)

executorch_program = to_edge(
export(
delegated_program,
sample_inputs,
),
export(delegated_program, sample_inputs, strict=True),
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def export_program(
torch._C._set_mkldnn_enabled(False)

# else: capture the model and return it.
expo_program = export(model, inputs)
expo_program = export(model, inputs, strict=True)

if dump_graphs:
logging.info("Exported graph:")
Expand Down
4 changes: 2 additions & 2 deletions backends/example/test_example_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_example_inputs():

quantized_gm = m
exported_program = to_edge(
export(quantized_gm, copy.deepcopy(example_inputs)),
export(quantized_gm, copy.deepcopy(example_inputs), strict=True),
compile_config=EDGE_COMPILE_CONFIG,
)

Expand Down Expand Up @@ -92,7 +92,7 @@ def test_delegate_mobilenet_v2(self):

quantized_gm = m
exported_program = to_edge(
export(quantized_gm, copy.deepcopy(example_inputs)),
export(quantized_gm, copy.deepcopy(example_inputs), strict=True),
compile_config=EDGE_COMPILE_CONFIG,
)

Expand Down
11 changes: 5 additions & 6 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,7 @@ def test_qnn_backend_multi_contexts_composite(self):
)
sample_input = module.get_random_input()
edge_prog = to_edge(
torch.export.export(module, sample_input),
torch.export.export(module, sample_input, strict=True),
)
update_spill_fill_size(edge_prog.exported_program())
exec_prog = edge_prog.to_executorch()
Expand Down Expand Up @@ -1957,7 +1957,7 @@ def calibrator(gm):
self.assertEqual(len(exported_progs), 1)
# lower all graph again, the skipped operators will be left in CPU
exec_prog = to_edge(
torch.export.export(graph_module, sample_input),
torch.export.export(graph_module, sample_input, strict=True),
).to_executorch()
self.verify_output(module, sample_input, exec_prog)

Expand Down Expand Up @@ -2004,7 +2004,7 @@ def calibrator(gm):
self.assertEqual(len(exported_progs), 2)
# lower all graph again, the skipped operators will be left in CPU
exec_prog = exec_prog = to_edge(
torch.export.export(graph_module, sample_input),
torch.export.export(graph_module, sample_input, strict=True),
).to_executorch()
self.verify_output(module, sample_input, exec_prog)

Expand Down Expand Up @@ -2041,7 +2041,7 @@ def calibrator(gm):
self.assertEqual(len(exported_progs), 5)
# lower all graph again, the skipped operators will be delegated with fp16
exec_prog = to_edge(
torch.export.export(graph_module, sample_input),
torch.export.export(graph_module, sample_input, strict=True),
).to_executorch()
self.verify_output(module, sample_input, exec_prog)

Expand Down Expand Up @@ -2086,7 +2086,7 @@ def test_qnn_backend_multi_contexts_composite(self):
)
sample_input = module.get_random_input()
edge_prog = to_edge(
torch.export.export(module, sample_input),
torch.export.export(module, sample_input, strict=True),
)
update_spill_fill_size(edge_prog.exported_program())
exec_prog = edge_prog.to_executorch()
Expand Down Expand Up @@ -2721,7 +2721,6 @@ def test_ssd300_vgg16(self):


class TestExampleQaihubScript(TestQNN):

def required_envs(self, conditions=None) -> bool:
conditions = [] if conditions is None else conditions
return all(
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def get_qdq_module(
custom_quant_annotations: Tuple[Callable] = (),
quant_dtype: QuantDtype = QuantDtype.use_8a8w,
) -> torch.fx.GraphModule:
m = torch.export.export(module, inputs).module()
m = torch.export.export(module, inputs, strict=True).module()

quantizer = QnnQuantizer()
quantizer.add_custom_quant_annotations(custom_quant_annotations)
Expand Down
6 changes: 3 additions & 3 deletions backends/qualcomm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def capture_program(
inputs: Tuple[torch.Tensor],
custom_pass_config: FrozenSet[str] = frozenset(),
) -> exir.ExirExportedProgram:
ep = torch.export.export(module, inputs)
ep = torch.export.export(module, inputs, strict=True)
decomposed_ep = ep.run_decompositions(get_decomp_table())
# We choose call_operator by target in ConvertBinaryOpsWithScalar
# because it is the same source_fn_stack for MultiheadAttention
Expand Down Expand Up @@ -551,7 +551,7 @@ def prepare_subgm(subgm, subgm_name):

fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set()
fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set()
graph_module = torch.export.export(nn_module, sample_input).module()
graph_module = torch.export.export(nn_module, sample_input, strict=True).module()
# define node support type
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
Expand Down Expand Up @@ -664,7 +664,7 @@ def forward(self, *inputs):
).default(inputs)

model = Model()
prog = torch.export.export(model, tuple(inputs.values()))
prog = torch.export.export(model, tuple(inputs.values()), strict=True)
# bookkeeping for variables' life cycle
return {
"custom_op": custom_op,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def run_test():
model(*sample_inputs)

program: ExportedProgram = export(
model, sample_inputs, dynamic_shapes=dynamic_shapes
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
)

edge_program = to_edge_transform_and_lower(
Expand Down
1 change: 1 addition & 0 deletions backends/xnnpack/partition/graphs/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(
v,
mask,
),
strict=True,
),
compile_config=get_xnnpack_edge_compile_config(),
)
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def run(
inputs: Tuple[torch.Tensor],
) -> None:
self.exported_program = export(
artifact, inputs, dynamic_shapes=self.dynamic_shapes
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
)

@property
Expand Down
2 changes: 1 addition & 1 deletion build/packaging/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def export_linear_model() -> bytes:

# Export the pytorch model and process for ExecuTorch.
print("Exporting program...")
exported_program = export(LinearModel(), example_inputs)
exported_program = export(LinearModel(), example_inputs, strict=True)
print("Lowering to edge...")
edge_program = to_edge(exported_program)
print("Creating ExecuTorch program...")
Expand Down
2 changes: 1 addition & 1 deletion devtools/backend_debug/tests/test_delegation_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, a, x, b):

m = Model()
inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
edge = to_edge(torch.export.export(m, inputs)).to_backend(
edge = to_edge(torch.export.export(m, inputs, strict=True)).to_backend(
AddMulPartitionerDemo()
)
delegation_info = get_delegation_info(edge.exported_program().graph_module)
Expand Down
1 change: 1 addition & 0 deletions devtools/bundled_program/util/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def get_common_executorch_program() -> (
m_name: export(
StatefulWrapperModule(eager_model, getattr(eager_model, m_name)),
capture_inputs[m_name],
strict=True,
)
for m_name in eager_model.method_names
}
Expand Down
2 changes: 1 addition & 1 deletion devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_test_model_with_bundled_program(self):

def get_test_model_with_manager(self):
f = models.BasicSinMax()
aten_dialect = export(f, f.get_random_inputs())
aten_dialect = export(f, f.get_random_inputs(), strict=True)
edge_program: EdgeProgramManager = to_edge(
aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
Expand Down
7 changes: 2 additions & 5 deletions docs/source/tutorials_source/devtools-integration-tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ def forward(self, x):

model = Net()

aten_model: ExportedProgram = export(
model,
(torch.randn(1, 1, 32, 32),),
)
aten_model: ExportedProgram = export(model, (torch.randn(1, 1, 32, 32),), strict=True)

edge_program_manager: EdgeProgramManager = to_edge(
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
Expand Down Expand Up @@ -141,7 +138,7 @@ def forward(self, x):

# Step 1: ExecuTorch Program Export
m_name = "forward"
method_graphs = {m_name: export(model, (torch.randn(1, 1, 32, 32),))}
method_graphs = {m_name: export(model, (torch.randn(1, 1, 32, 32),), strict=True)}

# Step 2: Construct Method Test Suites
inputs = [[torch.randn(1, 1, 32, 32)] for _ in range(2)]
Expand Down
28 changes: 15 additions & 13 deletions docs/source/tutorials_source/export-to-executorch-tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


example_args = (torch.randn(1, 3, 256, 256),)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args, strict=True)
print(aten_dialect)

######################################################################
Expand Down Expand Up @@ -101,7 +101,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:


example_args = (torch.randn(3, 3), torch.randn(3, 3))
aten_dialect: ExportedProgram = export(Basic(), example_args)
aten_dialect: ExportedProgram = export(Basic(), example_args, strict=True)

# Works correctly
print(aten_dialect.module()(torch.ones(3, 3), torch.ones(3, 3)))
Expand Down Expand Up @@ -131,7 +131,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
dim1_x = Dim("dim1_x", min=1, max=10)
dynamic_shapes = {"x": {1: dim1_x}, "y": {1: dim1_x}}
aten_dialect: ExportedProgram = export(
Basic(), example_args, dynamic_shapes=dynamic_shapes
Basic(), example_args, dynamic_shapes=dynamic_shapes, strict=True
)
print(aten_dialect)

Expand Down Expand Up @@ -213,7 +213,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
print("Quantized Graph")
print(converted_graph)

aten_dialect: ExportedProgram = export(converted_graph, example_args)
aten_dialect: ExportedProgram = export(converted_graph, example_args, strict=True)
print("ATen Dialect Graph")
print(aten_dialect)

Expand Down Expand Up @@ -243,7 +243,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
from executorch.exir import EdgeProgramManager, to_edge

example_args = (torch.randn(1, 3, 256, 256),)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args, strict=True)

edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
Expand All @@ -267,10 +267,10 @@ def forward(self, x):


encode_args = (torch.randn(1, 10),)
aten_encode: ExportedProgram = export(Encode(), encode_args)
aten_encode: ExportedProgram = export(Encode(), encode_args, strict=True)

decode_args = (torch.randn(1, 5),)
aten_decode: ExportedProgram = export(Decode(), decode_args)
aten_decode: ExportedProgram = export(Decode(), decode_args, strict=True)

edge_program: EdgeProgramManager = to_edge(
{"encode": aten_encode, "decode": aten_decode}
Expand All @@ -291,7 +291,7 @@ def forward(self, x):
# rather than the ``torch.ops.aten`` namespace.

example_args = (torch.randn(1, 3, 256, 256),)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args)
aten_dialect: ExportedProgram = export(SimpleConv(), example_args, strict=True)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
print("Edge Dialect Graph")
print(edge_program.exported_program())
Expand Down Expand Up @@ -357,7 +357,7 @@ def forward(self, x):

# Export and lower the module to Edge Dialect
example_args = (torch.ones(1),)
aten_dialect: ExportedProgram = export(LowerableModule(), example_args)
aten_dialect: ExportedProgram = export(LowerableModule(), example_args, strict=True)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
to_be_lowered_module = edge_program.exported_program()

Expand Down Expand Up @@ -423,7 +423,7 @@ def forward(self, x):


example_args = (torch.ones(1),)
aten_dialect: ExportedProgram = export(ComposedModule(), example_args)
aten_dialect: ExportedProgram = export(ComposedModule(), example_args, strict=True)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
Expand Down Expand Up @@ -461,7 +461,7 @@ def forward(self, a, x, b):


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
aten_dialect: ExportedProgram = export(Foo(), example_args)
aten_dialect: ExportedProgram = export(Foo(), example_args, strict=True)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
print("Edge Dialect graph")
Expand Down Expand Up @@ -495,7 +495,7 @@ def forward(self, a, x, b):


example_args = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
aten_dialect: ExportedProgram = export(Foo(), example_args)
aten_dialect: ExportedProgram = export(Foo(), example_args, strict=True)
edge_program: EdgeProgramManager = to_edge(aten_dialect)
exported_program = edge_program.exported_program()
delegated_program = edge_program.to_backend(AddMulPartitionerDemo())
Expand Down Expand Up @@ -577,7 +577,9 @@ def forward(self, x):
pre_autograd_aten_dialect = export_for_training(M(), example_args).module()
# Optionally do quantization:
# pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))
aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)
aten_dialect: ExportedProgram = export(
pre_autograd_aten_dialect, example_args, strict=True
)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
# Optionally do delegation:
# edge_program = edge_program.to_backend(CustomBackendPartitioner)
Expand Down
9 changes: 6 additions & 3 deletions examples/apple/coreml/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def partition_module_to_coreml(module):

def lower_module_to_coreml(module, compile_specs, example_inputs):
module = module.eval()
edge = to_edge(export(module, example_inputs), compile_config=_EDGE_COMPILE_CONFIG)
edge = to_edge(
export(module, example_inputs, strict=True), compile_config=_EDGE_COMPILE_CONFIG
)
# All of the subsequent calls on the edge_dialect_graph generated above (such as delegation or
# to_executorch()) are done in place and the graph is also modified in place. For debugging purposes
# we would like to keep a copy of the original edge dialect graph and hence we create a deepcopy of
Expand All @@ -107,7 +109,8 @@ def lower_module_to_coreml(module, compile_specs, example_inputs):
def export_lowered_module_to_executorch_program(lowered_module, example_inputs):
lowered_module(*example_inputs)
exec_prog = to_edge(
export(lowered_module, example_inputs), compile_config=_EDGE_COMPILE_CONFIG
export(lowered_module, example_inputs, strict=True),
compile_config=_EDGE_COMPILE_CONFIG,
).to_executorch(config=exir.ExecutorchBackendConfig(extract_delegate_segments=True))

return exec_prog
Expand Down Expand Up @@ -170,7 +173,7 @@ def main():

if args.use_partitioner:
model.eval()
exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten = torch.export.export(model, example_inputs, strict=True)

edge_program_manager = exir.to_edge(exir_program_aten)
edge_copy = copy.deepcopy(edge_program_manager)
Expand Down
Loading

0 comments on commit e824a31

Please sign in to comment.