Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 7, 2024
1 parent 91fc08b commit bafbdce
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = distutils,graphviz,mmdeploy,numpy,packaging,setuptools,tensorrt,torch,torchvision
known_third_party = distutils,graphviz,numpy,packaging,setuptools,tensorrt,torch,torchvision
6 changes: 4 additions & 2 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_convert(tmp_path):

trt_model = module2trt(
model,
args=[torch.rand(1, 3, 224, 224).cuda()],
args=[torch.rand(1, 3, 32, 32).cuda()],
)

model_path = tmp_path / 'tmp.pth'
Expand All @@ -18,9 +18,11 @@ def test_convert(tmp_path):
trt_model = TRTModule()
trt_model.load_state_dict(torch.load(model_path))

x = torch.rand(1, 3, 224, 224).cuda()
x = torch.rand(1, 3, 32, 32).cuda()
with torch.no_grad():
y = model(x)
y_trt = trt_model(x)

print(y)
print(y_trt)
torch.testing.assert_close(y, y_trt)
158 changes: 158 additions & 0 deletions tests/test_converters/test_grid_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import pytest
import torch
from torch import nn
from torch.nn import functional as F
from torch2trt_dynamic import (module2trt,
BuildEngineConfig)


class _TestModel(nn.Module):
def __init__(self, mode, padding_mode, align_corners) -> None:
super().__init__()
self.mode = mode
self.padding_mode = padding_mode
self.align_corners = align_corners

def forward(self, input, grid):
return F.grid_sample(
input,
grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners)


class TestGridSample:

@pytest.fixture
def hw_in(self, request):
yield request.param

@pytest.fixture
def hw_out(self, request):
yield request.param

@pytest.fixture
def mode(self, request):
yield request.param

@pytest.fixture
def padding_mode(self, request):
yield request.param

@pytest.fixture
def align_corners(self, request):
yield request.param

@pytest.fixture
def batch(self):
yield 2

@pytest.fixture
def channel(self):
yield 4

@pytest.fixture
def deep_in(self):
yield 4

@pytest.fixture
def deep_out(self):
yield 2

@pytest.fixture
def input4d(self, batch, channel, hw_in):
yield torch.rand(batch, channel, *hw_in).cuda()

@pytest.fixture
def input5d(self, batch, channel, deep_in, hw_in):
yield torch.rand(batch, channel, deep_in, *hw_in).cuda()

@pytest.fixture
def grid4d(self, batch, hw_out):
lin_w = torch.linspace(-1, 1, hw_out[1]
)[:, None].repeat(1, hw_out[0])
lin_h = torch.linspace(-1, 1, hw_out[0]
).repeat(hw_out[1], 1)
grid = torch.stack([lin_w, lin_h], dim=-1)
grid = grid[None].repeat(batch, 1, 1, 1)
yield grid.cuda()

@pytest.fixture
def grid5d(self, batch, deep_out, hw_out):
lin_d = torch.linspace(
-1, 1, deep_out
)[:, None, None].repeat(1, hw_out[1], hw_out[0])
lin_w = torch.linspace(
-1, 1, hw_out[1]
)[None, :, None].repeat(deep_out, 1, hw_out[0])
lin_h = torch.linspace(
-1, 1, hw_out[0]
)[None, None, :].repeat(deep_out, hw_out[1], 1)
grid = torch.stack([lin_w, lin_h, lin_d], dim=-1)
grid = grid[None].repeat(batch, 1, 1, 1, 1)
yield grid.cuda()

@pytest.fixture
def model(self, mode, padding_mode, align_corners):
kwargs = dict(mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)
yield _TestModel(**kwargs)

