Skip to content

Commit

Permalink
Ininial WARs to implement dynamo option for export
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed May 8, 2024
1 parent b401fde commit 53102bc
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 72 deletions.
27 changes: 16 additions & 11 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3
ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:24.04-py3

# build an image that includes only the nemo dependencies, ensures that dependencies
# are included first for optimal caching, and useful for building a development
Expand Down Expand Up @@ -61,20 +61,17 @@ RUN apt-get update && \
libgts-dev && \
rm -rf /var/lib/apt/lists/*

RUN pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
RUN pip3 install onnxscript==0.1.0.dev20240430

WORKDIR /workspace/
# Install megatron core, this can be removed once 0.3 pip package is released
# We leave it here in case we need to work off of a specific commit in main
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \
git cherry-pick -n e69187bc3679ea5841030a165d587bb48b56ee77 && \
pip install .

# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout f058162b215791b15507bb542f22ccfde49c872d && \
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
RUN pip3 install packaging

# Transformer Engine 1.2.0
RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \
Expand All @@ -84,6 +81,12 @@ RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \
git submodule init && git submodule update && \
NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .

# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
sed -i '178d' setup.py && \
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --group_norm --distributed_adam --deprecated_fused_adam" ./

WORKDIR /tmp/

# uninstall stuff from base container
Expand Down Expand Up @@ -152,11 +155,13 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \

# Install NeMo
RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]"
RUN apt-get install -y python3
RUN alias python=python3

# Check install
RUN python -c "import nemo.collections.nlp as nemo_nlp" && \
python -c "import nemo.collections.tts as nemo_tts" && \
python -c "import nemo_text_processing.text_normalization as text_normalization"
RUN python3 -c "import nemo.collections.nlp as nemo_nlp" && \
python3 -c "import nemo.collections.tts as nemo_tts" && \
python3 -c "import nemo_text_processing.text_normalization as text_normalization"


# copy scripts/examples/tests into container for end user
Expand Down
21 changes: 12 additions & 9 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def __init__(
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
self.exact_pad = exact_pad

if exact_pad:
logging.info("STFT using exact pad")
Expand All @@ -305,15 +306,6 @@ def __init__(
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.stft = lambda x: torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

self.normalize = normalize
self.log = log
Expand Down Expand Up @@ -372,6 +364,17 @@ def __init__(
logging.debug(f"using grads: {use_grads}")
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")

def stft(self, x):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

except (ImportError, ModuleNotFoundError) as e:
HAVE_MEGATRON_CORE = False


Expand Down
20 changes: 19 additions & 1 deletion nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,8 +1004,17 @@ def __init__(

self.ignore_collections = ignore_collections

def __call__(self, wrapped):
return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped)

def unwrapped_call(self, wrapped):
return wrapped

def wrapped_call(self, wrapped):
return self.decorated_call(wrapped)

@wrapt.decorator(enabled=is_typecheck_enabled)
def __call__(self, wrapped, instance: Typing, args, kwargs):
def decorated_call(self, wrapped, instance: Typing, args, kwargs):
"""
Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`.
By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing.
Expand Down Expand Up @@ -1114,3 +1123,12 @@ def disable_semantic_checks():
yield
finally:
typecheck.set_semantic_check_enabled(enabled=True)

@staticmethod
def enable_wrapping(enabled: bool = True):
typecheck.set_typecheck_enabled(enabled)
if enabled:
typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.wrapped_call.__code__
else:
typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.unwrapped_call.__code__
print(typecheck.__call__)
36 changes: 23 additions & 13 deletions nemo/core/classes/exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def export(
check_tolerance=0.01,
export_modules_as_functions=False,
keep_initializers_as_inputs=None,
use_dynamo=True,
):
"""
Exports the model to the specified format. The format is inferred from the file extension of the output file.
Expand Down Expand Up @@ -143,6 +144,7 @@ def _export(
check_tolerance=0.01,
export_modules_as_functions=False,
keep_initializers_as_inputs=None,
use_dynamo=True,
):
my_args = locals().copy()
my_args.pop('self')
Expand Down Expand Up @@ -218,19 +220,27 @@ def _export(
if dynamic_axes is None:
dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names)
dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names))
torch.onnx.export(
jitted_model,
input_example,
output,
input_names=input_names,
output_names=output_names,
verbose=verbose,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
export_modules_as_functions=export_modules_as_functions,
)
if use_dynamo:
options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes)
ex_model = torch.export.export(jitted_model, tuple(input_list), kwargs=input_dict)
ex_model = ex_model.run_decompositions()
ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options)
ex.save(output)
input_names = None
else:
torch.onnx.export(
jitted_model,
input_example,
output,
input_names=input_names,
output_names=output_names,
verbose=verbose,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
export_modules_as_functions=export_modules_as_functions,
)

if check_trace:
verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance)
Expand Down
5 changes: 5 additions & 0 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def parse_input_example(input_example):

def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list):
odict = {}
if not input_names:
input_list.extend(input_dict.values())
for k, v in zip(ort_input_names, input_list):
odict[k] = v.cpu().numpy()
return odict
for k in reversed(input_names):
val = None
if k in input_dict:
Expand Down
45 changes: 9 additions & 36 deletions tests/collections/asr/test_asr_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from nemo.core.utils import numba_utils
from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__

# from nemo.core.classes import typecheck
# typecheck.enable_wrapping(enabled=False)


NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__)


Expand All @@ -52,8 +56,6 @@ def test_EncDecCTCModel_export_to_onnx(self):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -66,8 +68,6 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -78,8 +78,6 @@ def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model):
model.export(output=filename)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logits'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -90,9 +88,6 @@ def test_EncDecCitrinetModel_export_to_onnx(self, citrinet_model):
model.export(output=filename)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
Expand Down Expand Up @@ -132,9 +127,6 @@ def test_EncDecCitrinetModel_limited_SE_export_to_onnx(self, citrinet_model):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'logprobs'

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -153,10 +145,6 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model):
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert len(onnx_model.graph.input) == 2
assert len(onnx_model.graph.output) == 2
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.input[1].name == 'length'
assert onnx_model.graph.output[0].name == 'outputs'
assert onnx_model.graph.output[1].name == 'encoded_lengths'

decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn)
assert files[1] == decoder_joint_filename
Expand All @@ -171,21 +159,12 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model):

# enc_logits + (all decoder inputs - state tuple) + flattened state list
assert len(onnx_model.graph.input) == (1 + (len(input_examples) - 1) + num_states)
assert onnx_model.graph.input[0].name == 'encoder_outputs'
assert onnx_model.graph.input[1].name == 'targets'
assert onnx_model.graph.input[2].name == 'target_length'

if num_states > 0:
for idx, ip in enumerate(onnx_model.graph.input[3:]):
assert ip.name == "input_" + state_name + '_' + str(idx + 1)

assert len(onnx_model.graph.output) == (len(input_examples) - 1) + num_states
assert onnx_model.graph.output[0].name == 'outputs'
assert onnx_model.graph.output[1].name == 'prednet_lengths'

if num_states > 0:
for idx, op in enumerate(onnx_model.graph.output[2:]):
assert op.name == "output_" + state_name + '_' + str(idx + 1)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand All @@ -206,8 +185,6 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model):
assert ts_encoder is not None

arguments = ts_encoder.forward.schema.arguments[1:] # First value is `self`
assert arguments[0].name == 'audio_signal'
assert arguments[1].name == 'length'

decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn)
assert files[1] == decoder_joint_filename
Expand All @@ -225,13 +202,6 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model):

# enc_logits + (all decoder inputs - state tuple) + flattened state list
assert len(ts_decoder_joint_args) == (1 + (len(input_examples) - 1) + num_states)
assert ts_decoder_joint_args[0].name == 'encoder_outputs'
assert ts_decoder_joint_args[1].name == 'targets'
assert ts_decoder_joint_args[2].name == 'target_length'

if num_states > 0:
for idx, ip in enumerate(ts_decoder_joint_args[3:]):
assert ip.name == "input_" + state_name + '_' + str(idx + 1)

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand Down Expand Up @@ -265,8 +235,6 @@ def test_EncDecCTCModel_adapted_export_to_onnx(self):
)
onnx_model = onnx.load(filename)
onnx.checker.check_model(onnx_model, full_check=True) # throws when failed
assert onnx_model.graph.input[0].name == 'audio_signal'
assert onnx_model.graph.output[0].name == 'logprobs'

def setup_method(self):
self.preprocessor = {
Expand Down Expand Up @@ -670,3 +638,8 @@ def squeezeformer_model():
)
conformer_model = EncDecCTCModel(cfg=modelConfig)
return conformer_model


if __name__ == "__main__":
t = TestExportable()
t.test_EncDecClassificationModel_export_to_onnx(speech_classification_model())
3 changes: 3 additions & 0 deletions tests/collections/nlp/test_nlp_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import torch
import wget
from omegaconf import DictConfig, OmegaConf
from nemo.core.classes import typecheck

typecheck.enable_wrapping(enabled=False)

from nemo.collections import nlp as nemo_nlp
from nemo.collections.nlp.models import IntentSlotClassificationModel
Expand Down
4 changes: 4 additions & 0 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import torch
from omegaconf import OmegaConf

from nemo.core.classes import typecheck

typecheck.enable_wrapping(enabled=False)

from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel
from nemo.utils.app_state import AppState

Expand Down

0 comments on commit 53102bc

Please sign in to comment.