From 7e09329a46b2633be6b7405d3f33ea253034b02c Mon Sep 17 00:00:00 2001 From: microhum Date: Wed, 5 Jun 2024 14:46:21 +0700 Subject: [PATCH] embedding fix --- models/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/transformers.py b/models/transformers.py index 4c396d9..61f4628 100644 --- a/models/transformers.py +++ b/models/transformers.py @@ -196,7 +196,7 @@ def __init__(self): self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1) self.decoder_norm_parallel = nn.LayerNorm(512) if opts.ref_nshot == 52: - self.cls_embedding = nn.Embedding(92,512) + self.cls_embedding = nn.Embedding(96,512) else: self.cls_embedding = nn.Embedding(52,512) self.cls_token = nn.Parameter(torch.zeros(1, 1, 512))