Skip to content

Commit

Permalink
每此evaluate之前先保存一下模型,evaluate完之后,更新一下yaml即可。避免evaluate异常退出导致模型ckpt丢失。
Browse files Browse the repository at this point in the history
  • Loading branch information
dujing committed Dec 24, 2024
1 parent 690602e commit f2e4019
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 27 deletions.
26 changes: 19 additions & 7 deletions cosyvoice/dataset/dataset_jsondata.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,30 @@ def sample(self, data):

class DataList(IterableDataset):

def __init__(self, lists, utt2wav, utt2text, utt2spk, shuffle=True, partition=True, tts_text=None):
def __init__(self, lists, utt2wav, utt2text, utt2spk, shuffle=True,
partition=True, tts_text=None, eval=False):
self.lists = lists
self.utt2wav = utt2wav
self.utt2text = utt2text
self.utt2spk = utt2spk
self.tts_text = tts_text # a list, each prompt will generate all texts in the list
self.sampler = DistributedSampler(shuffle, partition)
if not eval:
self.sampler = DistributedSampler(shuffle, partition)
else:
self.sampler = None

def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
if self.sampler is not None:
self.sampler.set_epoch(epoch)

def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
if self.sampler is not None:
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
else:
sampler_info = {}
indexes = range(len(self.lists))

for index in indexes:
utt = self.lists[index]
sample = {}
Expand Down Expand Up @@ -151,7 +161,8 @@ def Dataset(json_file,
gan=False,
shuffle=True,
partition=True,
tts_file=None):
tts_file=None,
eval=False):
""" Construct dataset from arguments
json_file is like :
Expand Down Expand Up @@ -244,7 +255,8 @@ def add_one_data(json_file):
logging.info(f"read {len(tts_text)} lines from {tts_file}")

dataset = DataList(valid_utt_list, utt2wav, utt2text, utt2spk,
shuffle=shuffle, partition=partition, tts_text=tts_text)
shuffle=shuffle, partition=partition,
tts_text=tts_text, eval=eval)

if gan is True:
# map partial arg to padding func in gan mode
Expand Down
18 changes: 10 additions & 8 deletions cosyvoice/utils/executor_online_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,19 @@ def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, codec_model=None, spkemb_model=None):
''' Cross validation on
'''
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step, on_batch_end, self.rank))
model.eval()
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["data_idx"] = self.data_idx
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step)
# save model first, in case batch evaluate fail, the ckpt will not be saved
save_model(model, model_name, info_dict, only_yaml=False)

total_num_utts, total_loss_dict = 0, {} # avoid division by 0
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
info_dict["data_idx"] = self.data_idx

num_utts = len(batch_dict["utts"])
total_num_utts += num_utts

Expand All @@ -229,5 +232,4 @@ def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, codec_
total_loss_dict[k] = sum(v) / total_num_utts
info_dict['loss_dict'] = total_loss_dict
log_per_save(writer, info_dict)
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step)
save_model(model, model_name, info_dict)
save_model(model, model_name, info_dict, only_yaml=True)
25 changes: 13 additions & 12 deletions cosyvoice/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,20 @@ def init_summarywriter(args):
return writer


def save_model(model, model_name, info_dict):
def save_model(model, model_name, info_dict, only_yaml=False):
rank = int(os.environ.get('RANK', 0))
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))

if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save(model.module.state_dict(), save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if not only_yaml:
if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save(model.module.state_dict(), save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if rank == 0:
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
Expand Down Expand Up @@ -341,13 +342,13 @@ def log_per_save(writer, info_dict):
rank = int(os.environ.get('RANK', 0))
logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
epoch, step, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))

if writer is not None:
for k in ['epoch', 'lr']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
writer.add_scalar('{}/{}'.format(tag, k), v, step)

def init_kaldi_dataset(args, configs, gan, train_data_indexes):
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
Expand Down

0 comments on commit f2e4019

Please sign in to comment.