diff --git a/nemo/export/tensorrt_lazy_compiler.py b/nemo/export/tensorrt_lazy_compiler.py index b751ff45d74f..c7a3f0738a43 100644 --- a/nemo/export/tensorrt_lazy_compiler.py +++ b/nemo/export/tensorrt_lazy_compiler.py @@ -24,7 +24,7 @@ import torch from nemo.utils.export_utils import add_casts_around_norms, replace_for_export -from nemo.utils.import_utils import safe_import, safe_import_from +from nemo.utils.import_utils import safe_import polygraphy, polygraphy_imported = safe_import("polygraphy") if polygraphy_imported: @@ -596,17 +596,20 @@ def add_profile(id, val): # Use temporary directory for easy cleanup in case of external weights with tempfile.TemporaryDirectory() as tmpdir: - unrolled_input = unroll_input(self.input_names, input_example) + if export_args.get("dynamo", False): + input_names = None + else: + input_names = list(unroll_input(self.input_names, input_example).keys()) onnx_path = str(Path(tmpdir) / "model.onnx") self.logger.info( - f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n" + f"Exporting to {onnx_path}:\n" + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" ) torch.onnx.export( model, - input_example, + (input_example,), onnx_path, - input_names=list(unrolled_input.keys()), + input_names=input_names, output_names=self.output_names, **export_args, ) diff --git a/tests/export/test_trt_compile.py b/tests/export/test_trt_compile.py index 758ab3f34437..95be5f88bf01 100644 --- a/tests/export/test_trt_compile.py +++ b/tests/export/test_trt_compile.py @@ -16,10 +16,9 @@ from typing import List import torch -from parameterized import parameterized from nemo.export import trt_compile -from nemo.utils.import_utils import safe_import, safe_import_from +from nemo.utils.import_utils import safe_import trt, trt_imported = safe_import("tensorrt") torch_tensorrt, torch_trt_imported = safe_import("torch_tensorrt") @@ -67,8 +66,11 @@ def test_torch_trt(self): x = torch.randn(1, 16).to("cuda") with tempfile.TemporaryDirectory() as tempdir: - args = {"method": "torch_trt"} - input_example = x + args = { + "method": "torch_trt", + "dynamic_batchsize": [1, 4, 8], + } + input_example = (x,) output_example = model(*input_example) trt_compile( model, @@ -81,11 +83,49 @@ def test_torch_trt(self): self.assertIsNotNone(model._trt_compiler.engine) torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + def test_profiles(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": False, + }, + "input_profiles": [ + { + "x_0": [[1, 8], [2, 16], [2, 32]], + "x_1": [[1, 8], [2, 16], [2, 32]], + "x_2": [[1, 8], [2, 16], [2, 32]], + "y": [[1, 8], [2, 16], [2, 32]], + "z": [[1, 8], [1, 16], [1, 32]], + } + ], + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_dynamo_trt", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + def test_lists(self): model = ListAdd().cuda() with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: args = { + "export_args": { + "dynamo": True, + }, "output_lists": [[-1], [2], []], } x = torch.randn(1, 16).to("cuda")