Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow an arbitrary mask to be used in the self attention #8235

Merged
merged 10 commits into from
Nov 26, 2024
22 changes: 18 additions & 4 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -154,10 +154,12 @@ def __init__(
)
self.input_size = input_size

def forward(self, x):
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
B x (s_dim_1 * ... * s_dim_n). Defaults to None.

Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
Expand All @@ -176,7 +178,13 @@ def forward(self, x):

if self.use_flash_attention:
x = F.scaled_dot_product_attention(
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
query=q,
key=k,
value=v,
attn_mask=attn_mask,
scale=self.scale,
dropout_p=self.dropout_rate,
is_causal=self.causal,
)
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
Expand All @@ -186,10 +194,16 @@ def forward(self, x):
att_mat = self.rel_positional_embedding(x, att_mat, q)

if self.causal:
if attn_mask is not None:
raise ValueError("Causal attention does not support attention masks.")
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
6 changes: 4 additions & 2 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,10 @@ def __init__(
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
if self.with_cross_attention:
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
Expand Down
18 changes: 18 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ def test_causal(self):
# check upper triangular part of the attention matrix is zero
assert torch.triu(block.att_mat, diagonal=1).sum() == 0

def test_masked_selfattention(self):
n = 64
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
input_shape = (1, n, 128)
# generate a mask randomly with zeros and ones of shape (1, n)
mask = torch.randint(0, 2, (1, n)).bool()
block(torch.randn(input_shape), attn_mask=mask)
att_mat = block.att_mat.squeeze()
# ensure all masked columns are zeros
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))

def test_causal_and_mask(self):
with self.assertRaises(ValueError):
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
inputs = torch.randn(2, 64, 128)
mask = torch.randint(0, 2, (2, 64)).bool()
block(inputs, attn_mask=mask)

@skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
Expand Down
Loading