forked from fishaudio/Bert-VITS2
-
Notifications
You must be signed in to change notification settings - Fork 8
/
webui_preprocess.py
177 lines (161 loc) · 8.01 KB
/
webui_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
import webbrowser
import os
import json
import subprocess
import shutil
def get_path(data_dir):
start_path = os.path.join("./data", data_dir)
lbl_path = os.path.join(start_path, "esd.list")
train_path = os.path.join(start_path, "train.list")
val_path = os.path.join(start_path, "val.list")
config_path = os.path.join(start_path, "configs", "config.json")
return start_path, lbl_path, train_path, val_path, config_path
def generate_config(data_dir, batch_size):
assert data_dir != "", "数据集名称不能为空"
start_path, _, train_path, val_path, config_path = get_path(data_dir)
if os.path.isfile(config_path):
config = json.load(open(config_path))
else:
config = json.load(open("configs/config.json"))
config["data"]["training_files"] = train_path
config["data"]["validation_files"] = val_path
config["train"]["batch_size"] = batch_size
out_path = os.path.join(start_path, "configs")
if not os.path.isdir(out_path):
os.mkdir(out_path)
model_path = os.path.join(start_path, "models")
if not os.path.isdir(model_path):
os.mkdir(model_path)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=4)
if not os.path.exists("config.yml"):
shutil.copy(src="default_config.yml", dst="config.yml")
return "配置文件生成完成"
def resample(data_dir):
assert data_dir != "", "数据集名称不能为空"
start_path, _, _, _, config_path = get_path(data_dir)
in_dir = os.path.join(start_path, "raw")
out_dir = os.path.join(start_path, "wavs")
subprocess.run(
f"python resample.py "
f"--sr 44100 "
f"--in_dir {in_dir} "
f"--out_dir {out_dir} ",
shell=True,
)
return "音频文件预处理完成"
def preprocess_text(data_dir):
assert data_dir != "", "数据集名称不能为空"
start_path, lbl_path, train_path, val_path, config_path = get_path(data_dir)
lines = open(lbl_path, "r", encoding="utf-8").readlines()
with open(lbl_path, "w", encoding="utf-8") as f:
for line in lines:
path, spk, language, text = line.strip().split("|")
path = os.path.join(start_path, "wavs", os.path.basename(path))
f.writelines(f"{path}|{spk}|{language}|{text}\n")
subprocess.run(
f"python preprocess_text.py "
f"--transcription-path {lbl_path} "
f"--train-path {train_path} "
f"--val-path {val_path} "
f"--config-path {config_path}",
shell=True,
)
return "标签文件预处理完成"
def bert_gen(data_dir):
assert data_dir != "", "数据集名称不能为空"
_, _, _, _, config_path = get_path(data_dir)
subprocess.run(
f"python bert_gen.py " f"--config {config_path}",
shell=True,
)
return "BERT 特征文件生成完成"
def clap_gen(data_dir):
assert data_dir != "", "数据集名称不能为空"
_, _, _, _, config_path = get_path(data_dir)
subprocess.run(
f"python clap_gen.py " f"--config {config_path}",
shell=True,
)
return "CLAP 特征文件生成完成"
if __name__ == "__main__":
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
_ = gr.Markdown(
value="# Bert-VITS2 数据预处理\n"
"## 预先准备:\n"
"下载 BERT 和 CLAP 模型:\n"
"- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
"- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
"- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
"- [CLAP](https://huggingface.co/laion/clap-htsat-fused)\n"
"\n"
"将 BERT 模型放置到 `bert` 文件夹下,CLAP 模型放置到 `emotional` 文件夹下,覆盖同名文件夹。\n"
"\n"
"数据准备:\n"
"将数据放置在 data 文件夹下,按照如下结构组织:\n"
"\n"
"```\n"
"├── data\n"
"│ ├── {你的数据集名称}\n"
"│ │ ├── esd.list\n"
"│ │ ├── raw\n"
"│ │ │ ├── ****.wav\n"
"│ │ │ ├── ****.wav\n"
"│ │ │ ├── ...\n"
"```\n"
"\n"
"其中,`raw` 文件夹下保存所有的音频文件,`esd.list` 文件为标签文本,格式为\n"
"\n"
"```\n"
"****.wav|{说话人名}|{语言 ID}|{标签文本}\n"
"```\n"
"\n"
"例如:\n"
"```\n"
"vo_ABDLQ001_1_paimon_02.wav|派蒙|ZH|没什么没什么,只是平时他总是站在这里,有点奇怪而已。\n"
"noa_501_0001.wav|NOA|JP|そうだね、油断しないのはとても大事なことだと思う\n"
"Albedo_vo_ABDLQ002_4_albedo_01.wav|Albedo|EN|Who are you? Why did you alarm them?\n"
"...\n"
"```\n"
)
data_dir = gr.Textbox(
label="数据集名称",
placeholder="你放置在 data 文件夹下的数据集所在文件夹的名称,如 data/genshin 则填 genshin",
)
info = gr.Textbox(label="状态信息")
_ = gr.Markdown(value="## 第一步:生成配置文件")
with gr.Row():
batch_size = gr.Slider(
label="批大小(Batch size):24 GB 显存可用 12",
value=8,
minimum=1,
maximum=64,
step=1,
)
generate_config_btn = gr.Button(value="执行", variant="primary")
_ = gr.Markdown(value="## 第二步:预处理音频文件")
resample_btn = gr.Button(value="执行", variant="primary")
_ = gr.Markdown(value="## 第三步:预处理标签文件")
preprocess_text_btn = gr.Button(value="执行", variant="primary")
_ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
bert_gen_btn = gr.Button(value="执行", variant="primary")
_ = gr.Markdown(value="## 第五步:生成 CLAP 特征文件")
clap_gen_btn = gr.Button(value="执行", variant="primary")
_ = gr.Markdown(
value="## 训练模型及部署:\n"
"修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
"- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
"- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
)
generate_config_btn.click(
generate_config, inputs=[data_dir, batch_size], outputs=[info]
)
resample_btn.click(resample, inputs=[data_dir], outputs=[info])
preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
clap_gen_btn.click(clap_gen, inputs=[data_dir], outputs=[info])
webbrowser.open(f"http://127.0.0.1:7860")
app.launch(share=False, server_port=7860)