def make_config(self, input, grid):
input_shape = tuple(input.shape)
input_post = input_shape[2:]
input_post_max = [x * 2 for x in input_post]
input_post_min = [x // 2 for x in input_post]
input_max = (*input_shape[:2],
*input_post_max)
input_min = (*input_shape[:2],
*input_post_min)
grid_shape = tuple(grid.shape)
grid_post = grid_shape[1:-1]
grid_post_max = [x * 2 for x in grid_post]
grid_post_min = [x // 2 for x in grid_post]
grid_max = (grid_shape[0],
*grid_post_max,
grid_shape[-1])
grid_min = (grid_shape[0],
*grid_post_min,
grid_shape[-1])
config = BuildEngineConfig(
shape_ranges=dict(
input=dict(
min=input_min,
opt=input_shape,
max=input_max
),
grid=dict(
min=grid_min,
opt=grid_shape,
max=grid_max
)
)
)
return config

@pytest.mark.parametrize("hw_in,hw_out", [
((8, 16), (16, 32)),
((16, 32), (8, 16)),
])
@pytest.mark.parametrize('mode', ['bilinear', 'nearest', 'bicubic'])
@pytest.mark.parametrize('padding_mode', ['zeros', 'border', 'reflection'])
@pytest.mark.parametrize('align_corners', [True, False])
def test_grid_sample_4d(self, input4d, grid4d, model):

dummy_input = torch.zeros_like(input4d)
dummy_grid = torch.zeros_like(grid4d)
config = self.make_config(dummy_input, dummy_grid)
trt_model = module2trt(model,
args=[dummy_input, dummy_grid],
config=config)

args = [input4d, grid4d]
with torch.inference_mode():
gt = model(*args)
out = trt_model(*args)
torch.testing.assert_close(out, gt)
41 changes: 41 additions & 0 deletions tests/test_converters/test_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
from torch import nn
from torch2trt_dynamic import module2trt


class _TestModel(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.gn = nn.GroupNorm(*args, **kwargs)

def forward(self, input):
return self.gn(input)


class TestGroupNorm:

@pytest.fixture
def num_channels(self):
yield 4

@pytest.fixture
def input(self, num_channels):
yield torch.rand(2, num_channels, 8, 16).cuda()

@pytest.fixture
def num_groups(self):
yield 2

def test_group_norm(self, input, num_groups):
num_channels = input.size(1)
model = _TestModel(num_groups, num_channels)
model = model.eval().cuda()
dummy_input = torch.zeros_like(input)
trt_model = module2trt(model,
args=[dummy_input])

with torch.inference_mode():
gt = model(input)
out = trt_model(input)
torch.testing.assert_close(out, gt)
32 changes: 8 additions & 24 deletions torch2trt_dynamic/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

# supported converters will override dummy converters

from . import AdaptiveAvgPool2d # noqa: F401
from . import AdaptiveMaxPool2d # noqa: F401
from . import adaptive_avg_pool1d # noqa: F401
from . import adaptive_avg_pool2d # noqa: F401
from . import adaptive_max_pool1d # noqa: F401
from . import adaptive_max_pool2d # noqa: F401
from . import add # noqa: F401
from . import grid_sample # noqa: F401
from .activation import (convert_elu, convert_leaky_relu, convert_selu,
convert_softplus, convert_softsign)
from .addcmul import convert_addcmul, test_addcmul
Expand Down Expand Up @@ -327,12 +334,6 @@

try:
# custom plugin support
from .adaptive_avg_pool1d import convert_adaptive_avg_pool1d
from .adaptive_max_pool1d import convert_adaptive_max_pool1d
from .adaptive_avg_pool2d import convert_adaptive_avg_pool2d
from .adaptive_max_pool2d import convert_adaptive_max_pool2d
from .AdaptiveAvgPool2d import convert_AdaptiveAvgPool2d
from .AdaptiveMaxPool2d import convert_AdaptiveMaxPool2d
from .bmm import convert_bmm
from .cummax import convert_cummax
from .cummin import convert_cummin
Expand All @@ -341,25 +342,12 @@
from .deform_conv2d import convert_deform_conv2d
from .Embedding import convert_embedding, convert_embedding_forward
from .gather import convert_gather
from .grid_sample import convert_grid_sample
from .GroupNorm import convert_GroupNorm
from . import GroupNorm # noqa: F401
from .nms import convert_nms
from .roi_align import convert_roi_align, convert_RoiAlign
from .roi_pool import convert_roi_pool, convert_RoIPool
from .unfold import convert_unfold

# adaptive_avg_pool1d
__all__ += ['convert_adaptive_avg_pool1d']
# adaptive_max_pool1d
__all__ += ['convert_adaptive_max_pool1d']
# adaptive_avg_pool2d
__all__ += ['convert_adaptive_avg_pool2d']
# adaptive_max_pool2d
__all__ += ['convert_adaptive_max_pool2d']
# AdaptiveAvgPool2d
__all__ += ['convert_AdaptiveAvgPool2d']
# AdaptiveMaxPool2d
__all__ += ['convert_AdaptiveMaxPool2d']
# bmm
__all__ += ['convert_bmm']
# cummax
Expand All @@ -376,10 +364,6 @@
__all__ += ['convert_embedding', 'convert_embedding_forward']
# gather
__all__ += ['convert_gather']
# grid_sample
__all__ += ['convert_grid_sample']
# GroupNorm
__all__ += ['convert_GroupNorm']
# nms
__all__ += ['convert_nms']
# roi_align
Expand Down
2 changes: 1 addition & 1 deletion torch2trt_dynamic/converters/adaptive_avg_pool1d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorrt as trt
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter,
trt_)

Expand All @@ -20,6 +19,7 @@ def convert_adaptive_avg_pool1d(ctx):
axes, keepdim)
output._trt = layer.get_output(0)
else:
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
output_size = (output_size, 1)

# input.unsqueeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion torch2trt_dynamic/converters/adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorrt as trt
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter,
trt_)

