Skip to content

Commit

Permalink
dynamo_export works for many small models
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed May 16, 2024
1 parent d3c41f7 commit e9e81b0
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 40 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def output_names(self):
return get_io_names(otypes, self.disabled_deployment_output_names)

def forward_for_export(
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
This forward is used when we need to export the model to ONNX format.
Expand All @@ -217,12 +217,12 @@ def forward_for_export(
"""
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
encoder_output = enc_fun(audio_signal=audio_signal, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
audio_signal=audio_signal,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
"embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()),
}

def forward_for_export(self, processed_signal, processed_signal_len):
encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len)
def forward_for_export(self, audio_signal, length):
encoded, length = self.encoder(audio_signal=audio_signal, length=length)
logits, embs = self.decoder(encoder_output=encoded, length=length)
return logits, embs

Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@ def forward(self, encoder_output, length=None):
embs = []

for layer in self.emb_layers:
pool, emb = layer(pool), layer[: self.emb_id](pool)
emb = layer[: self.emb_id](pool)
pool = layer(pool)
embs.append(emb)

pool = pool.squeeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths):
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
mask = ~mask # 0 represents value, 1 represents pad
x = x.float() # For stable AMP, SE must be computed at fp32.
x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
y = self._se_pool_step(x, mask) # [B, C, 1]
y = y.transpose(1, -1) # [B, 1, C]
y = self.fc(y) # [B, 1, C]
Expand Down
2 changes: 1 addition & 1 deletion nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def __init__(
self.ignore_collections = ignore_collections

def __call__(self, wrapped):
return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped)
return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped)

def unwrapped_call(self, wrapped):
return wrapped
Expand Down
61 changes: 49 additions & 12 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
from abc import ABC
from typing import Dict, List, Optional, Union

import onnx
import torch
from pytorch_lightning.core.module import _jit_is_scripting

from nemo.core.classes import typecheck
from nemo.core.neural_types import NeuralType
from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names
from nemo.utils import logging
from nemo.utils import logging, monkeypatched
from nemo.utils.export_utils import (
ExportFormat,
augment_filename,
get_export_format,
parse_input_example,
rename_onnx_io,
replace_for_export,
verify_runtime,
verify_torchscript,
Expand Down Expand Up @@ -177,7 +179,7 @@ def _export(
with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting():

if input_example is None:
input_example = self.input_module.input_example()
input_example = self.input_module.input_example(max_batch=2)

# Remove i/o examples from args we propagate to enclosed Exportables
my_args.pop('output')
Expand All @@ -191,7 +193,9 @@ def _export(
input_list, input_dict = parse_input_example(input_example)
input_names = self.input_names
output_names = self.output_names
output_example = tuple(self.forward(*input_list, **input_dict))
output_example = self.forward(*input_list, **input_dict)
if not isinstance(output_example, tuple):
output_example = (output_example,)

if check_trace:
if isinstance(check_trace, bool):
Expand Down Expand Up @@ -219,16 +223,49 @@ def _export(
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
if use_dynamo:
dynamic_shapes = {}
batch = torch.export.Dim("batch", max=128)
for name, dims in dynamic_axes.items():
ds = {}
for d in dims:
if d == 0:
ds[d] = batch
# this currently fails, https://github.com/pytorch/pytorch/issues/126127
# else:
# ds[d] = torch.export.Dim(name + '__' + str(d))
dynamic_shapes[name] = ds
else:
dynamic_shapes = dynamic_axes
if use_dynamo:
options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes)
ex_model = torch.export.export(
jitted_model, tuple(input_list), kwargs=input_dict, strict=False
)
ex_model = ex_model.run_decompositions()
ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options)
ex.save(output, model_state=jitted_model.state_dict())
input_names = None
import onnxscript

# https://github.com/microsoft/onnxscript/issues/1544
onnxscript.optimizer.constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 * 64

# https://github.com/pytorch/pytorch/issues/126339
with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None):
print("Running export.export, dynamic shapes:\n", dynamic_shapes)

ex_model = torch.export.export(
jitted_model,
tuple(input_list),
kwargs=input_dict,
dynamic_shapes=dynamic_shapes,
strict=False,
)
ex_model = ex_model.run_decompositions()

print("Running torch.onnx.dynamo_export ...")

options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True)
ex_module = ex_model.module()
ex = torch.onnx.dynamo_export(ex_module, *input_list, **input_dict, export_options=options)
ex.save(output) # , model_state=ex_module.state_dict())
del ex
# Rename I/O after save - don't want to risk modifying ex._model_proto
rename_onnx_io(output, input_names, output_names)
# input_names=None
else:
torch.onnx.export(
jitted_model,
Expand Down
1 change: 1 addition & 0 deletions nemo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
avoid_float16_autocast_context,
cast_all,
cast_tensor,
monkeypatched,
)
from nemo.utils.dtype import str_to_dtype
from nemo.utils.nemo_logging import Logger as _Logger
Expand Down
11 changes: 10 additions & 1 deletion nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
from contextlib import contextmanager, nullcontext

import torch

Expand Down Expand Up @@ -91,3 +91,12 @@ def forward(self, *args):
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
else:
return self.mod.forward(*args)


@contextmanager
def monkeypatched(object, name, patch):
""" Temporarily monkeypatches an object. """
pre_patched_value = getattr(object, name)
setattr(object, name, patch)
yield object
setattr(object, name, pre_patched_value)
30 changes: 28 additions & 2 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
output_example = model.forward(*input_list, **input_dict)
if not isinstance(output_example, tuple):
output_example = (output_example,)
ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list)
all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance)
status = "SUCCESS" if all_good else "FAIL"
Expand Down Expand Up @@ -221,10 +223,12 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
try:
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
this_good = False
except Exception: # there may ne size mismatch and it may be OK
except Exception: # there may be size mismatch and it may be OK
this_good = False
if not this_good:
logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
logging.info(
f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}"
)
all_good = False
return all_good

Expand Down Expand Up @@ -479,3 +483,25 @@ def add_casts_around_norms(model: nn.Module):
"MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll),
}
replace_modules(model, default_cast_replacements)


def rename_onnx_io(output, input_names, output_names):
onnx_model = onnx.load(output)
rename_map = {}
for inp, name in zip(onnx_model.graph.input, input_names):
rename_map[inp.name] = name
for out, name in zip(onnx_model.graph.output, output_names):
rename_map[out.name] = name
for n in onnx_model.graph.node:
for inp in range(len(n.input)):
if n.input[inp] in rename_map:
n.input[inp] = rename_map[n.input[inp]]
for out in range(len(n.output)):
if n.output[out] in rename_map:
n.output[out] = rename_map[n.output[out]]

for i in range(len(onnx_model.graph.input)):
onnx_model.graph.input[i].name = input_names[i]
for i in range(len(onnx_model.graph.output)):
onnx_model.graph.output[i].name = output_names[i]
onnx.save(onnx_model, output)
47 changes: 37 additions & 10 deletions tests/collections/asr/test_asr_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__

# from nemo.core.classes import typecheck
# typecheck.enable_wrapping(enabled=False)


NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__)


Expand All @@ -56,6 +52,8 @@ def test_EncDecCTCModel_export_to_onnx(self):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -68,6 +66,8 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -78,6 +78,8 @@ def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model):
model.export(output=filename)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -88,6 +90,9 @@ def test_EncDecCitrinetModel_export_to_onnx(self, citrinet_model):
model.export(output=filename)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
Expand Down Expand Up @@ -127,6 +132,9 @@ def test_EncDecCitrinetModel_limited_SE_export_to_onnx(self, citrinet_model):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -136,7 +144,7 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model):
with tempfile.TemporaryDirectory() as tmpdir:
fn = 'citri_rnnt.onnx'
filename = os.path.join(tmpdir, fn)
files, descr = model.export(output=filename, verbose=False)
files, descr = model.export(output=filename, dynamic_axes={}, verbose=False)

encoder_filename = os.path.join(tmpdir, 'encoder-' + fn)
assert files[0] == encoder_filename
Expand All @@ -145,6 +153,10 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model):
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert len(onnx_model.graph.input) == 2
assert len(onnx_model.graph.output) == 2
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'outputs'
assert onnx_model.graph.output[1].name == 'encoded_lengths'

decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn)
assert files[1] == decoder_joint_filename
Expand All @@ -159,12 +171,21 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model):

# enc_logits + (all decoder inputs - state tuple) + flattened state list
assert len(onnx_model.graph.input) == (1 + (len(input_examples) - 1) + num_states)
assert onnx_model.graph.input[0].name == 'encoder_outputs'
assert onnx_model.graph.input[1].name == 'targets'
assert onnx_model.graph.input[2].name == 'target_length'

if num_states > 0:
for idx, ip in enumerate(onnx_model.graph.input[3:]):
assert ip.name == "input_" + state_name + '_' + str(idx + 1)

assert len(onnx_model.graph.output) == (len(input_examples) - 1) + num_states
assert onnx_model.graph.output[0].name == 'outputs'
assert onnx_model.graph.output[1].name == 'prednet_lengths'

if num_states > 0:
for idx, op in enumerate(onnx_model.graph.output[2:]):
assert op.name == "output_" + state_name + '_' + str(idx + 1)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -185,6 +206,8 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model):
assert ts_encoder is not None

arguments = ts_encoder.forward.schema.arguments[1:] # First value is `self`
assert arguments[0].name == 'audio_signal'
assert arguments[1].name == 'length'

decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn)
assert files[1] == decoder_joint_filename
Expand All @@ -202,6 +225,13 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model):

# enc_logits + (all decoder inputs - state tuple) + flattened state list
assert len(ts_decoder_joint_args) == (1 + (len(input_examples) - 1) + num_states)
assert ts_decoder_joint_args[0].name == 'encoder_outputs'
assert ts_decoder_joint_args[1].name == 'targets'
assert ts_decoder_joint_args[2].name == 'target_length'

if num_states > 0:
for idx, ip in enumerate(ts_decoder_joint_args[3:]):
assert ip.name == "input_" + state_name + '_' + str(idx + 1)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand Down Expand Up @@ -235,6 +265,8 @@ def test_EncDecCTCModel_adapted_export_to_onnx(self):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logprobs'

def setup_method(self):
self.preprocessor = {
Expand Down Expand Up @@ -638,8 +670,3 @@ def squeezeformer_model():
)
conformer_model = EncDecCTCModel(cfg=modelConfig)
return conformer_model


if __name__ == "__main__":
t = TestExportable()
t.test_EncDecClassificationModel_export_to_onnx(speech_classification_model())
3 changes: 0 additions & 3 deletions tests/collections/nlp/test_nlp_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
import torch
import wget
from omegaconf import DictConfig, OmegaConf
from nemo.core.classes import typecheck

typecheck.enable_wrapping(enabled=False)

from nemo.collections import nlp as nemo_nlp
from nemo.collections.nlp.models import IntentSlotClassificationModel
Expand Down
Loading

0 comments on commit e9e81b0

Please sign in to comment.