Skip to content

Commit

Permalink
Adding tests for dynamic axes
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Nov 14, 2024
1 parent 22a4704 commit 06bc7ab
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
13 changes: 8 additions & 5 deletions nemo/export/tensorrt_lazy_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

from nemo.utils.export_utils import add_casts_around_norms, replace_for_export
from nemo.utils.import_utils import safe_import, safe_import_from
from nemo.utils.import_utils import safe_import

polygraphy, polygraphy_imported = safe_import("polygraphy")
if polygraphy_imported:
Expand Down Expand Up @@ -596,17 +596,20 @@ def add_profile(id, val):

# Use temporary directory for easy cleanup in case of external weights
with tempfile.TemporaryDirectory() as tmpdir:
unrolled_input = unroll_input(self.input_names, input_example)
if export_args.get("dynamo", False):
input_names = None
else:
input_names = list(unroll_input(self.input_names, input_example).keys())
onnx_path = str(Path(tmpdir) / "model.onnx")
self.logger.info(
f"Exporting to {onnx_path}:\nunrolled_inputs={list(unrolled_input.keys())}\n"
f"Exporting to {onnx_path}:\n"
+ f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}"
)
torch.onnx.export(
model,
input_example,
(input_example,),
onnx_path,
input_names=list(unrolled_input.keys()),
input_names=input_names,
output_names=self.output_names,
**export_args,
)
Expand Down
48 changes: 44 additions & 4 deletions tests/export/test_trt_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
from typing import List

import torch
from parameterized import parameterized

from nemo.export import trt_compile
from nemo.utils.import_utils import safe_import, safe_import_from
from nemo.utils.import_utils import safe_import

trt, trt_imported = safe_import("tensorrt")
torch_tensorrt, torch_trt_imported = safe_import("torch_tensorrt")
Expand Down Expand Up @@ -67,8 +66,11 @@ def test_torch_trt(self):
x = torch.randn(1, 16).to("cuda")

with tempfile.TemporaryDirectory() as tempdir:
args = {"method": "torch_trt"}
input_example = x
args = {
"method": "torch_trt",
"dynamic_batchsize": [1, 4, 8],
}
input_example = (x,)
output_example = model(*input_example)
trt_compile(
model,
Expand All @@ -81,11 +83,49 @@ def test_torch_trt(self):
self.assertIsNotNone(model._trt_compiler.engine)
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)

def test_profiles(self):
model = ListAdd().cuda()

with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
args = {
"export_args": {
"dynamo": False,
},
"input_profiles": [
{
"x_0": [[1, 8], [2, 16], [2, 32]],
"x_1": [[1, 8], [2, 16], [2, 32]],
"x_2": [[1, 8], [2, 16], [2, 32]],
"y": [[1, 8], [2, 16], [2, 32]],
"z": [[1, 8], [1, 16], [1, 32]],
}
],
"output_lists": [[-1], [2], []],
}
x = torch.randn(1, 16).to("cuda")
y = torch.randn(1, 16).to("cuda")
z = torch.randn(1, 16).to("cuda")
input_example = ([x, y, z], y.clone(), z.clone())
output_example = model(*input_example)
trt_compile(
model,
f"{tmpdir}/test_dynamo_trt",
args=args,
)
self.assertIsNone(model._trt_compiler.engine)
trt_output = model(*input_example)
# Check that lazy TRT build succeeded
self.assertIsNotNone(model._trt_compiler.engine)
torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)

def test_lists(self):
model = ListAdd().cuda()

with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
args = {
"export_args": {
"dynamo": True,
},
"output_lists": [[-1], [2], []],
}
x = torch.randn(1, 16).to("cuda")
Expand Down

0 comments on commit 06bc7ab

Please sign in to comment.