From 338da6b22431a0ddf19c2355fc7a47db00fb6a44 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Thu, 21 Nov 2024 16:41:35 +0100 Subject: [PATCH 1/7] enable the use of masked self-attention Signed-off-by: Lucas Robinet --- monai/networks/blocks/selfattention.py | 26 ++++++++++++++++++++++++-- tests/test_selfattention.py | 15 +++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index ac96b077bd..f22c75308a 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -154,10 +154,12 @@ def __init__( ) self.input_size = input_size - def forward(self, x): + def forward(self, x, attn_mask: torch.Tensor | None = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + attn_mask (torch.Tensor, optional): mask to apply to the attention matrix. + Defaults to None. B x N_heads x (s_dim_1 * ... * s_dim_n) x (s_dim_1 * ... * s_dim_n). Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C @@ -176,7 +178,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=attn_mask, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -188,6 +196,10 @@ def forward(self, x): if self.causal: att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) + if attn_mask is not None: + attn_mask = attn_mask[:, None, :, None] * attn_mask[:, None, None, :] + att_mat.masked_fill_(~attn_mask, torch.finfo(att_mat.dtype).min) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: @@ -203,3 +215,13 @@ def forward(self, x): x = self.out_proj(x) x = self.drop_output(x) return x + + +if __name__ == "__main__": + sa = SABlock(128, 1) + x = torch.randn(1, 6, 128) + mask = torch.ones((1, 6), dtype=torch.bool) + mask[0][2] = False + print(mask) + out = sa(x, attn_mask=mask) + print(out.shape) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 88919fd8b1..09aa454842 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -122,6 +122,21 @@ def test_causal(self): # check upper triangular part of the attention matrix is zero assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + def test_masked_selfattention(self): + n = 4 + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True) + input_shape = (1, n, 128) + mask = torch.tensor([[1, 1, 1, 0]]).bool() + block(torch.randn(input_shape), attn_mask=mask) + att_mat = block.att_mat.squeeze(1) + # get the masked row and the remaining ones based on mask 0 values + rows_true = att_mat[mask, :] + rows_false = att_mat[~mask, :] + # check that in false rows every element is equal to 1/4 + assert torch.allclose(rows_false, torch.ones_like(rows_false) / n) + # check that in true rows the mask column is zero + assert torch.allclose(rows_true[:, -1], torch.zeros_like(rows_true[:, -1])) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format From c869d69aee0c9e217f1b77c05af41f51e93b2310 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Thu, 21 Nov 2024 18:31:20 +0100 Subject: [PATCH 2/7] enable using a mask in the transformer forward pass Signed-off-by: Lucas Robinet --- monai/networks/blocks/transformerblock.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 05eb3b07ab..6f0da73e7b 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,8 +90,10 @@ def __init__( use_flash_attention=use_flash_attention, ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) From 42db7c054938be9c31f0080e3b0d66f00a7d79bc Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Fri, 22 Nov 2024 11:04:09 +0100 Subject: [PATCH 3/7] Refactoring to be easier to follow Signed-off-by: Lucas Robinet --- monai/networks/blocks/selfattention.py | 19 +++++-------------- tests/test_selfattention.py | 23 +++++++++++++---------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index f22c75308a..4edd213c35 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -159,7 +159,7 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C attn_mask (torch.Tensor, optional): mask to apply to the attention matrix. - Defaults to None. B x N_heads x (s_dim_1 * ... * s_dim_n) x (s_dim_1 * ... * s_dim_n). + B x (s_dim_1 * ... * s_dim_n). Defaults to None. Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C @@ -194,14 +194,15 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: + assert attn_mask is None, "Causal attention does not support attention masks." att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) if attn_mask is not None: - attn_mask = attn_mask[:, None, :, None] * attn_mask[:, None, None, :] - att_mat.masked_fill_(~attn_mask, torch.finfo(att_mat.dtype).min) + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) + attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1) + att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf")) att_mat = att_mat.softmax(dim=-1) - if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html @@ -215,13 +216,3 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): x = self.out_proj(x) x = self.drop_output(x) return x - - -if __name__ == "__main__": - sa = SABlock(128, 1) - x = torch.randn(1, 6, 128) - mask = torch.ones((1, 6), dtype=torch.bool) - mask[0][2] = False - print(mask) - out = sa(x, attn_mask=mask) - print(out.shape) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 09aa454842..f057e727c8 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -123,19 +123,22 @@ def test_causal(self): assert torch.triu(block.att_mat, diagonal=1).sum() == 0 def test_masked_selfattention(self): - n = 4 + n = 64 block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True) input_shape = (1, n, 128) - mask = torch.tensor([[1, 1, 1, 0]]).bool() + # generate a mask randomly with zeros and ones of shape (1, n) + mask = torch.randint(0, 2, (1, n)).bool() block(torch.randn(input_shape), attn_mask=mask) - att_mat = block.att_mat.squeeze(1) - # get the masked row and the remaining ones based on mask 0 values - rows_true = att_mat[mask, :] - rows_false = att_mat[~mask, :] - # check that in false rows every element is equal to 1/4 - assert torch.allclose(rows_false, torch.ones_like(rows_false) / n) - # check that in true rows the mask column is zero - assert torch.allclose(rows_true[:, -1], torch.zeros_like(rows_true[:, -1])) + att_mat = block.att_mat.squeeze() + # ensure all masked columns are zeros + assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)])) + + def test_causal_and_mask(self): + with self.assertRaises(AssertionError): + block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64) + inputs = torch.randn(2, 64, 128) + mask = torch.randint(0, 2, (2, 64)).bool() + block(inputs, attn_mask=mask) @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): From ea687ad6951c9939ef0d6e985db17ef668b24706 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Fri, 22 Nov 2024 16:22:19 +0100 Subject: [PATCH 4/7] Update monai/networks/blocks/selfattention.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Lucas Robinet --- monai/networks/blocks/selfattention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4edd213c35..bd4a47d78c 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -194,7 +194,8 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: - assert attn_mask is None, "Causal attention does not support attention masks." + if attn_mask is not None: + raise ValueError("Causal attention does not support attention masks.") att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) if attn_mask is not None: From 101e6cfecb9ad1c4db19009c183d3c7f4038a65f Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Fri, 22 Nov 2024 17:29:01 +0100 Subject: [PATCH 5/7] Update test_selfattention.py Signed-off-by: Lucas Robinet --- tests/test_selfattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index f057e727c8..338f1bf840 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -134,7 +134,7 @@ def test_masked_selfattention(self): assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)])) def test_causal_and_mask(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64) inputs = torch.randn(2, 64, 128) mask = torch.randint(0, 2, (2, 64)).bool() From 337ddb35db67453b088ee7160f98b683fcf804fc Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Mon, 25 Nov 2024 14:10:27 +0100 Subject: [PATCH 6/7] Fixing typing for TorchScript version Signed-off-by: Lucas Robinet --- monai/networks/blocks/selfattention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index bd4a47d78c..29a818c34f 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple, Union +from typing import Tuple, Union, Optional import torch import torch.nn as nn @@ -154,7 +154,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x, attn_mask: torch.Tensor | None = None): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C From d8f6d6fd07a21c8f8f9a69d6b1dac3a535e5e5d5 Mon Sep 17 00:00:00 2001 From: Lucas Robinet Date: Mon, 25 Nov 2024 15:18:35 +0100 Subject: [PATCH 7/7] Fixing: imports are incorrectly sorted Signed-off-by: Lucas Robinet --- monai/networks/blocks/selfattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 29a818c34f..86e1b1d3ae 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple, Union, Optional +from typing import Optional, Tuple, Union import torch import torch.nn as nn