Skip to content

Commit

Permalink
Fixing Apex export replacements
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Jul 8, 2024
1 parent 6697bbd commit 8370cc8
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add):
self.weight = weight
self.skip_bias_add = skip_bias_add

def forward(self, x):
def forward(self, x, weight=None):
if weight is None:
weight = self.weight
if self.skip_bias_add:
return F.linear(x, self.weight), self.bias
return F.linear(x, self.weight, self.bias), None
return F.linear(x, weight), self.bias
return F.linear(x, weight, self.bias), None


def get_export_format(filename: str):
Expand Down Expand Up @@ -239,7 +241,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization import MixedFusedRMSNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
Expand All @@ -255,21 +258,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:

if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
n_state = n.state_dict()
elif isinstance(n, MCoreFusedLayerNorm):
shape, eps, affine = n.weight.shape, n.eps, True
elif isinstance(n, FastLayerNorm):
shape, eps, affine = n.weight.shape, n.epsilon, True
n_state = n.state_dict()
elif isinstance(n, MixedFusedRMSNorm):
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
tmp_n_state = n.state_dict()
n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])}
else:
return None

n_state = n.state_dict()
mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)

mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=True)

return mod

Expand Down Expand Up @@ -306,7 +305,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
mod.load_state_dict(n_state, strict=False)
return mod

def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Expand All @@ -318,7 +317,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
Equivalent LayerNorm module
"""
if not isinstance(n, FusedScaleMaskSoftmax):
logging.warning("This function can only change the FusedScaleMaskSoftmax module.")
logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}")
return n

# disable the fusion only
Expand All @@ -331,6 +330,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
default_Apex_replacements = {
"FusedLayerNorm": replace_FusedLayerNorm,
"MixedFusedLayerNorm": replace_FusedLayerNorm,
"MCoreFusedLayerNorm": replace_FusedLayerNorm,
"FastLayerNorm": replace_FusedLayerNorm,
"RowParallelLinear": replace_ParallelLinear,
"ColumnParallelLinear": replace_ParallelLinear,
Expand Down

0 comments on commit 8370cc8

Please sign in to comment.