diff --git a/README.md b/README.md index 690e93e8..c7c0d49f 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,16 @@ We also created a WeChat group for better discussion and quicker response. Pleas ## Citations If you find wespeaker useful, please cite it as ```bibtex +@article{wang2024advancing, + title={Advancing speaker embedding learning: Wespeaker toolkit for research and production}, + author={Wang, Shuai and Chen, Zhengyang and Han, Bing and Wang, Hongji and Liang, Chengdong and Zhang, Binbin and Xiang, Xu and Ding, Wen and Rohdin, Johan and Silnova, Anna and others}, + journal={Speech Communication}, + volume={162}, + pages={103104}, + year={2024}, + publisher={Elsevier} +} + @inproceedings{wang2023wespeaker, title={Wespeaker: A research and production oriented speaker embedding learning toolkit}, author={Wang, Hongji and Liang, Chengdong and Wang, Shuai and Chen, Zhengyang and Zhang, Binbin and Xiang, Xu and Deng, Yanlei and Qian, Yanmin}, diff --git a/docs/papers_using_wespeaker.md b/docs/papers_using_wespeaker.md new file mode 100644 index 00000000..a75ba7f4 --- /dev/null +++ b/docs/papers_using_wespeaker.md @@ -0,0 +1,92 @@ +# Research that uses the wespeaker project + +[TOC] + +## Introduction + +After the release of the WeSpeaker project, many users from both academia and industry have actively engaged with it in their research. We appreciate all the feedback and contributions from the community and would like to highlight some of these interesting works below. + +## New Architecture +### ReDimNet +> + +- Implementation in Wespeaker +- Paper Link + +``` +@article{yakovlev2024reshape, + title={Reshape Dimensions Network for Speaker Recognition}, + author={Yakovlev, Ivan and Makarov, Rostislav and Balykin, Andrei and Malov, Pavel and Okhotnikov, Anton and Torgashov, Nikita}, + journal={arXiv preprint arXiv:2407.18223}, + year={2024} +} +``` + +### Golden gemini DF-ResNet + + +``` +@article{liu2024golden, + title={Golden gemini is all you need: Finding the sweet spots for speaker verification}, + author={Liu, Tianchi and Lee, Kong Aik and Wang, Qiongqiong and Li, Haizhou}, + journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, + year={2024}, + publisher={IEEE} +} +``` + +### SAM-ResNet + +``` +@inproceedings{qin2022simple, + title={Simple attention module based speaker verification with iterative noisy label detection}, + author={Qin, Xiaoyi and Li, Na and Weng, Chao and Su, Dan and Li, Ming}, + booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={6722--6726}, + year={2022}, + organization={IEEE} +} +``` +### Whisper based Speaker Verification + + + +## Pipelines + +### DINO Pretraining with Large-scale Data + +``` +@inproceedings{wang2024leveraging, + title={Leveraging In-the-Wild Data for Effective Self-Supervised Pretraining in Speaker Recognition}, + author={Wang, Shuai and Bai, Qibing and Liu, Qi and Yu, Jianwei and Chen, Zhengyang and Han, Bing and Qian, Yanmin and Li, Haizhou}, + booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={10901--10905}, + year={2024}, + organization={IEEE} +} +``` + + +## New Recipe/ Dataset + +### VoxBlink + +``` +@inproceedings{lin2024voxblink, + title={Voxblink: A large scale speaker verification dataset on camera}, + author={Lin, Yuke and Qin, Xiaoyi and Zhao, Guoqing and Cheng, Ming and Jiang, Ning and Wu, Haiying and Li, Ming}, + booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, + pages={10271--10275}, + year={2024}, + organization={IEEE} +} + +@article{lin2024voxblink2, + title={VoxBlink2: A 100K+ Speaker Recognition Corpus and the Open-Set Speaker-Identification Benchmark}, + author={Lin, Yuke and Cheng, Ming and Zhang, Fulin and Gao, Yingying and Zhang, Shilei and Li, Ming}, + journal={arXiv preprint arXiv:2407.11510}, + year={2024} +} +``` + + diff --git a/docs/pretrained.md b/docs/pretrained.md index 7fdc3dd8..87f6c885 100644 --- a/docs/pretrained.md +++ b/docs/pretrained.md @@ -54,7 +54,10 @@ The model with suffix **LM** means that it is further fine-tuned using large-mar | [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA1024](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024.zip) / [ECAPA1024_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024_LM.zip) | [ECAPA1024](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024.onnx) / [ECAPA1024_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024_LM.onnx) | | [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [Gemini_DFResnet114_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_gemini_dfresnet114_LM.zip)| [Gemini_DFResnet114_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_gemini_dfresnet114_LM.onnx) | | [CNCeleb](../examples/cnceleb/v2/README.md) | CN | [ResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34.zip) / [ResNet34_LM](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34_LM.zip) | [ResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34.onnx) / [ResNet34_LM](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34_LM.onnx) | - +| [VoxBlink2](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34.zip) | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34.onnx) | +| [VoxBlink2 (pretrain) + VoxCeleb2 (finetune)](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34_ft.zip) | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34_ft.onnx) +| [VoxBlink2](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100.zip) | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100.onnx) | +| [VoxBlink2 (pretrain) + VoxCeleb2 (finetune)](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100_ft.zip) | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100_ft.onnx) ### huggingface | Datasets | Languages | Checkpoint (pt) | Runtime Model (onnx) | diff --git a/docs/reference.rst b/docs/reference.rst index 3af47698..76806108 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -7,4 +7,5 @@ Reference ./paper.md ./speaker_recognition_papers.md + ./papers_using_wespeaker.md ./python_api/modules.rst diff --git a/examples/voxceleb/v2/README.md b/examples/voxceleb/v2/README.md index e4b2af31..b773c73c 100644 --- a/examples/voxceleb/v2/README.md +++ b/examples/voxceleb/v2/README.md @@ -3,8 +3,9 @@ * Setup: fbank80, num_frms200, epoch150, ArcMargin, aug_prob0.6, speed_perturb (no spec_aug) * Scoring: cosine (sub mean of vox2_dev), AS-Norm, [QMF](https://arxiv.org/pdf/2010.11255) * Metric: EER(%) - -* 🔥 UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement. +* 🔥 UPDATE 2024.09.03: We support the SimAM_ResNet pretrained on VoxBlink2 and Finetuned on Voxceleb2! +* 🔥 UPDATE 2024.08.27: We support SSL models as the feature front-end, take a look at the WavLM recipe! +* UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement. * LR scheduler warmup from 0 * Remove one embedding layer in ResNet models * Add large margin fine-tuning strategy (LM) @@ -55,6 +56,12 @@ | | | | × | √ | × | 0.707 | 0.889 | 1.546 | | | | | √ | x | × | 0.771 | 0.906 | 1.599 | | | | | √ | √ | × | 0.638 | 0.839 | 1.427 | +| SimAM_ResNet34 (VoxBlink2 Pretrain) | 25.2M | | √ | x | × | 0.415 | 0.615 | 1.121 | +| | | | √ | √ | × | 0.372 | 0.581 | 1.049 | +| | | | √ | √ | √ | 0.372 | 0.559 | 0.997 | +| SimAM_ResNet100 (VoxBlink2 Pretrain) | 50.2M | | √ | x | × | 0.229 | 0.458 | 0.868 | +| | | | √ | √ | × | 0.207 | 0.424 | 0.804 | +| | | | √ | √ | √ | 0.202 | 0.421 | 0.795 | ## PLDA results diff --git a/wespeaker/cli/hub.py b/wespeaker/cli/hub.py index 79a301b6..f8b863cd 100644 --- a/wespeaker/cli/hub.py +++ b/wespeaker/cli/hub.py @@ -74,6 +74,8 @@ class Hub(object): "english": "voxceleb_resnet221_LM.tar.gz", "campplus": "campplus_cn_common_200k.tar.gz", "eres2net": "eres2net_cn_commom_200k.tar.gz", + "vblinkp": "voxblink2_samresnet34.tar.gz", + "vblinkf": "voxblink2_voxceleb_samresnet34.tar.gz", } def __init__(self) -> None: diff --git a/wespeaker/cli/speaker.py b/wespeaker/cli/speaker.py index 67030556..82217830 100644 --- a/wespeaker/cli/speaker.py +++ b/wespeaker/cli/speaker.py @@ -294,6 +294,10 @@ def main(): elif args.eres2net: model = load_model("eres2net") model.set_wavform_norm(True) + elif args.vblinkp: + model = load_model("vblinkp") + elif args.vblinkf: + model = load_model("vblinkf") else: model = load_model(args.language) else: diff --git a/wespeaker/cli/utils.py b/wespeaker/cli/utils.py index 34eb26b0..9926e9ff 100644 --- a/wespeaker/cli/utils.py +++ b/wespeaker/cli/utils.py @@ -47,6 +47,17 @@ def get_args(): action='store_true', help='whether to use the damo/speech_eres2net_sv_zh-cn_16k-common model' ) + parser.add_argument( + '--vblinkp', + action='store_true', + help='whether to use the samresnet34 model pretrained on voxblink2' + ) + parser.add_argument( + '--vblinkf', + action='store_true', + help="whether to use the samresnet34 model pretrained on voxblink2 and" + "fintuned on voxceleb2" + ) parser.add_argument('-p', '--pretrain', type=str, diff --git a/wespeaker/models/pooling_layers.py b/wespeaker/models/pooling_layers.py index 29b319df..cad5a14e 100644 --- a/wespeaker/models/pooling_layers.py +++ b/wespeaker/models/pooling_layers.py @@ -148,6 +148,31 @@ def get_out_dim(self): return self.out_dim +class ASP(nn.Module): + # Attentive statistics pooling + def __init__(self, in_planes, acoustic_dim): + super(ASP, self).__init__() + outmap_size = int(acoustic_dim / 8) + self.out_dim = in_planes * 8 * outmap_size * 2 + + self.attention = nn.Sequential( + nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + def forward(self, x): + x = x.reshape(x.size()[0], -1, x.size()[-1]) + w = self.attention(x) + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + x = x.view(x.size()[0], -1) + return x + + class MHASTP(torch.nn.Module): """ Multi head attentive statistics pooling Reference: diff --git a/wespeaker/models/samresnet.py b/wespeaker/models/samresnet.py new file mode 100644 index 00000000..22cc4084 --- /dev/null +++ b/wespeaker/models/samresnet.py @@ -0,0 +1,163 @@ +import torch +import torch.nn as nn +import wespeaker.models.pooling_layers as pooling_layers + + +class SimAMBasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1 + ): + super(SimAMBasicBlock, self).__init__() + self.conv1 = ConvLayer( + in_planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + ) + self.bn1 = NormLayer(planes) + self.conv2 = ConvLayer( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn2 = NormLayer(planes) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + self.downsample = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.downsample = nn.Sequential( + ConvLayer( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + NormLayer(self.expansion * planes), + ) + + def forward(self, x): + out = self.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out = self.SimAM(out) + out += self.downsample(x) + out = self.relu(out) + return out + + def SimAM(self, X, lambda_p=1e-4): + n = X.shape[2] * X.shape[3] - 1 + d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2) + v = d.sum(dim=[2, 3], keepdim=True) / n + E_inv = d / (4 * (v + lambda_p)) + 0.5 + return X * self.sigmoid(E_inv) + + +class ResNet(nn.Module): + def __init__( + self, in_planes, block, num_blocks, in_ch=1, **kwargs + ): + super(ResNet, self).__init__() + self.in_planes = in_planes + self.NormLayer = nn.BatchNorm2d + self.ConvLayer = nn.Conv2d + + self.conv1 = self.ConvLayer( + in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = self.NormLayer(in_planes) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer( + block, in_planes, num_blocks[0], stride=1, block_id=1 + ) + self.layer2 = self._make_layer( + block, in_planes * 2, num_blocks[1], stride=2, block_id=2 + ) + self.layer3 = self._make_layer( + block, in_planes * 4, num_blocks[2], stride=2, block_id=3 + ) + self.layer4 = self._make_layer( + block, in_planes * 8, num_blocks[3], stride=2, block_id=4 + ) + + def _make_layer(self, block, planes, num_blocks, stride, block_id=1): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append( + block( + self.ConvLayer, + self.NormLayer, + self.in_planes, + planes, + stride, + block_id, + ) + ) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + +def SimAM_ResNet34(in_planes): + return ResNet(in_planes, SimAMBasicBlock, [3, 4, 6, 3]) + + +def SimAM_ResNet100(in_planes): + return ResNet(in_planes, SimAMBasicBlock, [6, 16, 24, 3]) + + +class SimAM_ResNet34_ASP(nn.Module): + def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0): + super(SimAM_ResNet34_ASP, self).__init__() + self.front = SimAM_ResNet34(in_planes) + self.pooling = pooling_layers.ASP(in_planes, acoustic_dim) + self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim) + self.drop = nn.Dropout(dropout) if dropout else None + + def forward(self, x): + x = x.permute(0, 2, 1) + x = self.front(x.unsqueeze(dim=1)) + x = self.pooling(x) + if self.drop: + x = self.drop(x) + x = self.bottleneck(x) + return x + + +class SimAM_ResNet100_ASP(nn.Module): + def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0): + super(SimAM_ResNet100_ASP, self).__init__() + self.front = SimAM_ResNet100(in_planes) + self.pooling = pooling_layers.ASP(in_planes, acoustic_dim) + self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim) + self.drop = nn.Dropout(dropout) if dropout else None + + def forward(self, x): + x = x.permute(0, 2, 1) + x = self.front(x.unsqueeze(dim=1)) + x = self.pooling(x) + if self.drop: + x = self.drop(x) + x = self.bottleneck(x) + return x + + +if __name__ == '__main__': + x = torch.zeros(1, 200, 80) + model = SimAM_ResNet34_ASP(embed_dim=256) + model.eval() + out = model(x) + print(out[-1].size()) + + num_params = sum(p.numel() for p in model.parameters()) + print("{} M".format(num_params / 1e6)) diff --git a/wespeaker/models/speaker_model.py b/wespeaker/models/speaker_model.py index 4ef4a311..053881ad 100644 --- a/wespeaker/models/speaker_model.py +++ b/wespeaker/models/speaker_model.py @@ -1,5 +1,5 @@ # Copyright (c) 2022 Hongji Wang (jijijiang77@gmail.com) -# +# 2024 Shuai Wang (wsstriving@gmail.com) # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -21,6 +21,7 @@ import wespeaker.models.gemini_dfresnet as gemini import wespeaker.models.res2net as res2net import wespeaker.models.redimnet as redimnet +import wespeaker.models.samresnet as samresnet def get_speaker_model(model_name: str): @@ -42,6 +43,8 @@ def get_speaker_model(model_name: str): return getattr(gemini, model_name) elif model_name.startswith("ReDimNet"): return getattr(redimnet, model_name) + elif model_name.startswith("SimAM_ResNet"): + return getattr(samresnet, model_name) else: # model_name error !!! print(model_name + " not found !!!") exit(1)