Skip to content

Commit

Permalink
添加文本转写,方便测试的时候,直接从参考音频获取文本信息。
Browse files Browse the repository at this point in the history
  • Loading branch information
dujing committed Nov 27, 2024
1 parent 957eda5 commit 44e8e13
Show file tree
Hide file tree
Showing 12 changed files with 612 additions and 15 deletions.
27 changes: 16 additions & 11 deletions cosyvoice/dataset/dataset_kaldidata.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __iter__(self):
sample = {}
sample['utt'] = utt
sample['wav'] = self.utt2wav[utt]
sample['text'] = self.utt2text[utt]
if utt in self.utt2text:
sample['text'] = self.utt2text[utt]
if utt in self.utt2spk:
sample['spk'] = self.utt2spk[utt]
else:
Expand Down Expand Up @@ -177,8 +178,8 @@ def Dataset(data_dir,

def add_one_data(data_dir):
logging.info(f"Loading data: {data_dir}")
assert os.path.exists(f"{data_dir}/wav.scp") \
and os.path.exists(f"{data_dir}/text")
assert os.path.exists(f"{data_dir}/wav.scp") # \
# and os.path.exists(f"{data_dir}/text")
# and os.path.exists(f"{data_dir}/utt2spk")

with open(f"{data_dir}/wav.scp", 'r', encoding='utf-8') as f_scp:
Expand All @@ -189,13 +190,14 @@ def add_one_data(data_dir):
utt, wav = line[0], line[1]
utt2wav[utt] = wav

with open(f"{data_dir}/text", 'r', encoding='utf-8') as f_text:
for line in f_text:
line = line.strip().split(maxsplit=1)
if len(line) != 2:
continue
utt, text = line[0], line[1]
utt2text[utt] = text
if os.path.exists(f"{data_dir}/text"):
with open(f"{data_dir}/text", 'r', encoding='utf-8') as f_text:
for line in f_text:
line = line.strip().split(maxsplit=1)
if len(line) != 2:
continue
utt, text = line[0], line[1]
utt2text[utt] = text

if os.path.exists(f"{data_dir}/utt2spk"):
with open(f"{data_dir}/utt2spk", 'r', encoding='utf-8') as f_spk:
Expand All @@ -214,7 +216,10 @@ def add_one_data(data_dir):
else:
add_one_data(data_dir)

valid_utt_list = list(set(utt2wav.keys()) & set(utt2text.keys()))
valid_utt_list = list(utt2wav.keys())
if len(utt2text) != 0:
valid_utt_list = list(set(utt2wav.keys()) & set(utt2text.keys()))
logging.info(f"Total utts: {len(valid_utt_list)}")

tts_text = None
if mode=="inference" and os.path.exists(tts_file):
Expand Down
21 changes: 20 additions & 1 deletion cosyvoice/dataset/processor_kaldidata.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ def filter(data,
continue
yield sample

def transcribe(data, get_transcriber, mode='inference'):
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample

if 'text' not in sample:
transcriber = get_transcriber()
transribe_sr = 16000
speech = sample['speech']
if sample['sample_rate'] != transribe_sr:
speech = torchaudio.transforms.Resample(
orig_freq=sample['sample_rate'], new_freq=transribe_sr)(speech)
input = speech[0]
sample['text'] = transcriber.transcribe(speech_or_path=input)
logging.info(f"prompt text: {sample['text']}")

yield sample


def resample(data, resample_rate=24000, min_sample_rate=16000, mode='train'):
""" Resample data.
Expand Down Expand Up @@ -184,7 +202,8 @@ def truncate(data, truncate_length=24576, mode='train'):
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
if mode == 'train':
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample

Expand Down
29 changes: 29 additions & 0 deletions cosyvoice/dataset/transcriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

class Transcriber:
def __init__(self, model="iic/SenseVoiceSmall", device="cuda:0"):
self.model = AutoModel(
model="iic/SenseVoiceSmall",
trust_remote_code=False,
remote_code="model.py",
vad_model="fsmn-vad",
vad_kwargs={"max_single_segment_time": 15000},
device=device)


def transcribe(self, speech_or_path):
res = self.model.generate(
input=speech_or_path,
cache={},
language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
use_itn=True,
batch_size_s=60,
merge_vad=True, #
merge_length_s=15,
)
text = rich_transcription_postprocess(res[0]["text"])
return text

def get_transcriber():
return Transcriber()
1 change: 1 addition & 0 deletions cosyvoice/utils/executor_online_codec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
# Jing Du
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
10 changes: 8 additions & 2 deletions examples/tts_vc/cosyvoice/conf/cosyvoice_phoneme.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ batch: !name:cosyvoice.dataset.processor_kaldidata.batch
max_frames_in_batch: 20000 # 100 frame per second
padding: !name:cosyvoice.dataset.processor_kaldidata.padding
use_spk_embedding: False # change to True during sft

truncate: !name:cosyvoice.dataset.processor_kaldidata.truncate
truncate_length: 360000 # 24k, 15second for inference
get_transcriber: !name:cosyvoice.dataset.transcriber.get_transcriber
transcribe: !name:cosyvoice.dataset.processor_kaldidata.transcribe
get_transcriber: !ref <get_transcriber>
# dataset processor pipeline
data_pipeline: [
!ref <tokenize>,
Expand All @@ -197,9 +201,11 @@ data_pipeline: [
]

infer_data_pipeline: [
!ref <tokenize>,
!ref <filter>,
!ref <resample>,
!ref <truncate>,
!ref <transcribe>,
!ref <tokenize>,
!ref <compute_fbank>,
!ref <batch>,
!ref <padding>,
Expand Down
Loading

0 comments on commit 44e8e13

Please sign in to comment.