From a7d1dd55d6dfc3ad8a43b263363fb157c397840e Mon Sep 17 00:00:00 2001 From: ffd000 Date: Sun, 14 Apr 2024 15:33:23 +0300 Subject: [PATCH] Update training scripts to use generic device Updated the training scripts to use specified device instead of moving stuff to cuda exclusively. --- train_finetune.py | 2 +- train_finetune_accelerate.py | 2 +- train_first.py | 6 +++--- train_second.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/train_finetune.py b/train_finetune.py index 3c650747..d17a4eea 100644 --- a/train_finetune.py +++ b/train_finetune.py @@ -577,7 +577,7 @@ def main(config_path): batch = [b.to(device) for b in batch[1:]] texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch with torch.no_grad(): - mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda') + mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) text_mask = length_to_mask(input_lengths).to(texts.device) _, _, s2s_attn = model.text_aligner(mels, mask, texts) diff --git a/train_finetune_accelerate.py b/train_finetune_accelerate.py index 4cfd95f7..7b85d45d 100644 --- a/train_finetune_accelerate.py +++ b/train_finetune_accelerate.py @@ -584,7 +584,7 @@ def main(config_path): batch = [b.to(device) for b in batch[1:]] texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch with torch.no_grad(): - mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda') + mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) text_mask = length_to_mask(input_lengths).to(texts.device) _, _, s2s_attn = model.text_aligner(mels, mask, texts) diff --git a/train_first.py b/train_first.py index eaa8fe64..22c3178b 100644 --- a/train_first.py +++ b/train_first.py @@ -183,7 +183,7 @@ def main(config_path): texts, input_lengths, _, _, mels, mel_input_length, _ = batch with torch.no_grad(): - mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda') + mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) text_mask = length_to_mask(input_lengths).to(texts.device) ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) @@ -336,7 +336,7 @@ def main(config_path): texts, input_lengths, _, _, mels, mel_input_length, _ = batch with torch.no_grad(): - mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda') + mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts) s2s_attn = s2s_attn.transpose(-1, -2) @@ -368,7 +368,7 @@ def main(config_path): en.append(asr[bib, :, random_start:random_start+mel_len]) gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)]) y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300] - wav.append(torch.from_numpy(y).to('cuda')) + wav.append(torch.from_numpy(y).to(device)) wav = torch.stack(wav).float().detach() diff --git a/train_second.py b/train_second.py index fb1048dc..3381e4ab 100644 --- a/train_second.py +++ b/train_second.py @@ -576,7 +576,7 @@ def main(config_path): batch = [b.to(device) for b in batch[1:]] texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch with torch.no_grad(): - mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda') + mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device) text_mask = length_to_mask(input_lengths).to(texts.device) _, _, s2s_attn = model.text_aligner(mels, mask, texts)