diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index c445309440517..534598097bf45 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add): self.weight = weight self.skip_bias_add = skip_bias_add - def forward(self, x): + def forward(self, x, weight=None): + if weight is None: + weight = self.weight if self.skip_bias_add: - return F.linear(x, self.weight), self.bias - return F.linear(x, self.weight, self.bias), None + return F.linear(x, weight), self.bias + return F.linear(x, weight, self.bias), None def get_export_format(filename: str): @@ -239,7 +241,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.normalization import MixedFusedRMSNorm from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm - from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm + from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: @@ -255,21 +258,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - n_state = n.state_dict() + elif isinstance(n, MCoreFusedLayerNorm): + shape, eps, affine = n.weight.shape, n.eps, True elif isinstance(n, FastLayerNorm): shape, eps, affine = n.weight.shape, n.epsilon, True - n_state = n.state_dict() - elif isinstance(n, MixedFusedRMSNorm): - shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine - tmp_n_state = n.state_dict() - n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])} else: return None n_state = n.state_dict() mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) - mod.load_state_dict(n_state) + mod.load_state_dict(n_state, strict=True) return mod @@ -306,7 +305,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]: mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) n_state = n.state_dict() - mod.load_state_dict(n_state) + mod.load_state_dict(n_state, strict=False) return mod def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: @@ -318,7 +317,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: Equivalent LayerNorm module """ if not isinstance(n, FusedScaleMaskSoftmax): - logging.warning("This function can only change the FusedScaleMaskSoftmax module.") + logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}") return n # disable the fusion only @@ -331,6 +330,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: default_Apex_replacements = { "FusedLayerNorm": replace_FusedLayerNorm, "MixedFusedLayerNorm": replace_FusedLayerNorm, + "MCoreFusedLayerNorm": replace_FusedLayerNorm, "FastLayerNorm": replace_FusedLayerNorm, "RowParallelLinear": replace_ParallelLinear, "ColumnParallelLinear": replace_ParallelLinear,