From b64bba71e13dbec30de03b4695901995980fdde1 Mon Sep 17 00:00:00 2001 From: Duo Li Date: Tue, 8 Aug 2023 17:47:08 +0800 Subject: [PATCH] Update attention.py --- mmpretrain/models/utils/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py index e92f6054dd8..7069723760b 100644 --- a/mmpretrain/models/utils/attention.py +++ b/mmpretrain/models/utils/attention.py @@ -38,7 +38,7 @@ def scaled_dot_product_attention_pyimpl(query, attn_mask = torch.ones( query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) if attn_mask is not None and attn_mask.dtype == torch.bool: - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + attn_mask = attn_mask.masked_fill(~attn_mask, -float('inf')) attn_weight = query @ key.transpose(-2, -1) / scale if attn_mask is not None: