Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add a frontend module in wespeaker and support wavlm #344

Merged
merged 13 commits into from
Aug 19, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ tensorboard
*.onnx
external_tools
pretrained_models
s3prl_hub
1 change: 1 addition & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dataset_args:
speed_perturb: True
num_frms: 200
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "fbank" # fbank, s3prl
fbank_args:
num_mel_bins: 80
frame_shift: 10
Expand Down
91 changes: 91 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c1024-ASTP-emb192-WavLM_Large_frozen-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch150
gpus: "[0,1,2,3,4,5,6,7]"
num_avg: 10
enable_amp: True # whether enable automatic mixed precision training

seed: 42
num_epochs: 150
save_epoch_interval: 5 # save model every 5 epochs
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 256
num_workers: 16
pin_memory: False
prefetch_factor: 16
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 150
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: True
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c1024 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.0
final_margin: 0.2
increase_start_epoch: 20
fix_start_epoch: 40
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.1
final_lr: 0.00001
warm_up_epoch: 6
warm_from_zero: True
92 changes: 92 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c1024-ASTP-emb192-WavLM_Large_joint_ft-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch20
gpus: "[0,1,2,3,4,5,6,7]"
num_avg: 3
enable_amp: True # whether enable automatic mixed precision training
do_lm: False

seed: 42
num_epochs: 20
save_epoch_interval: 1 # save model every epoch
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 64
num_workers: 8
pin_memory: False
prefetch_factor: 8
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 150
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: False
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c1024 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.2
final_margin: 0.2
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.0001
final_lr: 0.000025
warm_up_epoch: 1
warm_from_zero: True
92 changes: 92 additions & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
### train configuraton

exp_dir: exp/ECAPA_TDNN_GLOB_c1024-ASTP-emb192-WavLM_Large_joint_lmft-num_frms300-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch20
gpus: "[0,1]"
num_avg: 2
enable_amp: True # whether enable automatic mixed precision training
do_lm: True

seed: 42
num_epochs: 20
save_epoch_interval: 1 # save model every epoch
log_batch_interval: 100 # log every 100 batchs

dataloader_args:
batch_size: 32
num_workers: 8
pin_memory: False
prefetch_factor: 8
drop_last: True

dataset_args:
# the sample number which will be traversed within one epoch, if the value equals to 0,
# the utterance number in the dataset will be used as the sample_num_per_epoch.
sample_num_per_epoch: 0
shuffle: True
shuffle_args:
shuffle_size: 2500
filter: True
filter_args:
min_num_frames: 50
max_num_frames: 400
resample_rate: 16000
speed_perturb: True
num_frms: 300
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "s3prl" # fbank, s3prl
s3prl_args:
upstream_args:
name: "wavlm_large"
download_dir: ./s3prl_hub
multilayer_feature: True
layer: -1
frozen: False
frame_shift: 20
frame_length: 20
cmvn: True
cmvn_args:
norm_mean: True
norm_var: False
spec_aug: False
spec_aug_args:
num_t_mask: 1
num_f_mask: 1
max_t: 10
max_f: 8
prob: 0.6

model: ECAPA_TDNN_GLOB_c1024 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024
model_init: null
model_args:
feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training)
embed_dim: 192
pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP
projection_args:
project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter
scale: 32.0
easy_margin: False

margin_scheduler: MarginScheduler
margin_update:
initial_margin: 0.5
final_margin: 0.5
increase_start_epoch: 1
fix_start_epoch: 1
update_margin: True
increase_type: "exp" # exp, linear

loss: CrossEntropyLoss
loss_args: {}

optimizer: SGD
optimizer_args:
momentum: 0.9
nesterov: True
weight_decay: 0.0001

