Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow an arbitrary mask to be used in the self attention #8235

Merged
merged 10 commits into from
Nov 26, 2024

Conversation

Lucas-rbnt
Copy link
Contributor

Description

The aim of this PR is to enable the use of an arbitrary mask in the self attention module, which is very useful in the case of missing data or masked modeling.

Official torch implementations allow the use of an arbitrary mask, and in MONAI the use of a mask is also made possible with the causal argument. Here, it's just a generalization directly in the forward pass.

In the SABlock and TransformerBlock, it is now possible to input a boolean mask of size (BS, Seq_length).
Only the columns of the masked token are set to -inf and not the rows, as is rarely the case in common implementations. Masked tokens don't contribute to the gradient anyway.
In cases where causal attention is required, inputting a mask is not supported to avoid masks overlapping.

I haven't implemented the addition mask to the attention matrix, which allows you to use values other than -inf in certain cases, as may be the case here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

If you think it's relevant, it could be added.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@ericspod ericspod requested a review from KumoLiu November 22, 2024 13:57
@ericspod
Copy link
Member

I think this is fine with the minor proposed change.

Copy link
Contributor

@KumoLiu KumoLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, looks good to me.

@KumoLiu
Copy link
Contributor

KumoLiu commented Nov 24, 2024

/build

@KumoLiu KumoLiu enabled auto-merge (squash) November 24, 2024 07:40
@KumoLiu
Copy link
Contributor

KumoLiu commented Nov 25, 2024

It seems there is a TorchScript conversion issue caused by this addition.

======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_0 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.492Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.492Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.492Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.492Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.492Z] RuntimeError: 
[2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/blocks/selfattention.py", line 157
[2024-11-24T08:22:21.492Z]     def forward(self, x, attn_mask: torch.Tensor | None = None):
[2024-11-24T08:22:21.492Z]                                     ~~~~~~~~~~~~~~~~~~~ <--- HERE
[2024-11-24T08:22:21.492Z]         """
[2024-11-24T08:22:21.492Z]         Args:
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] ======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_1 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.492Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.492Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.492Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.492Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.492Z] RuntimeError: 
[2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/blocks/selfattention.py", line 157
[2024-11-24T08:22:21.492Z]     def forward(self, x, attn_mask: torch.Tensor | None = None):
[2024-11-24T08:22:21.492Z]                                     ~~~~~~~~~~~~~~~~~~~ <--- HERE
[2024-11-24T08:22:21.492Z]         """
[2024-11-24T08:22:21.492Z]         Args:
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] ======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_2 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.493Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.493Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.493Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.493Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.493Z] RuntimeError: Can't redefine method: forward on class: __torch__.monai.networks.blocks.selfattention.___torch_mangle_531.SABlock (of Python compilation unit at: 0x5eb6b10)
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] ======================================================================
[2024-11-24T08:22:21.493Z] ERROR: test_script_3 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.493Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.493Z] Traceback (most recent call last):
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.493Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.493Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.493Z]     convert_to_torchscript(
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.493Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.493Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.493Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.493Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.493Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.493Z] RuntimeError: Can't redefine method: forward on class: __torch__.monai.networks.blocks.selfattention.SABlock (of Python compilation unit at: 0x5eb6b10)
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.493Z] Ran 15769 tests in 1527.821s
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] FAILED (errors=4, skipped=1100)

@Lucas-rbnt
Copy link
Contributor Author

It seems there is a TorchScript conversion issue caused by this addition.

It seems to be due to a typing error on the | character.

[2024-11-24T08:22:21.492Z] RuntimeError:
[2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:

This typing method is reserved for python versions >3.10, but it seems that python 3.9 is being used in the test environment.

[2024-11-24T08:22:21.492Z] File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in

I used this notation because I thought I'd already seen it in MONAI.
The problem will be solved either by upping the python version used, or by switching to an older typing syntax.

from typing import Optional
...
attn_mask: Optional[torch.tensor] = None

I can change, as you prefer!

@KumoLiu
Copy link
Contributor

KumoLiu commented Nov 25, 2024

I can change, as you prefer!

Yes, could you help convert this to the older typing syntax, as TorchScript does not support the | operator? Thanks.

auto-merge was automatically disabled November 25, 2024 13:10

Head branch was pushed to by a user without write access

@KumoLiu
Copy link
Contributor

KumoLiu commented Nov 25, 2024

/build

@KumoLiu KumoLiu merged commit 649c7c8 into Project-MONAI:dev Nov 26, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants