Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: borisfom <[email protected]>
  • Loading branch information
borisfom committed Jun 13, 2024
1 parent ea50c5e commit dd21b74
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def output_names(self):
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
for n, t in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)

Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te


def splice_frames(x, frame_splicing):
""" Stacks frames together across feature dim
"""Stacks frames together across feature dim
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2 ** -24,
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=16.7,
Expand Down Expand Up @@ -511,7 +511,7 @@ def __init__(
highfreq: Optional[float] = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Union[float, str] = 2 ** -24,
log_zero_guard_value: Union[float, str] = 2**-24,
dither: float = 1e-5,
window: str = "hann",
pad_to: int = 0,
Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(

@property
def filter_banks(self):
""" Matches the analogous class """
"""Matches the analogous class"""
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask):
return y

def set_max_len(self, max_len, seq_range=None):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_len = max_len
if seq_range is None:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.scale = 1 / (d_head ** 0.5)
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm

self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
Expand Down
2 changes: 1 addition & 1 deletion nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(self, *args):

@contextmanager
def monkeypatched(object, name, patch):
""" Temporarily monkeypatches an object. """
"""Temporarily monkeypatches an object."""
pre_patched_value = getattr(object, name)
setattr(object, name, patch)
yield object
Expand Down
4 changes: 2 additions & 2 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:

def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
Generic function generator to replace BaseT module with DestT wrapper.
Generic function generator to replace BaseT module with DestT wrapper.
Args:
BaseT : module type to replace
DestT : destination module type
Expand Down Expand Up @@ -450,7 +450,7 @@ def script_module(m: nn.Module):

def replace_for_export(model: nn.Module) -> nn.Module:
"""
Top-level function to replace 'default set' of modules in model, called from _prepare_for_export.
Top-level function to replace 'default set' of modules in model, called from _prepare_for_export.
NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
Args:
model : top level module
Expand Down

0 comments on commit dd21b74

Please sign in to comment.