Skip to content

Commit

Permalink
fix the speaker_encoder import
Browse files Browse the repository at this point in the history
  • Loading branch information
wsstriving committed Aug 7, 2024
1 parent 9b6536e commit c347bd0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 8 additions & 8 deletions wespeaker/models/redimnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,14 +1008,14 @@ def ReDimNetB6(feat_dim=72, embed_dim=192, pooling_func="ASTP", two_emb_layer=Fa


if __name__ == "__main__":
# x = torch.zeros(1, 200, 72)
# model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False)
# model.eval()
# out = model(x)
# print(out[-1].size())

# num_params = sum(p.numel() for p in model.parameters())
# print("{} M".format(num_params / 1e6))
x = torch.zeros(1, 200, 72)
model = ReDimNet(feat_dim=72, embed_dim=192, two_emb_layer=False)
model.eval()
out = model(x)
print(out[-1].size())

num_params = sum(p.numel() for p in model.parameters())
print("{} M".format(num_params / 1e6))

# Currently, the model sizes differ from the ones in the paper
model_classes = [
Expand Down
3 changes: 3 additions & 0 deletions wespeaker/models/speaker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import wespeaker.models.eres2net as eres2net
import wespeaker.models.gemini_dfresnet as gemini
import wespeaker.models.res2net as res2net
import wespeaker.models.redimnet as redimnet


def get_speaker_model(model_name: str):
Expand All @@ -39,6 +40,8 @@ def get_speaker_model(model_name: str):
return getattr(res2net, model_name)
elif model_name.startswith("Gemini"):
return getattr(gemini, model_name)
elif model_name.startswith("ReDimNet"):
return getattr(redimnet, model_name)
else: # model_name error !!!
print(model_name + " not found !!!")
exit(1)

0 comments on commit c347bd0

Please sign in to comment.