Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into lazy-export
Browse files Browse the repository at this point in the history
  • Loading branch information
borisfom committed Nov 15, 2024
2 parents a76ddcb + 0625327 commit 2a0bebc
Show file tree
Hide file tree
Showing 51 changed files with 4,397 additions and 868 deletions.
15 changes: 15 additions & 0 deletions docs/source/asr/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ RNNT Decoding
:show-inheritance:
:members:

TDT Decoding
~~~~~~~~~~~~~

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer
:show-inheritance:
:members:

Hypotheses
~~~~~~~~~~

Expand Down
56 changes: 46 additions & 10 deletions nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@


def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:
"""
Packs a list of hypotheses into a tensor and prepares decoder states.
This function takes a list of token sequences (hypotheses) and converts
it into a tensor format. If any decoder states are on the GPU, they
are moved to the CPU. Additionally, the function removes any timesteps
with a value of -1 from the sequences.
Args:
hypotheses (list): A list of token sequences representing hypotheses.
Returns:
list: A list of packed hypotheses in tensor format.
"""
for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis
hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)

Expand All @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:


def _states_to_device(dec_state, device='cpu'):
"""
Transfers decoder states to the specified device.
This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda').
Args:
dec_state (Tensor): The decoder states to be transferred.
device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'.
Returns:
Tensor: The decoder states on the specified device.
"""
if torch.is_tensor(dec_state):
dec_state = dec_state.to(device)

Expand Down Expand Up @@ -106,15 +132,17 @@ class BeamRNNTInfer(Typing):
however the time required for the search also grows steadily.
`tsd` - time synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.
Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions.
For longer sequences, T is greater, and can therefore take a long time for beams to obtain
good results. This also requires greater memory to execute.
`alsd` - alignment-length synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.
Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth
Expand All @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing):
For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.
`maes` = modified adaptive expansion searcn. Please refer to the paper:
[Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505)
[Accelerating RNN Transducer Inference via Adaptive Expansion Search]
(https://ieeexplore.ieee.org/document/9250505)
Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the
number of expansions (for tokens) required per timestep. The number of expansions can usually
Expand Down Expand Up @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing):
and affects the speed of inference since large values will perform large beam search in the next step.
maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
The default (2.3) is selected from the paper. It performs a comparison
(max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob
is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which
can be potential candidates for expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
Expand All @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing):
preserve_alignments: Bool flag which preserves the history of alignments generated during
beam decoding (sample). When set to true, the Hypothesis will contain
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1).
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1)
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary.
Expand Down Expand Up @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu
return lm_score, next_state

def set_decoding_type(self, decoding_type: str):

# Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
"""
Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
Args:
decoding_type: decoding type
"""
# TOKEN_OFFSET for BPE-based models
if decoding_type == 'subword':
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
Expand All @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str):

@dataclass
class BeamRNNTInferConfig:
"""
Beam RNNT Inference config.
"""

beam_size: int
search_type: str = 'default'
score_norm: bool = True
Expand Down
Loading

0 comments on commit 2a0bebc

Please sign in to comment.