Expand All @@ -25,6 +24,7 @@ def convert_adaptive_avg_pool2d(ctx):
axes, keepdim)
output._trt = layer.get_output(0)
else:
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
plugin = create_adaptivepool_plugin(
'adaptive_avg_pool2d_' + str(id(input)),
output_size=output_size,
Expand Down
2 changes: 1 addition & 1 deletion torch2trt_dynamic/converters/adaptive_max_pool1d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorrt as trt
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
from torch2trt_dynamic.torch2trt_dynamic import (get_arg, tensorrt_converter,
trt_)

Expand All @@ -20,6 +19,7 @@ def convert_adaptive_max_pool1d(ctx):
axes, keepdim)
output._trt = layer.get_output(0)
else:
from torch2trt_dynamic.plugins import create_adaptivepool_plugin
output_size = (output_size, 1)

# input.unsqueeze(-1)
Expand Down
38 changes: 17 additions & 21 deletions torch2trt_dynamic/converters/grid_sample.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import tensorrt as trt

from ..plugins import create_gridsample_plugin
from ..torch2trt_dynamic import get_arg, tensorrt_converter, trt_

_MODE_MAP = dict(
bilinear=trt.ResizeMode.LINEAR,
nearest=trt.ResizeMode.NEAREST,
bicubic=trt.ResizeMode.CUBIC)

_PAD_MODE_MAP = dict(
zeros=trt.SampleMode.FILL,
border=trt.SampleMode.CLAMP,
reflection=trt.SampleMode.REFLECT)


@tensorrt_converter('torch.nn.functional.grid_sample')
def convert_grid_sample(ctx):
Expand All @@ -17,25 +26,12 @@ def convert_grid_sample(ctx):
input_trt = trt_(ctx.network, input)
grid_trt = trt_(ctx.network, grid)

if mode == 'bilinear':
mode = trt.ResizeMode.LINEAR
elif mode == 'nearest':
mode = trt.ResizeMode.NEAREST

if padding_mode == 'zeros':
padding_mode = 0
elif padding_mode == 'border':
padding_mode = 1
elif padding_mode == 'reflection':
padding_mode = 2

plugin = create_gridsample_plugin(
'torch_gridsample_' + str(id(input)),
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)

layer = ctx.network.add_plugin_v2(
inputs=[input_trt, grid_trt], plugin=plugin)
mode = _MODE_MAP[mode]
padding_mode = _PAD_MODE_MAP[padding_mode]

layer = ctx.network.add_grid_sample(input_trt, grid_trt)
layer.interpolation_mode = mode
layer.sample_mode = padding_mode
layer.align_corners = align_corners

output._trt = layer.get_output(0)
Loading

0 comments on commit bafbdce

Please sign in to comment.