From dd21b74c3bf08d9d62a90454dfeb78f7194503ae Mon Sep 17 00:00:00 2001 From: borisfom Date: Thu, 13 Jun 2024 05:12:41 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: borisfom --- nemo/collections/asr/models/asr_model.py | 2 +- nemo/collections/asr/parts/preprocessing/features.py | 8 ++++---- nemo/collections/asr/parts/submodules/jasper.py | 4 ++-- nemo/collections/tts/modules/transformer.py | 2 +- nemo/utils/cast_utils.py | 2 +- nemo/utils/export_utils.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 4f8e82293d48..24e300aff112 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -240,7 +240,7 @@ def output_names(self): if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} - for (n, t) in list(in_types.items())[1:]: + for n, t in list(in_types.items())[1:]: otypes[n] = t return get_io_names(otypes, self.disabled_deployment_output_names) diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 51fc6c2418f7..d70737b5135b 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te def splice_frames(x, frame_splicing): - """ Stacks frames together across feature dim + """Stacks frames together across feature dim input is batch_size, feature_dim, num_frames output is batch_size, feature_dim*frame_splicing, num_frames @@ -261,7 +261,7 @@ def __init__( highfreq=None, log=True, log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, + log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, @@ -511,7 +511,7 @@ def __init__( highfreq: Optional[float] = None, log: bool = True, log_zero_guard_type: str = "add", - log_zero_guard_value: Union[float, str] = 2 ** -24, + log_zero_guard_value: Union[float, str] = 2**-24, dither: float = 1e-5, window: str = "hann", pad_to: int = 0, @@ -582,7 +582,7 @@ def __init__( @property def filter_banks(self): - """ Matches the analogous class """ + """Matches the analogous class""" return self._mel_spec_extractor.mel_scale.fb def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index c2beb3918ead..78f81ee555bc 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask): return y def set_max_len(self, max_len, seq_range=None): - """ Sets maximum input length. - Pre-calculates internal seq_range mask. + """Sets maximum input length. + Pre-calculates internal seq_range mask. """ self.max_len = max_len if seq_range is None: diff --git a/nemo/collections/tts/modules/transformer.py b/nemo/collections/tts/modules/transformer.py index 2243d7d1c317..25c177d221cc 100644 --- a/nemo/collections/tts/modules/transformer.py +++ b/nemo/collections/tts/modules/transformer.py @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals self.n_head = n_head self.d_model = d_model self.d_head = d_head - self.scale = 1 / (d_head ** 0.5) + self.scale = 1 / (d_head**0.5) self.pre_lnorm = pre_lnorm self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head) diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index d59189cc912e..a7960be4cc4d 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -95,7 +95,7 @@ def forward(self, *args): @contextmanager def monkeypatched(object, name, patch): - """ Temporarily monkeypatches an object. """ + """Temporarily monkeypatches an object.""" pre_patched_value = getattr(object, name) setattr(object, name, patch) yield object diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index c2da09101523..c44530944051 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -383,7 +383,7 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -450,7 +450,7 @@ def script_module(m: nn.Module): def replace_for_export(model: nn.Module) -> nn.Module: """ - Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. + Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module