Skip to content

Commit

Permalink
add support to voxblink2 pretrained model and update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
王帅 committed Sep 3, 2024
1 parent 4be9d57 commit 6b149b3
Show file tree
Hide file tree
Showing 11 changed files with 325 additions and 4 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ We also created a WeChat group for better discussion and quicker response. Pleas
## Citations
If you find wespeaker useful, please cite it as
```bibtex
@article{wang2024advancing,
title={Advancing speaker embedding learning: Wespeaker toolkit for research and production},
author={Wang, Shuai and Chen, Zhengyang and Han, Bing and Wang, Hongji and Liang, Chengdong and Zhang, Binbin and Xiang, Xu and Ding, Wen and Rohdin, Johan and Silnova, Anna and others},
journal={Speech Communication},
volume={162},
pages={103104},
year={2024},
publisher={Elsevier}
}
@inproceedings{wang2023wespeaker,
title={Wespeaker: A research and production oriented speaker embedding learning toolkit},
author={Wang, Hongji and Liang, Chengdong and Wang, Shuai and Chen, Zhengyang and Zhang, Binbin and Xiang, Xu and Deng, Yanlei and Qian, Yanmin},
Expand Down
92 changes: 92 additions & 0 deletions docs/papers_using_wespeaker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Research that uses the wespeaker project

[TOC]

## Introduction

After the release of the WeSpeaker project, many users from both academia and industry have actively engaged with it in their research. We appreciate all the feedback and contributions from the community and would like to highlight some of these interesting works below.

## New Architecture
### ReDimNet
>
- Implementation in Wespeaker
- Paper Link

```
@article{yakovlev2024reshape,
title={Reshape Dimensions Network for Speaker Recognition},
author={Yakovlev, Ivan and Makarov, Rostislav and Balykin, Andrei and Malov, Pavel and Okhotnikov, Anton and Torgashov, Nikita},
journal={arXiv preprint arXiv:2407.18223},
year={2024}
}
```

### Golden gemini DF-ResNet


```
@article{liu2024golden,
title={Golden gemini is all you need: Finding the sweet spots for speaker verification},
author={Liu, Tianchi and Lee, Kong Aik and Wang, Qiongqiong and Li, Haizhou},
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
year={2024},
publisher={IEEE}
}
```

### SAM-ResNet

```
@inproceedings{qin2022simple,
title={Simple attention module based speaker verification with iterative noisy label detection},
author={Qin, Xiaoyi and Li, Na and Weng, Chao and Su, Dan and Li, Ming},
booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={6722--6726},
year={2022},
organization={IEEE}
}
```
### Whisper based Speaker Verification



## Pipelines

### DINO Pretraining with Large-scale Data

```
@inproceedings{wang2024leveraging,
title={Leveraging In-the-Wild Data for Effective Self-Supervised Pretraining in Speaker Recognition},
author={Wang, Shuai and Bai, Qibing and Liu, Qi and Yu, Jianwei and Chen, Zhengyang and Han, Bing and Qian, Yanmin and Li, Haizhou},
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={10901--10905},
year={2024},
organization={IEEE}
}
```


## New Recipe/ Dataset

### VoxBlink

```
@inproceedings{lin2024voxblink,
title={Voxblink: A large scale speaker verification dataset on camera},
author={Lin, Yuke and Qin, Xiaoyi and Zhao, Guoqing and Cheng, Ming and Jiang, Ning and Wu, Haiying and Li, Ming},
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={10271--10275},
year={2024},
organization={IEEE}
}
@article{lin2024voxblink2,
title={VoxBlink2: A 100K+ Speaker Recognition Corpus and the Open-Set Speaker-Identification Benchmark},
author={Lin, Yuke and Cheng, Ming and Zhang, Fulin and Gao, Yingying and Zhang, Shilei and Li, Ming},
journal={arXiv preprint arXiv:2407.11510},
year={2024}
}
```


5 changes: 4 additions & 1 deletion docs/pretrained.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ The model with suffix **LM** means that it is further fine-tuned using large-mar
| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [ECAPA1024](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024.zip) / [ECAPA1024_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024_LM.zip) | [ECAPA1024](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024.onnx) / [ECAPA1024_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_ECAPA1024_LM.onnx) |
| [VoxCeleb](../examples/voxceleb/v2/README.md) | EN | [Gemini_DFResnet114_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_gemini_dfresnet114_LM.zip)| [Gemini_DFResnet114_LM](https://wenet.org.cn/downloads?models=wespeaker&version=voxceleb_gemini_dfresnet114_LM.onnx) |
| [CNCeleb](../examples/cnceleb/v2/README.md) | CN | [ResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34.zip) / [ResNet34_LM](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34_LM.zip) | [ResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34.onnx) / [ResNet34_LM](https://wenet.org.cn/downloads?models=wespeaker&version=cnceleb_resnet34_LM.onnx) |

| [VoxBlink2](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34.zip) | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34.onnx) |
| [VoxBlink2 (pretrain) + VoxCeleb2 (finetune)](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34_ft.zip) | [SimAMResNet34](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet34_ft.onnx)
| [VoxBlink2](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100.zip) | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100.onnx) |
| [VoxBlink2 (pretrain) + VoxCeleb2 (finetune)](../examples/voxceleb/v2/README.md) | Multilingual | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100_ft.zip) | [SimAMResNet100](https://wenet.org.cn/downloads?models=wespeaker&version=voxblink2_samresnet100_ft.onnx)
### huggingface

| Datasets | Languages | Checkpoint (pt) | Runtime Model (onnx) |
Expand Down
1 change: 1 addition & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Reference

./paper.md
./speaker_recognition_papers.md
./papers_using_wespeaker.md
./python_api/modules.rst
11 changes: 9 additions & 2 deletions examples/voxceleb/v2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
* Setup: fbank80, num_frms200, epoch150, ArcMargin, aug_prob0.6, speed_perturb (no spec_aug)
* Scoring: cosine (sub mean of vox2_dev), AS-Norm, [QMF](https://arxiv.org/pdf/2010.11255)
* Metric: EER(%)

* 🔥 UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement.
* 🔥 UPDATE 2024.09.03: We support the SimAM_ResNet pretrained on VoxBlink2 and Finetuned on Voxceleb2!
* 🔥 UPDATE 2024.08.27: We support SSL models as the feature front-end, take a look at the WavLM recipe!
* UPDATE 2022.07.19: We apply the same setups as the winning system of CNSRC 2022 (see [cnceleb](https://github.com/wenet-e2e/wespeaker/tree/master/examples/cnceleb/v2) recipe for details), and obtain significant performance improvement.
* LR scheduler warmup from 0
* Remove one embedding layer in ResNet models
* Add large margin fine-tuning strategy (LM)
Expand Down Expand Up @@ -55,6 +56,12 @@
| | | | × || × | 0.707 | 0.889 | 1.546 |
| | | || x | × | 0.771 | 0.906 | 1.599 |
| | | ||| × | 0.638 | 0.839 | 1.427 |
| SimAM_ResNet34 (VoxBlink2 Pretrain) | 25.2M | || x | × | 0.415 | 0.615 | 1.121 |
| | | ||| × | 0.372 | 0.581 | 1.049 |
| | | |||| 0.372 | 0.559 | 0.997 |
| SimAM_ResNet100 (VoxBlink2 Pretrain) | 50.2M | || x | × | 0.229 | 0.458 | 0.868 |
| | | ||| × | 0.207 | 0.424 | 0.804 |
| | | |||| 0.202 | 0.421 | 0.795 |


## PLDA results
Expand Down
2 changes: 2 additions & 0 deletions wespeaker/cli/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class Hub(object):
"english": "voxceleb_resnet221_LM.tar.gz",
"campplus": "campplus_cn_common_200k.tar.gz",
"eres2net": "eres2net_cn_commom_200k.tar.gz",
"vblinkp": "voxblink2_samresnet34.tar.gz",
"vblinkf": "voxblink2_voxceleb_samresnet34.tar.gz",
}

def __init__(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions wespeaker/cli/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ def main():
elif args.eres2net:
model = load_model("eres2net")
model.set_wavform_norm(True)
elif args.vblinkp:
model = load_model("vblinkp")
elif args.vblinkf:
model = load_model("vblinkf")
else:
model = load_model(args.language)
else:
Expand Down
11 changes: 11 additions & 0 deletions wespeaker/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def get_args():
action='store_true',
help='whether to use the damo/speech_eres2net_sv_zh-cn_16k-common model'
)
parser.add_argument(
'--vblinkp',
action='store_true',
help='whether to use the samresnet34 model pretrained on voxblink2'
)
parser.add_argument(
'--vblinkf',
action='store_true',
help="whether to use the samresnet34 model pretrained on voxblink2 and"
"fintuned on voxceleb2"
)
parser.add_argument('-p',
'--pretrain',
type=str,
Expand Down
25 changes: 25 additions & 0 deletions wespeaker/models/pooling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ def get_out_dim(self):
return self.out_dim


class ASP(nn.Module):
# Attentive statistics pooling
def __init__(self, in_planes, acoustic_dim):
super(ASP, self).__init__()
outmap_size = int(acoustic_dim / 8)
self.out_dim = in_planes * 8 * outmap_size * 2

self.attention = nn.Sequential(
nn.Conv1d(in_planes * 8 * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, in_planes * 8 * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)

def forward(self, x):
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
return x


class MHASTP(torch.nn.Module):
""" Multi head attentive statistics pooling
Reference:
Expand Down
163 changes: 163 additions & 0 deletions wespeaker/models/samresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import torch
import torch.nn as nn
import wespeaker.models.pooling_layers as pooling_layers


class SimAMBasicBlock(nn.Module):
expansion = 1

def __init__(
self, ConvLayer, NormLayer, in_planes, planes, stride=1, block_id=1
):
super(SimAMBasicBlock, self).__init__()
self.conv1 = ConvLayer(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = NormLayer(planes)
self.conv2 = ConvLayer(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = NormLayer(planes)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()

self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.downsample = nn.Sequential(
ConvLayer(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
NormLayer(self.expansion * planes),
)

def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.SimAM(out)
out += self.downsample(x)
out = self.relu(out)
return out

def SimAM(self, X, lambda_p=1e-4):
n = X.shape[2] * X.shape[3] - 1
d = (X - X.mean(dim=[2, 3], keepdim=True)).pow(2)
v = d.sum(dim=[2, 3], keepdim=True) / n
E_inv = d / (4 * (v + lambda_p)) + 0.5
return X * self.sigmoid(E_inv)


class ResNet(nn.Module):
def __init__(
self, in_planes, block, num_blocks, in_ch=1, **kwargs
):
super(ResNet, self).__init__()
self.in_planes = in_planes
self.NormLayer = nn.BatchNorm2d
self.ConvLayer = nn.Conv2d

self.conv1 = self.ConvLayer(
in_ch, in_planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = self.NormLayer(in_planes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(
block, in_planes, num_blocks[0], stride=1, block_id=1
)
self.layer2 = self._make_layer(
block, in_planes * 2, num_blocks[1], stride=2, block_id=2
)
self.layer3 = self._make_layer(
block, in_planes * 4, num_blocks[2], stride=2, block_id=3
)
self.layer4 = self._make_layer(
block, in_planes * 8, num_blocks[3], stride=2, block_id=4
)

def _make_layer(self, block, planes, num_blocks, stride, block_id=1):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(
block(
self.ConvLayer,
self.NormLayer,
self.in_planes,
planes,
stride,
block_id,
)
)
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x


def SimAM_ResNet34(in_planes):
return ResNet(in_planes, SimAMBasicBlock, [3, 4, 6, 3])


def SimAM_ResNet100(in_planes):
return ResNet(in_planes, SimAMBasicBlock, [6, 16, 24, 3])


class SimAM_ResNet34_ASP(nn.Module):
def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0):
super(SimAM_ResNet34_ASP, self).__init__()
self.front = SimAM_ResNet34(in_planes)
self.pooling = pooling_layers.ASP(in_planes, acoustic_dim)
self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim)
self.drop = nn.Dropout(dropout) if dropout else None

def forward(self, x):
x = x.permute(0, 2, 1)
x = self.front(x.unsqueeze(dim=1))
x = self.pooling(x)
if self.drop:
x = self.drop(x)
x = self.bottleneck(x)
return x


class SimAM_ResNet100_ASP(nn.Module):
def __init__(self, in_planes=64, embed_dim=256, acoustic_dim=80, dropout=0):
super(SimAM_ResNet100_ASP, self).__init__()
self.front = SimAM_ResNet100(in_planes)
self.pooling = pooling_layers.ASP(in_planes, acoustic_dim)
self.bottleneck = nn.Linear(self.pooling.out_dim, embed_dim)
self.drop = nn.Dropout(dropout) if dropout else None

def forward(self, x):
x = x.permute(0, 2, 1)
x = self.front(x.unsqueeze(dim=1))
x = self.pooling(x)
if self.drop:
x = self.drop(x)
x = self.bottleneck(x)
return x


if __name__ == '__main__':
x = torch.zeros(1, 200, 80)
model = SimAM_ResNet34_ASP(embed_dim=256)
model.eval()
out = model(x)
print(out[-1].size())

num_params = sum(p.numel() for p in model.parameters())
print("{} M".format(num_params / 1e6))
Loading

0 comments on commit 6b149b3

Please sign in to comment.