scheduler: ExponentialDecrease
scheduler_args:
initial_lr: 0.0001
final_lr: 0.000025
warm_up_epoch: 1
warm_from_zero: True
1 change: 1 addition & 0 deletions examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dataset_args:
speed_perturb: True
num_frms: 600
aug_prob: 0.6 # prob to add reverb & noise aug per sample
frontend: "fbank" # fbank, s3prl
fbank_args:
num_mel_bins: 80
frame_shift: 10
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ soundfile==0.10.3.post1
pypeln==0.4.9
silero-vad
pre-commit==3.5.0
s3prl
2 changes: 1 addition & 1 deletion tools/extract_embedding.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ for suffix in $(seq 0 $(($nj - 1))); do
suffix=$(printf '%03d' $suffix)
data_list_subfile=${log_dir}/split_${suffix}
embed_ark=${embed_dir}/xvector_${suffix}.ark
CUDA_VISIBLE_DEVICES=${gpus[$idx]} python wespeaker/bin/extract.py \
CUDA_VISIBLE_DEVICES=${gpus[$idx]} python -u wespeaker/bin/extract.py \
--config ${exp_dir}/config.yaml \
--model_path ${model_path} \
--data_type ${data_type} \
Expand Down
37 changes: 32 additions & 5 deletions wespeaker/bin/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from tqdm import tqdm

from wespeaker.dataset.dataset import Dataset
from wespeaker.dataset.dataset_utils import apply_cmvn, spec_aug
from wespeaker.frontend import *
from wespeaker.models.speaker_model import get_speaker_model
from wespeaker.utils.checkpoint import load_checkpoint
from wespeaker.utils.utils import parse_config_or_kwargs, validate_path
Expand All @@ -41,18 +43,27 @@ def extract(config='conf/config.yaml', **kwargs):
# auto-tuner to False
torch.backends.cudnn.benchmark = False

test_conf = copy.deepcopy(configs['dataset_args'])
# model: frontend (optional) => speaker model
model = get_speaker_model(configs['model'])(**configs['model_args'])
frontend_type = test_conf.get('frontend', 'fbank')
if frontend_type == 's3prl':
frontend_args = frontend_type + "_args"
print('Initializing frontend model (this could take some time) ...')
frontend = frontend_class_dict[frontend_type](
**test_conf[frontend_args], sample_rate=test_conf['resample_rate'])
model.add_module("frontend", frontend)
print('Loading checkpoint ...')
load_checkpoint(model, model_path)
print('Finished !!! Start extracting ...')
device = torch.device("cuda")
model.to(device).eval()

# test_configs
test_conf = copy.deepcopy(configs['dataset_args'])
# test_conf = copy.deepcopy(configs['dataset_args'])
test_conf['speed_perturb'] = False
if 'fbank_args' in test_conf:
test_conf['fbank_args']['dither'] = 0.0
elif 'mfcc_args' in test_conf:
test_conf['mfcc_args']['dither'] = 0.0
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['aug_prob'] = configs.get('aug_prob', 0.0)
Expand Down Expand Up @@ -81,8 +92,24 @@ def extract(config='conf/config.yaml', **kwargs):
embed_scp) as writer:
for _, batch in tqdm(enumerate(dataloader)):
utts = batch['key']
features = batch['feat']
features = features.float().to(device) # (B,T,F)
if frontend_type == 'fbank':
features = batch['feat']
features = features.float().to(device) # (B,T,F)
else: # 's3prl'
wavs = batch['wav'] # (B,1,W)
wavs = wavs.squeeze(1).float().to(device) # (B,W)
wavs_len = torch.LongTensor([wavs.shape[1]]).repeat(
wavs.shape[0]).to(device) # (B)
features, _ = model.frontend(wavs, wavs_len)

# apply cmvn
if test_conf.get('cmvn', True):
features = apply_cmvn(features,
**test_conf.get('cmvn_args', {}))
# spec augmentation
if test_conf.get('spec_aug', False):
features = spec_aug(features, **test_conf['spec_aug_args'])

# Forward through model
outputs = model(features) # embed or (embed_a, embed_b)
embeds = outputs[-1] if isinstance(outputs, tuple) else outputs
Expand Down
Loading
Loading