Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Jul 5, 2024
1 parent 88cf623 commit 7f8105f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def __init__(self, model):
self.dtype = utils_funcs.torch_dtype_from_precision(model.cfg.precision)

def forward(self, input_ids, position_ids, attention_mask):
device = str(next(self.parameters()).device)
if self.fp8_enabled and HAVE_TE:
with (
transformer_engine.pytorch.onnx_export(self.fp8_enabled),
Expand All @@ -222,7 +221,7 @@ def forward(self, input_ids, position_ids, attention_mask):
with (
torch.no_grad(),
torch.inference_mode(),
torch.autocast(device, dtype=self.dtype),
torch.autocast('cuda', dtype=self.dtype),
warnings.catch_warnings(),
):
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning, module=r'.*')
Expand All @@ -231,9 +230,9 @@ def forward(self, input_ids, position_ids, attention_mask):
attention_mask.shape[2] == attention_mask.shape[3] == input_ids.shape[1] == position_ids.shape[1]
)
output_tensor = self.model.forward(
tokens=input_ids.to(device=device),
text_position_ids=position_ids.to(device=device),
attention_mask=attention_mask.to(device=device),
tokens=input_ids.cuda(),
text_position_ids=position_ids.cuda(),
attention_mask=attention_mask.cuda(),
labels=None,
)

Expand Down
9 changes: 1 addition & 8 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,24 +260,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:

if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
n_state = n.state_dict()
elif isinstance(n, MCoreFusedLayerNorm):
shape, eps, affine = n.weight.shape, n.eps, True
n_state = n.state_dict()
elif isinstance(n, FastLayerNorm):
shape, eps, affine = n.weight.shape, n.epsilon, True
n_state = n.state_dict()
elif isinstance(n, MixedFusedRMSNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
tmp_n_state = n.state_dict()
n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])}
else:
return None

n_state = n.state_dict()
mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)

mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=True)

return mod

Expand Down

0 comments on commit 7f8105f

Please sign in to comment.