From 7f8105f447ceca80ad687f0e13677689b8865c47 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 5 Jul 2024 15:30:33 -0700 Subject: [PATCH] Cleanup Signed-off-by: Boris Fomitchev --- .../nlp/models/language_modeling/megatron_gpt_model.py | 9 ++++----- nemo/utils/export_utils.py | 9 +-------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index ccb6c872681d..dc4ea7353f39 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -197,7 +197,6 @@ def __init__(self, model): self.dtype = utils_funcs.torch_dtype_from_precision(model.cfg.precision) def forward(self, input_ids, position_ids, attention_mask): - device = str(next(self.parameters()).device) if self.fp8_enabled and HAVE_TE: with ( transformer_engine.pytorch.onnx_export(self.fp8_enabled), @@ -222,7 +221,7 @@ def forward(self, input_ids, position_ids, attention_mask): with ( torch.no_grad(), torch.inference_mode(), - torch.autocast(device, dtype=self.dtype), + torch.autocast('cuda', dtype=self.dtype), warnings.catch_warnings(), ): warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*') @@ -231,9 +230,9 @@ def forward(self, input_ids, position_ids, attention_mask): attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1] ) output_tensor = self.model.forward( - tokens=input_ids.to(device=device), - text_position_ids=position_ids.to(device=device), - attention_mask=attention_mask.to(device=device), + tokens=input_ids.cuda(), + text_position_ids=position_ids.cuda(), + attention_mask=attention_mask.cuda(), labels=None, ) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 4225f780a250..84aa583ea3fe 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -260,24 +260,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - n_state = n.state_dict() elif isinstance(n, MCoreFusedLayerNorm): shape, eps, affine = n.weight.shape, n.eps, True - n_state = n.state_dict() elif isinstance(n, FastLayerNorm): shape, eps, affine = n.weight.shape, n.epsilon, True - n_state = n.state_dict() - elif isinstance(n, MixedFusedRMSNorm): - shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - tmp_n_state = n.state_dict() - n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])} else: return None n_state = n.state_dict() mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) - mod.load_state_dict(n_state) + mod.load_state_dict(n_state, strict=True) return mod