Skip to content

Commit

Permalink
[feature] add a frontend module in wespeaker and support wavlm (#344)
Browse files Browse the repository at this point in the history
* [feature] add a frontend module in wespeaker and support wavlm

* update .gitignore

* update wavlm configs

* update wespeaker/frontend/__init__.py

* [fix] remove trailing whitespaces

* [fix] fix lint errors

* [fix] fix lint errors

* [fix] fix lint errors

* [fix] fix spelling mistakes

* update run.sh

* update wavlm configs and add run_wavlm.sh

* update README.md
  • Loading branch information
JiJiJiang authored Aug 19, 2024
1 parent 1df7626 commit 655039e
Show file tree
Hide file tree
Showing 21 changed files with 738 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ tensorboard
*.onnx
external_tools
pretrained_models
s3prl_hub
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pre-commit install # for clean and tidy code
```

## 🔥 News
* 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344).
* 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320).
* 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291).
* 2024.04.23: Support MNN inference engine in runtime, see [#310](https://github.com/wenet-e2e/wespeaker/pull/310).
Expand Down
22 changes: 22 additions & 0 deletions examples/voxceleb/v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,25 @@ The results on ResNet34 (large margin, no asnorm) are:
|:--------------:|:------------:|:------------:|:------------:|
| PLDA | 1.207 | 1.350 | 2.528 |


## WavLM results

* Pre-trained frontend: the [WavLM](https://arxiv.org/abs/2110.13900) Large model, multilayer features are used
* Speaker model: ECAPA_TDNN_GLOB_c512-ASTP-emb192
* Training strategy: Frozen => Joint ft => Joint lmft

```bash
bash run_wavlm.sh --stage 3 --stop_stage 9
```

| Training strategy | AS-Norm | QMF | vox1-O-clean | vox1-E-clean | vox1-H-clean |
|:------------------|:-------:|:---:|:------------:|:------------:|:------------:|
| Frozen | × | × | 0.595 | 0.719 | 1.501 |
| || × | 0.548 | 0.656 | 1.355 |
| ||| 0.489 | 0.619 | 1.224 |
| Frozen => Joint ft | × | × | 0.542 | 0.635 | 1.355 |
| || × | 0.521 | 0.594 | 1.237 |
| ||| 0.494 | 0.576 | 1.205 |
| Frozen => Joint ft => Joint lmft | × | × | 0.521 | 0.626 | 1.344 |
| || × | 0.495 | 0.588 | 1.247 |
| ||| **0.415** | **0.551** | **1.118** |
1 change: 1 addition & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dataset_args:
speed_perturb: True
num_frms: 200
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "fbank" # fbank, s3prl
fbank_args:
num_mel_bins: 80
frame_shift: 10
Expand Down
91 changes: 91 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_frozen-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch150
gpus: "[0,1,2,3,4,5,6,7]"
num_avg: 10
enable_amp: True # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
save_epoch_interval: 5 # save model every 5 epochs
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 256
num_workers: 16
pin_memory: False
prefetch_factor: 16
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 150
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: True
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.0
final_margin: 0.2
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.1
final_lr: 0.00001
warm_up_epoch: 6
warm_from_zero: True
92 changes: 92 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_ft-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch20
gpus: "[0,1,2,3,4,5,6,7]"
num_avg: 3
enable_amp: True # whether enable automatic mixed precision training
do_lm: False

seed: 42
num_epochs: 20
save_epoch_interval: 1 # save model every epoch
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 64
num_workers: 8
pin_memory: False
prefetch_factor: 8
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 150
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: False
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.2
final_margin: 0.2
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.001
final_lr: 0.00025
warm_up_epoch: 1
warm_from_zero: True
92 changes: 92 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_lmft-num_frms300-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch10
gpus: "[0,1,2,3,4,5,6,7]"
num_avg: 1
enable_amp: True # whether enable automatic mixed precision training
do_lm: True

seed: 42
num_epochs: 20
save_epoch_interval: 1 # save model every epoch
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 32
num_workers: 8
pin_memory: False
prefetch_factor: 8
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 300
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: False
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.5
final_margin: 0.5
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.0001
final_lr: 0.000025
warm_up_epoch: 1
warm_from_zero: True
1 change: 1 addition & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dataset_args:
speed_perturb: True
num_frms: 600
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "fbank" # fbank, s3prl
fbank_args:
num_mel_bins: 80
frame_shift: 10
Expand Down
8 changes: 4 additions & 4 deletions examples/voxceleb/v2/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ stop_stage=-1
data=data
data_type="shard" # shard/raw

config=conf/campplus.yaml
exp_dir=exp/CAMPPlus-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
config=conf/resnet.yaml
exp_dir=exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150
gpus="[0,1]"
num_avg=10
checkpoint=
Expand All @@ -22,7 +22,7 @@ score_norm_method="asnorm" # asnorm/snorm
top_n=300

# setup for large margin fine-tuning
lm_config=conf/campplus_lm.yaml
lm_config=conf/resnet_lm.yaml

. tools/parse_options.sh || exit 1

Expand Down Expand Up @@ -55,7 +55,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Start training ..."
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
torchrun --master_addr=localhost --master_port=29401 --nnodes=1 --nproc_per_node=$num_gpus \
wespeaker/bin/train.py --config $config \
--exp_dir ${exp_dir} \
--gpus $gpus \
Expand Down
Loading

0 comments on commit 655039e

Please sign in to comment.