Skip to content

Commit

Permalink
Nemotron ONNX export fixed
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Jul 5, 2024
1 parent f7515ee commit 0bf80da
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 22 deletions.
23 changes: 17 additions & 6 deletions examples/nlp/language_modeling/megatron_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.export import Dim

from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
Expand Down Expand Up @@ -74,7 +75,7 @@ def nemo_export(cfg):
assert nemo_in is not None, "NeMo model not provided. Please provide the path to the .nemo or .ckpt file"

onnx_out = cfg.onnx_model_file

print(f"onnx_out: {onnx_out}")
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
assert (
cfg.trainer.devices * cfg.trainer.num_nodes
Expand Down Expand Up @@ -149,18 +150,28 @@ def nemo_export(cfg):
try:
model.to(device=cfg.export_options.device).freeze()
model.eval()

sequence = "sequence"
batch = "batch"

use_dynamo = False
if use_dynamo:
sequence = Dim("sequence")
batch = Dim("batch")

model.export(
onnx_out,
onnx_opset_version=cfg.export_options.onnx_opset,
do_constant_folding=cfg.export_options.do_constant_folding,
dynamic_axes={
'input_ids': {0: "sequence", 1: "batch"},
'position_ids': {0: "sequence", 1: "batch"},
'logits': {0: "sequence", 1: "batch"},
},
check_trace=check_trace,
check_tolerance=cfg.export_options.check_tolerance,
verbose=cfg.export_options.verbose,
dynamic_axes={
'input_ids': {0: sequence, 1: batch},
'position_ids': {0: sequence, 1: batch},
'logits': {0: sequence, 1: batch},
},
use_dynamo=use_dynamo,
)
except Exception as e:
logging.error(
Expand Down
25 changes: 15 additions & 10 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def __init__(self, model):

self.dtype = utils_funcs.torch_dtype_from_precision(model.cfg.precision)

def forward(self, tokens, position_ids, attention_mask):
def forward(self, input_ids, position_ids, attention_mask):
device = str(next(self.parameters()).device)
if self.fp8_enabled and HAVE_TE:
with (
transformer_engine.pytorch.onnx_export(self.fp8_enabled),
Expand All @@ -207,10 +208,12 @@ def forward(self, tokens, position_ids, attention_mask):
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
assert input_ids.shape == position_ids.shape
assert (
attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1]
)
output_tensor = self.model.forward(
tokens=tokens.cuda(),
tokens=input_ids.cuda(),
text_position_ids=position_ids.cuda(),
attention_mask=attention_mask.cuda(),
labels=None,
Expand All @@ -219,16 +222,18 @@ def forward(self, tokens, position_ids, attention_mask):
with (
torch.no_grad(),
torch.inference_mode(),
torch.autocast('cuda', dtype=self.dtype),
torch.autocast(device, dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
assert tokens.shape == position_ids.shape
assert attention_mask.shape[2] == attention_mask.shape[3] == tokens.shape[1] == position_ids.shape[1]
assert input_ids.shape == position_ids.shape
assert (
attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1]
)
output_tensor = self.model.forward(
tokens=tokens.cuda(),
text_position_ids=position_ids.cuda(),
attention_mask=attention_mask.cuda(),
tokens=input_ids.to(device=device),
text_position_ids=position_ids.to(device=device),
attention_mask=attention_mask.to(device=device),
labels=None,
)

Expand Down
21 changes: 15 additions & 6 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add):
self.weight = weight
self.skip_bias_add = skip_bias_add

def forward(self, x):
def forward(self, x, weight=None):
if weight is None:
weight = self.weight
if self.skip_bias_add:
return F.linear(x, self.weight), self.bias
return F.linear(x, self.weight, self.bias), None
return F.linear(x, weight), self.bias
return F.linear(x, weight, self.bias), None


def get_export_format(filename: str):
Expand Down Expand Up @@ -239,7 +241,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization import MixedFusedRMSNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm

# from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
Expand All @@ -256,6 +261,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
n_state = n.state_dict()
elif isinstance(n, MCoreFusedLayerNorm):
shape, eps, affine = n.weight.shape, n.eps, True
n_state = n.state_dict()
elif isinstance(n, FastLayerNorm):
shape, eps, affine = n.weight.shape, n.epsilon, True
n_state = n.state_dict()
Expand Down Expand Up @@ -306,7 +314,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=False)
return mod

def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Expand All @@ -318,7 +326,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Equivalent LayerNorm module
"""
if not isinstance(n, FusedScaleMaskSoftmax):
logging.warning("This function can only change the FusedScaleMaskSoftmax module.")
logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}")
return n

# disable the fusion only
Expand All @@ -331,6 +339,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
default_Apex_replacements = {
"FusedLayerNorm": replace_FusedLayerNorm,
"MixedFusedLayerNorm": replace_FusedLayerNorm,
"MCoreFusedLayerNorm": replace_FusedLayerNorm,
"FastLayerNorm": replace_FusedLayerNorm,
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
Expand Down

0 comments on commit 0bf80da

Please sign in to comment.