From 49a1e343129909f7e01c67882a71c12beed3ef30 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:28:31 +0800 Subject: [PATCH 01/24] 7994-enhance-mlpblock (#7995) Fixes #7994 . ### Description The current implementation does not support tuple input of "GEGLU" since it only change the out features of the first linear layer when the input is a string of "GEGLU". This PR enhances it, and also enable "vista3d" mode to support #7987 Tests are added to cover the changes. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 1 + monai/networks/blocks/mlp.py | 20 ++++++++++++++++---- tests/test_mlp.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index fe04f96a80..65f9a4dcf2 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -99,6 +99,7 @@ jobs: name: Install itk pre-release (Linux only) run: | python -m pip install --pre -U itk + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - name: Install the dependencies run: | python -m pip install --user --upgrade pip wheel diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index d3510b64d3..8771711d25 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -11,12 +11,15 @@ from __future__ import annotations +from typing import Union + import torch.nn as nn from monai.networks.layers import get_act_layer +from monai.networks.layers.factories import split_args from monai.utils import look_up_option -SUPPORTED_DROPOUT_MODE = {"vit", "swin"} +SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"} class MLPBlock(nn.Module): @@ -39,7 +42,7 @@ def __init__( https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 "swin" corresponds to one instance as implemented in https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 - + "vista3d" mode does not use dropout. """ @@ -48,15 +51,24 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size - self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) + act_name, _ = split_args(act) + self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.fn = get_act_layer(act) - self.drop1 = nn.Dropout(dropout_rate) + # Use Union[nn.Dropout, nn.Identity] for type annotations + self.drop1: Union[nn.Dropout, nn.Identity] + self.drop2: Union[nn.Dropout, nn.Identity] + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) if dropout_opt == "vit": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = nn.Dropout(dropout_rate) elif dropout_opt == "swin": + self.drop1 = nn.Dropout(dropout_rate) self.drop2 = self.drop1 + elif dropout_opt == "vista3d": + self.drop1 = nn.Identity() + self.drop2 = nn.Identity() else: raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 54f70d3318..2598d8877d 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -15,10 +15,12 @@ import numpy as np import torch +import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode from monai.networks.blocks.mlp import MLPBlock +from monai.networks.layers.factories import split_args TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): @@ -31,6 +33,14 @@ ] TEST_CASE_MLP.append(test_case) +# test different activation layers +TEST_CASE_ACT = [] +for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore + TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)]) + +# test different dropout modes +TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]] + class TestMLPBlock(unittest.TestCase): @@ -45,6 +55,24 @@ def test_ill_arg(self): with self.assertRaises(ValueError): MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0) + @parameterized.expand(TEST_CASE_ACT) + def test_act(self, input_param, input_shape, expected_shape): + net = MLPBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + act_name, _ = split_args(input_param["act"]) + if act_name == "GEGLU": + self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2) + else: + self.assertEqual(net.linear1.in_features, net.linear1.out_features) + + @parameterized.expand(TEST_CASE_DROP) + def test_dropout_mode(self, dropout_mode, dropout_layer): + net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode) + self.assertTrue(isinstance(net.drop1, dropout_layer)) + self.assertTrue(isinstance(net.drop2, dropout_layer)) + if __name__ == "__main__": unittest.main() From 660891f37665874a38b289fdb8cd339800aff9cc Mon Sep 17 00:00:00 2001 From: Balamurali Date: Thu, 8 Aug 2024 00:58:12 -0700 Subject: [PATCH 02/24] Initial commit -- Adding calibration loss specific to segmentation (#7819) ### Description Model calibration has helped in developing reliable deep learning models. In this pull request, I have added a new loss function NACL (https://arxiv.org/abs/2303.06268, https://arxiv.org/abs/2401.14487) which has shown promising results for both discriminative and calibration in segmentation. **Future Plans:** Currently, MONAI has some of the alternative loss functions (Label Smoothing, and Focal Loss), but it doesn't have the calibration specific loss functions (https://arxiv.org/abs/2111.15430, https://arxiv.org/abs/2209.09641). Besides, these methods are better evaluated with calibration metrics, Expected Calibration Error (https://lightning.ai/docs/torchmetrics/stable/classification/calibration_error.html). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Balamurali Signed-off-by: bala93 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/losses.rst | 5 ++ monai/losses/__init__.py | 1 + monai/losses/nacl_loss.py | 139 +++++++++++++++++++++++++++++++ tests/test_nacl_loss.py | 166 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 311 insertions(+) create mode 100644 monai/losses/nacl_loss.py create mode 100644 tests/test_nacl_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index ba794af3eb..528ccd1173 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -93,6 +93,11 @@ Segmentation Losses .. autoclass:: SoftDiceclDiceLoss :members: +`NACLLoss` +~~~~~~~~~~ +.. autoclass:: NACLLoss + :members: + Registration Losses ------------------- diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index e937b53fa4..41935be204 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -37,6 +37,7 @@ from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss from .multi_scale import MultiScaleLoss +from .nacl_loss import NACLLoss from .perceptual import PerceptualLoss from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py new file mode 100644 index 0000000000..3303e89bce --- /dev/null +++ b/monai/losses/nacl_loss.py @@ -0,0 +1,139 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + +from monai.networks.layers import GaussianFilter, MeanFilter + + +class NACLLoss(_Loss): + """ + Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation. + NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions + to match a soft class proportion of surrounding pixel. + + Murugesan, Balamurali, et al. + "Trust your neighbours: Penalty-based constraints for model calibration." + International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023. + https://arxiv.org/abs/2303.06268 + + Murugesan, Balamurali, et al. + "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints." + https://arxiv.org/abs/2401.14487 + """ + + def __init__( + self, + classes: int, + dim: int, + kernel_size: int = 3, + kernel_ops: str = "mean", + distance_type: str = "l1", + alpha: float = 0.1, + sigma: float = 1.0, + ) -> None: + """ + Args: + classes: number of classes + dim: dimension of data (supports 2d and 3d) + kernel_size: size of the spatial kernel + distance_type: l1/l2 distance between spatial kernel and predicted logits + alpha: weightage between cross entropy and logit constraint + sigma: sigma of gaussian + """ + + super().__init__() + + if kernel_ops not in ["mean", "gaussian"]: + raise ValueError("Kernel ops must be either mean or gaussian") + + if dim not in [2, 3]: + raise ValueError(f"Support 2d and 3d, got dim={dim}.") + + if distance_type not in ["l1", "l2"]: + raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}") + + self.nc = classes + self.dim = dim + self.cross_entropy = nn.CrossEntropyLoss() + self.distance_type = distance_type + self.alpha = alpha + self.ks = kernel_size + self.svls_layer: Any + + if kernel_ops == "mean": + self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size) + self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim) + if kernel_ops == "gaussian": + self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma) + + def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: + """ + Converts the mask to one hot represenation and is smoothened with the selected spatial filter. + + Args: + mask: the shape should be BH[WD]. + + Returns: + torch.Tensor: the shape would be BNH[WD], N being number of classes. + """ + rmask: torch.Tensor + + if self.dim == 2: + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() + rmask = self.svls_layer(oh_labels) + + if self.dim == 3: + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() + rmask = self.svls_layer(oh_labels) + + return rmask + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Computes standard cross-entropy loss and constraints it neighbor aware logit penalty. + + Args: + inputs: the shape should be BNH[WD], where N is the number of classes. + targets: the shape should be BH[WD]. + + Returns: + torch.Tensor: value of the loss. + + Example: + >>> import torch + >>> from monai.losses import NACLLoss + >>> B, N, H, W = 8, 3, 64, 64 + >>> input = torch.rand(B, N, H, W) + >>> target = torch.randint(0, N, (B, H, W)) + >>> criterion = NACLLoss(classes = N, dim = 2) + >>> loss = criterion(input, target) + """ + + loss_ce = self.cross_entropy(inputs, targets) + + utargets = self.get_constr_target(targets) + + if self.distance_type == "l1": + loss_conf = utargets.sub(inputs).abs_().mean() + elif self.distance_type == "l2": + loss_conf = utargets.sub(inputs).pow_(2).abs_().mean() + + loss: torch.Tensor = loss_ce + self.alpha * loss_conf + + return loss diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py new file mode 100644 index 0000000000..51ec275cf4 --- /dev/null +++ b/tests/test_nacl_loss.py @@ -0,0 +1,166 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses import NACLLoss + +inputs = torch.tensor( + [ + [ + [ + [0.1498, 0.1158, 0.3996, 0.3730], + [0.2155, 0.1585, 0.8541, 0.8579], + [0.6640, 0.2424, 0.0774, 0.0324], + [0.0580, 0.2180, 0.3447, 0.8722], + ], + [ + [0.3908, 0.9366, 0.1779, 0.1003], + [0.9630, 0.6118, 0.4405, 0.7916], + [0.5782, 0.9515, 0.4088, 0.3946], + [0.7860, 0.3910, 0.0324, 0.9568], + ], + [ + [0.0759, 0.0238, 0.5570, 0.1691], + [0.2703, 0.7722, 0.1611, 0.6431], + [0.8051, 0.6596, 0.4121, 0.1125], + [0.5283, 0.6746, 0.5528, 0.7913], + ], + ] + ] +) +targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]]) + +TEST_CASES = [ + [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], + [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], + [{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], + [{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], + [{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790], + [ + {"classes": 3, "dim": 3, "kernel_ops": "gaussian"}, + { + "inputs": torch.tensor( + [ + [ + [ + [ + [0.5977, 0.2767, 0.0591, 0.1675], + [0.4835, 0.3778, 0.8406, 0.3065], + [0.6047, 0.2860, 0.9742, 0.2013], + [0.9128, 0.8368, 0.6711, 0.4384], + ], + [ + [0.9797, 0.1863, 0.5584, 0.6652], + [0.2272, 0.2004, 0.7914, 0.4224], + [0.5097, 0.8818, 0.2581, 0.3495], + [0.1054, 0.5483, 0.3732, 0.3587], + ], + [ + [0.3060, 0.7066, 0.7922, 0.4689], + [0.1733, 0.8902, 0.6704, 0.2037], + [0.8656, 0.5561, 0.2701, 0.0092], + [0.1866, 0.7714, 0.6424, 0.9791], + ], + [ + [0.5067, 0.3829, 0.6156, 0.8985], + [0.5192, 0.8347, 0.2098, 0.2260], + [0.8887, 0.3944, 0.6400, 0.5345], + [0.1207, 0.3763, 0.5282, 0.7741], + ], + ], + [ + [ + [0.8499, 0.4759, 0.1964, 0.5701], + [0.3190, 0.1238, 0.2368, 0.9517], + [0.0797, 0.6185, 0.0135, 0.8672], + [0.4116, 0.1683, 0.1355, 0.0545], + ], + [ + [0.7533, 0.2658, 0.5955, 0.4498], + [0.9500, 0.2317, 0.2825, 0.9763], + [0.1493, 0.1558, 0.3743, 0.8723], + [0.1723, 0.7980, 0.8816, 0.0133], + ], + [ + [0.8426, 0.2666, 0.2077, 0.3161], + [0.1725, 0.8414, 0.1515, 0.2825], + [0.4882, 0.5159, 0.4120, 0.1585], + [0.2551, 0.9073, 0.7691, 0.9898], + ], + [ + [0.4633, 0.8717, 0.8537, 0.2899], + [0.3693, 0.7953, 0.1183, 0.4596], + [0.0087, 0.7925, 0.0989, 0.8385], + [0.8261, 0.6920, 0.7069, 0.4464], + ], + ], + [ + [ + [0.0110, 0.1608, 0.4814, 0.6317], + [0.0194, 0.9669, 0.3259, 0.0028], + [0.5674, 0.8286, 0.0306, 0.5309], + [0.3973, 0.8183, 0.0238, 0.1934], + ], + [ + [0.8947, 0.6629, 0.9439, 0.8905], + [0.0072, 0.1697, 0.4634, 0.0201], + [0.7184, 0.2424, 0.0820, 0.7504], + [0.3937, 0.1424, 0.4463, 0.5779], + ], + [ + [0.4123, 0.6227, 0.0523, 0.8826], + [0.0051, 0.0353, 0.3662, 0.7697], + [0.4867, 0.8986, 0.2510, 0.5316], + [0.1856, 0.2634, 0.9140, 0.9725], + ], + [ + [0.2041, 0.4248, 0.2371, 0.7256], + [0.2168, 0.5380, 0.4538, 0.7007], + [0.9013, 0.2623, 0.0739, 0.2998], + [0.1366, 0.5590, 0.2952, 0.4592], + ], + ], + ] + ] + ), + "targets": torch.tensor( + [ + [ + [[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]], + [[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]], + [[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]], + [[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]], + ] + ] + ), + }, + 1.15035, + ], +] + + +class TestNACLLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_result(self, input_param, input_data, expected_val): + loss = NACLLoss(**input_param) + result = loss(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() From 4a3117fe6bbfbd8e6e33d6bc5d36f8ae70135ddd Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:27:38 +0800 Subject: [PATCH 03/24] Ensure location as tuple in wsireader (#8007) Fixes #8006 ### Description Ensure location as tuple in wsireader ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/wsi_datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 3488029a7a..2ee8c9d363 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -23,7 +23,7 @@ from monai.data.utils import iter_patch_position from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import ForegroundMask, Randomizable, apply_transform -from monai.utils import convert_to_dst_type, ensure_tuple_rep +from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"] @@ -123,9 +123,9 @@ def _get_label(self, sample: dict): def _get_location(self, sample: dict): if self.center_location: size = self._get_size(sample) - return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))] + return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))) else: - return sample[WSIPatchKeys.LOCATION] + return ensure_tuple(sample[WSIPatchKeys.LOCATION]) def _get_level(self, sample: dict): if self.patch_level is None: From 0bb05d7bca54db8c3cf670b1d27883f4116c21dc Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 9 Aug 2024 14:35:57 +0800 Subject: [PATCH 04/24] Add label smoothing param in DiceCELoss (#8000) Fixes #7957 ### Description In this modified version I made the following changes: 1. Added `label_smoothing: float = 0.0` parameter in `__init__` method, default value is 0.0. 2. When creating the `self.cross_entropy` instance, pass the `label_smoothing` parameter to `nn.CrossEntropyLoss`. 3. Added `self.label_smoothing = label_smoothing` in the `__init__` method to save this parameter for access when needed. For example: ``` from monai.losses import DiceCELoss # Before criterion = DiceCELoss() criterion.cross_entropy.label_smoothing = 0.1 # Now criterion = DiceCELoss(label_smoothing=0.1) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/dice.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 07a38d9572..44cde41e5d 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -666,6 +666,7 @@ def __init__( weight: torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, + label_smoothing: float = 0.0, ) -> None: """ Args: @@ -704,6 +705,9 @@ def __init__( Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. Defaults to 1.0. + label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed + by the given factor to reduce overfitting. + Defaults to 0.0. """ super().__init__() @@ -728,7 +732,12 @@ def __init__( batch=batch, weight=dice_weight, ) - self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) + if pytorch_after(1, 10): + self.cross_entropy = nn.CrossEntropyLoss( + weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) + else: + self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") From 069519dfc89e984dd10d997497e5c7c1aa963b5c Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:00:04 +0800 Subject: [PATCH 05/24] Add include_fc and `use_combined_linear` argument in the `SABlock` (#7996) Fixes #7991 Fixes #7992 ### Description Add `include_fc` and `use_combined_linear` argument in the `SABlock`. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/networks/blocks/crossattention.py | 33 ++-- monai/networks/blocks/selfattention.py | 60 ++++--- monai/networks/blocks/spatialattention.py | 11 +- monai/networks/blocks/transformerblock.py | 8 +- monai/networks/nets/autoencoderkl.py | 73 +++++--- monai/networks/nets/controlnet.py | 36 ++-- monai/networks/nets/diffusion_model_unet.py | 160 +++++++++++++++--- monai/networks/nets/spade_autoencoderkl.py | 22 +++ .../nets/spade_diffusion_model_unet.py | 41 ++++- monai/networks/nets/transformer.py | 34 ++-- tests/test_crossattention.py | 17 +- tests/test_selfattention.py | 66 ++++++-- tests/test_unetr.py | 2 +- 13 files changed, 426 insertions(+), 137 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index daa5abdd56..bdecf63168 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -59,13 +59,12 @@ def __init__( causal (bool, optional): whether to use causal attention. sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only - "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional - parameter size. + parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -109,7 +108,7 @@ def __init__( self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # calculate query, key, values for all heads in batch and move head forward to be the batch dim b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - q = self.to_q(x) + q = self.input_rearrange(self.to_q(x)) kv = context if context is not None else x _, kv_t, _ = kv.size() - k = self.to_k(kv) - v = self.to_v(kv) + k = self.input_rearrange(self.to_k(kv)) + v = self.input_rearrange(self.to_v(kv)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # - k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose( - 1, 2 - ) # Back to (b, nh, t, hs) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined @@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 124c00acc6..ac96b077bd 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple, Union import torch import torch.nn as nn @@ -40,9 +40,11 @@ def __init__( hidden_input_size: int | None = None, causal: bool = False, sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + rel_pos_embedding: str | None = None, + input_size: Tuple | None = None, + attention_dtype: torch.dtype | None = None, + include_fc: bool = True, + use_combined_linear: bool = True, use_flash_attention: bool = False, ) -> None: """ @@ -61,9 +63,10 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -105,9 +108,22 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.qkv: Union[nn.Linear, nn.Identity] + self.to_q: Union[nn.Linear, nn.Identity] + self.to_k: Union[nn.Linear, nn.Identity] + self.to_v: Union[nn.Linear, nn.Identity] + + if use_combined_linear: + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + else: + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.qkv = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -117,6 +133,8 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.include_fc = include_fc + self.use_combined_linear = use_combined_linear self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: @@ -144,8 +162,13 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] + if self.use_combined_linear: + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + else: + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) @@ -153,13 +176,8 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose(1, 2) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -179,7 +197,9 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) - x = self.out_proj(x) + if self.include_fc: + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 1cfafb1585..665442b55e 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module): spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -45,6 +50,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -60,6 +67,8 @@ def __init__( num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 28d9c563ac..05eb3b07ab 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,6 +37,8 @@ def __init__( sequence_length: int | None = None, with_cross_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = True, ) -> None: """ Args: @@ -47,7 +49,9 @@ def __init__( qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. """ @@ -69,6 +73,8 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 35d80e0565..836027796f 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -157,6 +157,10 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -170,6 +174,9 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -220,6 +227,9 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -243,6 +253,9 @@ def __init__( num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -291,6 +304,10 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -305,6 +322,9 @@ def __init__( attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -350,6 +370,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -389,6 +412,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module): with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -480,6 +510,9 @@ def __init__( with_decoder_nonlocal_attn: bool = True, use_checkpoint: bool = False, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -509,6 +542,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = Decoder( spatial_dims=spatial_dims, @@ -521,6 +557,9 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_convtranspose=use_convtranspose, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, @@ -665,27 +704,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.weight"], - old_state_dict[f"{block}.to_k.weight"], - old_state_dict[f"{block}.to_v.weight"], - ], - dim=0, - ) - new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.bias"], - old_state_dict[f"{block}.to_k.bias"], - old_state_dict[f"{block}.to_v.bias"], - ], - dim=0, - ) + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # old version did not have a projection so set these to the identity new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] @@ -698,5 +728,8 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] - self.load_state_dict(new_state_dict) + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) + self.load_state_dict(new_state_dict, strict=True) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ed3654733d..8b08eaae10 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -143,6 +143,10 @@ class ControlNet(nn.Module): upcast_attention: if True, upcast attention operations to full precision. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -163,6 +167,9 @@ def __init__( upcast_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -282,6 +289,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -326,6 +336,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( @@ -441,25 +454,16 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] - - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index a885339d0d..f57fe251d2 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -67,7 +67,9 @@ class DiffusionUNetTransformerBlock(nn.Module): cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. """ @@ -80,6 +82,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -89,6 +93,8 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) @@ -134,6 +140,11 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + """ def __init__( @@ -148,6 +159,9 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -175,6 +189,9 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -529,6 +546,10 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -544,6 +565,9 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -570,6 +594,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -636,7 +663,11 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -656,6 +687,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -688,6 +722,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -745,6 +782,10 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -755,6 +796,9 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -772,6 +816,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -808,6 +855,10 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -822,6 +873,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -844,6 +898,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, @@ -989,6 +1046,10 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1004,6 +1065,9 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1032,6 +1096,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1116,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1136,6 +1207,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1169,6 +1243,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1250,6 +1327,9 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1263,6 +1343,9 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1280,6 +1363,9 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1307,6 +1393,9 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1320,6 +1409,9 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1329,6 +1421,9 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) @@ -1350,6 +1445,9 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1364,6 +1462,9 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1382,6 +1483,9 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1419,9 +1523,13 @@ class DiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1442,6 +1550,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1536,6 +1647,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1553,6 +1667,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -1587,6 +1704,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) @@ -1714,31 +1834,23 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) @@ -1782,6 +1894,9 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1866,6 +1981,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index d5794a9227..cc8909194a 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module): label_nc: number of semantic channels for SPADE normalisation. with_nonlocal_attn: if True use non-local attention block. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -152,6 +156,9 @@ def __init__( label_nc: int, with_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -200,6 +207,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -243,6 +253,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -331,6 +344,9 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -360,6 +376,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = SPADEDecoder( spatial_dims=spatial_dims, @@ -373,6 +392,9 @@ def __init__( label_nc=label_nc, with_nonlocal_attn=with_decoder_nonlocal_attn, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index 75d1687df3..a9609b1d39 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module): resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. spade_intermediate_channels: number of intermediate channels for SPADE block layer + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -342,6 +346,9 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -371,6 +378,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism. + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -477,6 +489,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -510,6 +525,9 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -592,6 +610,9 @@ def get_spade_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return SPADEAttnUpBlock( @@ -608,6 +629,9 @@ def get_spade_up_block( resblock_updown=resblock_updown, num_head_channels=num_head_channels, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return SPADECrossAttnUpBlock( @@ -627,6 +651,7 @@ def get_spade_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) else: return SPADEUpBlock( @@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - spade_intermediate_channels: number of intermediate channels for SPADE block layer + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -691,6 +718,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -783,6 +813,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -799,6 +832,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -834,6 +870,7 @@ def __init__( upcast_attention=upcast_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 1af725abda..3a278c112a 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -62,6 +66,9 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -86,6 +93,9 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] @@ -133,25 +143,15 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] - - # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] - for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn.to_q.weight"], - old_state_dict[f"{block}.attn.to_k.weight"], - old_state_dict[f"{block}.attn.to_v.weight"], - ], - dim=0, - ) + new_state_dict[k] = old_state_dict.pop(k) # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 - for k in old_state_dict: + for k in list(old_state_dict.keys()): if "norm2" in k: - new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) if "norm3" in k: - new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] - + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 44458147d6..e034e42290 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -166,6 +166,21 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand([[True], [False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self, causal): + input_param = {"hidden_size": 128, "num_heads": 1, "causal": causal, "sequence_length": 16 if causal else None} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(1, 16, 128).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3e98f4c5c4..88919fd8b1 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save einops, has_einops = optional_import("einops") @@ -32,20 +32,23 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - for flash_attn in [True, False]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, - "input_size": input_size, - "use_flash_attention": flash_attn, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + "use_flash_attention": True if rel_pos_embedding is None else False, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -175,6 +178,39 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) + @parameterized.expand([[True, False], [True, True], [False, True], [False, False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_script(self, include_fc, use_combined_linear): + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } + net = SABlock(**input_param) + input_shape = (2, 512, 360) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self): + for causal in [True, False]: + input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(2, 512, 360).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 46018d2bc0..1217c9d85f 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -123,7 +123,7 @@ def test_ill_arg(self): ) @parameterized.expand(TEST_CASE_UNETR) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = UNETR(**(input_param)) net.eval() From 6be7b13a901918071be1cf10aee8701d6e751484 Mon Sep 17 00:00:00 2001 From: "Kelvin R." <138339140+K-Rilla@users.noreply.github.com> Date: Fri, 9 Aug 2024 02:29:47 -0700 Subject: [PATCH 06/24] Replaced package "pkg_resources" with "packaging" (#7953) Fixes #7559 . ### Description Replaced "pkg_resources" references with "packaging" in MONAI/monai/utils/module.py & setup.py Changes were made in functions "pytorch_after", "version_leq", "version_geq". ### Types of changes - Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: dedeepyasai Signed-off-by: saelra Signed-off-by: Kelvin R Signed-off-by: ken-ni Signed-off-by: Dureti <98233210+DuretiShemsi@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: saelra Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: dedeepyasai Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ratanachat Saelee <146144408+Saelra@users.noreply.github.com> Co-authored-by: ken-ni Co-authored-by: Dureti <98233210+DuretiShemsi@users.noreply.github.com> --- .github/workflows/pythonapp.yml | 2 +- docs/requirements.txt | 1 + monai/utils/module.py | 7 ++++--- requirements-min.txt | 1 + setup.cfg | 2 ++ setup.py | 4 ++-- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 65f9a4dcf2..3c39166c1e 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -151,7 +151,7 @@ jobs: - name: Install dependencies run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; - python -m pip install --user --upgrade pip setuptools wheel twine + python -m pip install --user --upgrade pip setuptools wheel twine packaging # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml diff --git a/docs/requirements.txt b/docs/requirements.txt index fe415a07b5..ff94f7b6de 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -41,3 +41,4 @@ onnxruntime; python_version <= '3.10' zarr huggingface_hub pyamg>=5.0.0 +packaging diff --git a/monai/utils/module.py b/monai/utils/module.py index 4d28f8d986..251232d62f 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -564,7 +564,7 @@ def version_leq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("pkg_resources", name="packaging") + pkging, has_ver = optional_import("packaging.Version") if has_ver: try: return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs)) @@ -591,7 +591,8 @@ def version_geq(lhs: str, rhs: str) -> bool: """ lhs, rhs = str(lhs), str(rhs) - pkging, has_ver = optional_import("pkg_resources", name="packaging") + pkging, has_ver = optional_import("packaging.Version") + if has_ver: try: return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs)) @@ -629,7 +630,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st if current_ver_string is None: _env_var = os.environ.get("PYTORCH_VER", "") current_ver_string = _env_var if _env_var else torch.__version__ - ver, has_ver = optional_import("pkg_resources", name="parse_version") + ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3) diff --git a/requirements-min.txt b/requirements-min.txt index a091ef0568..21cf9d5e5c 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -4,3 +4,4 @@ setuptools>=50.3.0,<66.0.0,!=60.6.0 ; python_version < "3.12" setuptools>=70.2.0; python_version >= "3.12" coverage>=5.5 parameterized +packaging diff --git a/setup.cfg b/setup.cfg index 202e7b0e24..dfa94fcfa1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -137,6 +137,8 @@ pyyaml = pyyaml fire = fire +packaging = + packaging jsonschema = jsonschema pynrrd = diff --git a/setup.py b/setup.py index b90d9d0976..576743c1f7 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ import sys import warnings -import pkg_resources +from packaging import version from setuptools import find_packages, setup import versioneer @@ -40,7 +40,7 @@ BUILD_CUDA = FORCE_CUDA or (torch.cuda.is_available() and (CUDA_HOME is not None)) - _pt_version = pkg_resources.parse_version(torch.__version__).release + _pt_version = version.parse(torch.__version__).release if _pt_version is None or len(_pt_version) < 3: raise AssertionError("unknown torch version") TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2]) From f8480027888e365d25a1b85429b200dca58b9f19 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Fri, 9 Aug 2024 22:32:38 +0800 Subject: [PATCH 07/24] Add utils for vista3d (#7999) This PR is a part of #7987 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- docs/source/apps.rst | 8 - docs/source/transforms.rst | 3 + monai/transforms/__init__.py | 1 + monai/transforms/utils.py | 183 ++++++++++++++++++ .../utils_morphological_ops.py} | 2 + tests/min_tests.py | 1 + tests/test_morphological_ops.py | 2 +- tests/test_vista3d_utils.py | 133 +++++++++++++ 8 files changed, 324 insertions(+), 9 deletions(-) rename monai/{apps/generation/maisi/utils/morphological_ops.py => transforms/utils_morphological_ops.py} (99%) create mode 100644 tests/test_vista3d_utils.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index c6ba8c0b9a..7fa7b9e9ff 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -261,11 +261,3 @@ FastMRIReader .. autoclass:: monai.apps.nnunet.nnUNetV2Runner :members: - -`Generative AI` ---------------- - -`MAISI Utilities` -~~~~~~~~~~~~~~~~~ -.. automodule:: monai.apps.generation.maisi.utils.morphological_ops - :members: diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index a359821679..637f0873f1 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2310,6 +2310,9 @@ Utilities .. automodule:: monai.transforms.utils_pytorch_numpy_unification :members: +.. automodule:: monai.transforms.utils_morphological_ops + :members: + By Categories ------------- .. toctree:: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index ef1da2d855..9548443768 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -688,6 +688,7 @@ weighted_patch_samples, zero_margins, ) +from .utils_morphological_ops import dilate, erode from .utils_pytorch_numpy_unification import ( allclose, any_np_pt, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d8461d927b..e32bf6fc48 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -22,6 +22,7 @@ import numpy as np import torch +from torch import Tensor import monai from monai.config import DtypeLike, IndexSelection @@ -30,6 +31,7 @@ from monai.networks.utils import meshgrid_ij from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform, Transform, apply_transform +from monai.transforms.utils_morphological_ops import erode from monai.transforms.utils_pytorch_numpy_unification import ( any_np_pt, ascontiguousarray, @@ -65,6 +67,8 @@ min_version, optional_import, pytorch_after, + unsqueeze_left, + unsqueeze_right, ) from monai.utils.enums import TransformBackends from monai.utils.type_conversion import ( @@ -103,6 +107,8 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", + "get_largest_connected_component_mask_point", + "convert_points_to_disc", "remove_small_objects", "img_bounds", "in_bounds", @@ -1172,6 +1178,183 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] +def get_largest_connected_component_mask_point( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + point_coords: NdarrayTensor, + point_labels: NdarrayTensor, + pos_val: Sequence[int] = (1, 3), + neg_val: Sequence[int] = (0, 2), + margins: int = 3, +) -> NdarrayTensor: + """ + Gets the connected component of img_pos and img_neg that include the positive points and + negative points separately. The function is used for combining automatic results with interactive + results in VISTA3D. + + Args: + img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image. + img_neg: same format as img_pos but corresponds to negative points. + pos_val: positive point label values. + neg_val: negative point label values. + point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. + point_labels: the label of each point, shape [B, N]. + """ + + cucim_skimage, has_cucim = optional_import("cucim.skimage") + + use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") + if use_cp: + img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore + img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore + label = cucim_skimage.measure.label + lib = cp + else: + if not has_measure: + raise RuntimeError("skimage.measure required.") + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + # for skimage.measure.label, the input must be bool type + if img_pos_.dtype != bool or img_neg_.dtype != bool: + raise ValueError("img_pos and img_neg must be bool type.") + label = measure.label + lib = np + + features_pos, _ = label(img_pos_, connectivity=3, return_num=True) + features_neg, _ = label(img_neg_, connectivity=3, return_num=True) + + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] in pos_val: + features = features_pos + elif point_labels[bs, i] in neg_val: + features = features_neg + else: + # if -1 padding point, skip + continue + for margin in range(margins): + if isinstance(p, np.ndarray): + x, y, z = np.round(p).astype(int).tolist() + else: + x, y, z = p.float().round().int().tolist() + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): + index = features[bs, 0, l:r, t:d, f:b].max() + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + + +def convert_points_to_disc( + image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False +): + """ + Convert a 3D point coordinates into image mask. The returned mask has the same spatial + size as `image_size` while the batch dimension is the same as 'point' batch dimension. + The point is converted to a mask ball with radius defined by `radius`. The output + contains two channels each for negative (first channel) and positive points. + + Args: + image_size: The output size of the converted mask. It should be a 3D tuple. + point: [B, N, 3], 3D point coordinates. + point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points. + radius: disc ball radius size. + disc: If true, use regular disc, other use gaussian. + """ + masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) + _array = [ + torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) + ] + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] + coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) + coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) + for b, n in np.ndindex(*point.shape[:2]): + point_bn = unsqueeze_right(point[b, n], 6) + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + if disc: + masks[b, channel] += pow_diff.sum(0) < radius**2 + else: + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) + return masks + + +def sample_points_from_label( + labels: Tensor, + label_set: Sequence[int], + max_ppoint: int = 1, + max_npoint: int = 0, + device: torch.device | str | None = "cpu", + use_center: bool = False, +): + """Sample points from labels. + + Args: + labels: [1, 1, H, W, D] + label_set: local index, must match values in labels. + max_ppoint: maximum positive point samples. + max_npoint: maximum negative point samples. + device: returned tensor device. + use_center: whether to sample points from center. + + Returns: + point: point coordinates of [B, N, 3]. B equals to the length of label_set. + point_label: [B, N], always 0 for negative, 1 for positive. + """ + if not labels.shape[0] == 1: + raise ValueError("labels must have batch size 1.") + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + labels = labels[0, 0] + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0]) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + num_p = min(len(plabelpoints), max_ppoint) + num_n = min(len(nlabelpoints), max_npoint) + pad = max_ppoint + max_npoint - num_p - num_n + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices_tensor = torch.sort(pdis) + sorted_indices = sorted_indices_tensor.cpu().tolist() + else: + sorted_indices = list(range(len(plabelpoints))) + random.shuffle(sorted_indices) + _point.append( + torch.stack( + [plabelpoints[sorted_indices[i]] for i in range(num_p)] + + random.choices(nlabelpoints, k=num_n) + + [torch.tensor([0, 0, 0], device=device)] * pad + ) + ) + _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device)) + else: + # pad the background labels + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) + point = torch.stack(_point) + point_label = torch.stack(_point_label) + + return point, point_label + + def remove_small_objects( img: NdarrayTensor, min_size: int = 64, diff --git a/monai/apps/generation/maisi/utils/morphological_ops.py b/monai/transforms/utils_morphological_ops.py similarity index 99% rename from monai/apps/generation/maisi/utils/morphological_ops.py rename to monai/transforms/utils_morphological_ops.py index 14786d60a2..b3134c1865 100644 --- a/monai/apps/generation/maisi/utils/morphological_ops.py +++ b/monai/transforms/utils_morphological_ops.py @@ -20,6 +20,8 @@ from monai.config import NdarrayOrTensor from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep +__all__ = ["erode", "dilate"] + def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor: """ diff --git a/tests/min_tests.py b/tests/min_tests.py index 3a143df84b..479c4c8dc2 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -209,6 +209,7 @@ def run_testsuit(): "test_zarr_avg_merger", "test_perceptual_loss", "test_ultrasound_confidence_map_transform", + "test_vista3d_utils", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py index 6f29415759..422e8c4b9d 100644 --- a/tests/test_morphological_ops.py +++ b/tests/test_morphological_ops.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t +from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_SHAPE = [] diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py new file mode 100644 index 0000000000..a940854d88 --- /dev/null +++ b/tests/test_vista3d_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.utils import ( + convert_points_to_disc, + get_largest_connected_component_mask_point, + sample_points_from_label, +) +from monai.utils import min_version +from monai.utils.module import optional_import +from tests.utils import skip_if_no_cuda, skip_if_quick + +cp, has_cp = optional_import("cupy") +cucim_skimage, has_cucim = optional_import("cucim.skimage") +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TESTS_SAMPLE_POINTS_FROM_LABEL = [] +for use_center in [True, False]: + labels = torch.zeros(1, 1, 32, 32, 32) + labels[0, 0, 5:10, 5:10, 5:10] = 1 + labels[0, 0, 10:15, 10:15, 10:15] = 3 + labels[0, 0, 20:25, 20:25, 20:25] = 5 + TESTS_SAMPLE_POINTS_FROM_LABEL.append( + [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)] + ) + +TEST_CONVERT_POINTS_TO_DISC = [] +for radius in [1, 2]: + for disc in [True, False]: + image_size = (32, 32, 32) + point = torch.randn(3, 1, 3) + point_label = torch.randint(0, 4, (3, 1)) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_LCC_MASK_POINT_TORCH = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 128, 32, 32) + TEST_LCC_MASK_POINT_TORCH.append( + [ + { + "img_pos": torch.randint(0, 2, shape, dtype=torch.bool), + "img_neg": torch.randint(0, 2, shape, dtype=torch.bool), + "point_coords": torch.randint(0, 10, (bs, num_points, 3)), + "point_labels": torch.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + +TEST_LCC_MASK_POINT_NP = [] +for bs in [1, 2]: + for num_points in [1, 3]: + shape = (bs, 1, 32, 32, 64) + TEST_LCC_MASK_POINT_NP.append( + [ + { + "img_pos": np.random.randint(0, 2, shape, dtype=bool), + "img_neg": np.random.randint(0, 2, shape, dtype=bool), + "point_coords": np.random.randint(0, 5, (bs, num_points, 3)), + "point_labels": np.random.randint(0, 4, (bs, num_points)), + }, + shape, + ] + ) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestSamplePointsFromLabel(unittest.TestCase): + + @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL) + def test_shape(self, input_data, expected_point_shape, expected_point_label_shape): + point, point_label = sample_points_from_label(**input_data) + self.assertEqual(point.shape, expected_point_shape) + self.assertEqual(point_label.shape, expected_point_label_shape) + + +class TestConvertPointsToDisc(unittest.TestCase): + + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC) + def test_shape(self, input_data, expected_shape): + result = convert_points_to_disc(**input_data) + self.assertEqual(result.shape, expected_shape) + + +@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") +class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): + + @skip_if_quick + @skip_if_no_cuda + @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_TORCH) + def test_cp_shape(self, input_data, shape): + for key in input_data: + input_data[key] = input_data[key].to(device) + mask = get_largest_connected_component_mask_point(**input_data) + self.assertEqual(mask.shape, shape) + + @skipUnless(has_measure, "skimage required") + @parameterized.expand(TEST_LCC_MASK_POINT_NP) + def test_np_shape(self, input_data, shape): + mask = get_largest_connected_component_mask_point(**input_data) + self.assertEqual(mask.shape, shape) + + +if __name__ == "__main__": + unittest.main() From 62430315cd176478d06dd197b7a6dfdd6cd90c44 Mon Sep 17 00:00:00 2001 From: myron Date: Sat, 10 Aug 2024 06:34:27 -0700 Subject: [PATCH 08/24] Adding a network CellSamWrapper (#7981) Adding a network CellSamWrapper, a thin wrapper around SAM, which can be used for 2D segmentation tasks. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: am Signed-off-by: myron Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: am Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/installation.md | 4 +- monai/networks/nets/cell_sam_wrapper.py | 92 +++++++++++++++++++++++++ requirements-dev.txt | 1 + setup.cfg | 5 +- tests/test_cell_sam_wrapper.py | 58 ++++++++++++++++ 5 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 monai/networks/nets/cell_sam_wrapper.py create mode 100644 tests/test_cell_sam_wrapper.py diff --git a/docs/source/installation.md b/docs/source/installation.md index 4308a07647..70a8b6f1d4 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -254,10 +254,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub] +[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips, pynvml, huggingface_hub, segment-anything] ``` which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub`, `pyamg` and `segment-anything` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py new file mode 100644 index 0000000000..308c3a6bcb --- /dev/null +++ b/monai/networks/nets/cell_sam_wrapper.py @@ -0,0 +1,92 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional as F + +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +_all__ = ["CellSamWrapper"] + + +class CellSamWrapper(torch.nn.Module): + """ + CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything + with an image only decoder, that can be used for segmentation tasks. + + + Args: + auto_resize_inputs: whether to resize inputs before passing to the network. + (usually they need be resized, unless they are already at the expected size) + network_resize_roi: expected input size for the network. + (currently SAM expects 1024x1024) + checkpoint: checkpoint file to load the SAM weights from. + (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) + return_features: whether to return features from SAM encoder + (without using decoder/upsampling to the original input size) + + """ + + def __init__( + self, + auto_resize_inputs=True, + network_resize_roi=(1024, 1024), + checkpoint="sam_vit_b_01ec64.pth", + return_features=False, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + + self.network_resize_roi = network_resize_roi + self.auto_resize_inputs = auto_resize_inputs + self.return_features = return_features + + if not has_sam: + raise ValueError( + "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git" + ) + + model = build_sam_vit_b(checkpoint=checkpoint) + + model.prompt_encoder = None + model.mask_decoder = None + + model.mask_decoder = nn.Sequential( + nn.BatchNorm2d(num_features=256), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(num_features=128), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True), + ) + + self.model = model + + def forward(self, x): + sh = x.shape[2:] + + if self.auto_resize_inputs: + x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") + + x = self.model.image_encoder(x) + + if not self.return_features: + x = self.model.mask_decoder(x) + if self.auto_resize_inputs: + x = F.interpolate(x, size=sh, mode="bilinear") + + return x diff --git a/requirements-dev.txt b/requirements-dev.txt index 72ba210093..76f1952345 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -59,3 +59,4 @@ nvidia-ml-py huggingface_hub pyamg>=5.0.0 git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd +git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 diff --git a/setup.cfg b/setup.cfg index dfa94fcfa1..e240445e36 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,6 +85,7 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0 + segment-anything nibabel = nibabel ninja = @@ -162,11 +163,13 @@ pynvml = nvidia-ml-py # # workaround https://github.com/Project-MONAI/MONAI/issues/5882 # MetricsReloaded = -# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded + # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 +segment-anything = + segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py new file mode 100644 index 0000000000..2f1ee2b901 --- /dev/null +++ b/tests/test_cell_sam_wrapper.py @@ -0,0 +1,58 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.cell_sam_wrapper import CellSamWrapper +from monai.utils import optional_import + +build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b") + +device = "cuda" if torch.cuda.is_available() else "cpu" +TEST_CASE_CELLSEGWRAPPER = [] +for dims in [128, 256, 512, 1024]: + test_case = [ + {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None}, + (1, 3, *([dims] * 2)), + (1, 3, *([dims] * 2)), + ] + TEST_CASE_CELLSEGWRAPPER.append(test_case) + + +@unittest.skipUnless(has_sam, "Requires SAM installation") +class TestResNetDS(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CELLSEGWRAPPER) + def test_shape(self, input_param, input_shape, expected_shape): + net = CellSamWrapper(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device)) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + def test_ill_arg0(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device) + net(torch.randn([1, 3, 256, 256]).to(device)) + + def test_ill_arg1(self): + with self.assertRaises(RuntimeError): + net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device) + net(torch.randn([1, 3, 1024, 1024]).to(device)) + + +if __name__ == "__main__": + unittest.main() From 250c18d71b39f414d4ef91e353d69ca8c2ce3f92 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Sun, 11 Aug 2024 23:46:48 -0400 Subject: [PATCH 09/24] Refactor DiffusionModelUNetMaisi (#7989) Fixes #7988 . ### Description Refactor DiffusionModelUNetMaisi to use monai core components. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Pengfei Guo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../networks/diffusion_model_unet_maisi.py | 36 +++++++++---------- tests/test_diffusion_model_unet_maisi.py | 7 +--- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py index d5f5f6136b..e990b5fc98 100644 --- a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py +++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py @@ -37,21 +37,15 @@ from torch import nn from monai.networks.blocks import Convolution -from monai.utils import ensure_tuple_rep, optional_import -from monai.utils.type_conversion import convert_to_tensor - -get_down_block, has_get_down_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_down_block" -) -get_mid_block, has_get_mid_block = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_mid_block" -) -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" +from monai.networks.nets.diffusion_model_unet import ( + get_down_block, + get_mid_block, + get_timestep_embedding, + get_up_block, + zero_module, ) -get_up_block, has_get_up_block = optional_import("generative.networks.nets.diffusion_model_unet", name="get_up_block") -xformers, has_xformers = optional_import("xformers") -zero_module, has_zero_module = optional_import("generative.networks.nets.diffusion_model_unet", name="zero_module") +from monai.utils import ensure_tuple_rep +from monai.utils.type_conversion import convert_to_tensor __all__ = ["DiffusionModelUNetMaisi"] @@ -78,6 +72,8 @@ class DiffusionModelUNetMaisi(nn.Module): cross_attention_dim: Number of context dimensions to use. num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: If True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers. include_top_region_index_input: If True, use top region index input. @@ -102,6 +98,8 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, dropout_cattn: float = 0.0, include_top_region_index_input: bool = False, @@ -152,9 +150,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - if use_flash_attention is True and not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." @@ -210,7 +205,6 @@ def __init__( input_channel = output_channel output_channel = num_channels[i] is_final_block = i == len(num_channels) - 1 - down_block = get_down_block( spatial_dims=spatial_dims, in_channels=input_channel, @@ -227,6 +221,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -245,6 +241,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -280,6 +278,8 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py index 059a4a4ba8..f9384e6d82 100644 --- a/tests/test_diffusion_model_unet_maisi.py +++ b/tests/test_diffusion_model_unet_maisi.py @@ -17,14 +17,11 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi from monai.networks import eval_mode from monai.utils import optional_import _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi UNCOND_CASES_2D = [ [ @@ -291,7 +288,6 @@ ] -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) @@ -510,7 +506,6 @@ def test_shape_with_additional_inputs(self, input_param): self.assertEqual(result.shape, (1, 1, 16, 16)) -@skipUnless(has_generative, "monai-generative required") class TestDiffusionModelUNetMaisi3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) From 7a6f680642e4fba4ac6465237292f43f5755e869 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 12 Aug 2024 20:46:23 +0800 Subject: [PATCH 10/24] Remove segment-anything in setup.cfg (#8010) Fixes #8009 ### Description Remove segment-anything in setup.cfg ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/release.yml | 4 ++-- setup.cfg | 5 ++--- tests/ngc_bundle_download.py | 2 +- tests/test_handler_ignite_metric.py | 4 ++-- tests/test_patchembedding.py | 12 ++++++------ tests/test_unetr.py | 12 ++++++------ tests/test_vitautoenc.py | 10 +++++----- 7 files changed, 24 insertions(+), 25 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60b610565e..a014a4ed1d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -24,7 +24,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install setuptools run: | - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging - name: Build and test source archive and wheel file run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; @@ -104,7 +104,7 @@ jobs: run: | find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; git describe - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging python setup.py build cat build/lib/monai/_version.py - name: Upload version diff --git a/setup.cfg b/setup.cfg index e240445e36..2115c30a7f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,7 +85,6 @@ all = nvidia-ml-py huggingface_hub pyamg>=5.0.0 - segment-anything nibabel = nibabel ninja = @@ -168,8 +167,8 @@ huggingface_hub = huggingface_hub pyamg = pyamg>=5.0.0 -segment-anything = - segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything +# segment-anything = +# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py index 01dc044870..107114861c 100644 --- a/tests/ngc_bundle_download.py +++ b/tests/ngc_bundle_download.py @@ -127,7 +127,7 @@ def test_loading_mmar(self, item): in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), - pos_embed="conv", + proj_type="conv", hidden_size=768, mlp_dim=3072, ) diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py index 28e0b69621..3e42bda35d 100644 --- a/tests/test_handler_ignite_metric.py +++ b/tests/test_handler_ignite_metric.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from monai.handlers import IgniteMetric, IgniteMetricHandler, from_engine +from monai.handlers import IgniteMetricHandler, from_engine from monai.losses import DiceLoss from monai.metrics import LossMetric from tests.utils import SkipIfNoModule, assert_allclose, optional_import @@ -172,7 +172,7 @@ def _val_func(engine, batch): @parameterized.expand(TEST_CASES[0:2]) def test_old_ignite_metric(self, input_param, input_data, expected_val): loss_fn = DiceLoss(**input_param) - ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) + ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"])) def _val_func(engine, batch): pass diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index d059145033..71ac767966 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -43,7 +43,7 @@ "patch_size": (patch_size,) * nd, "hidden_size": hidden_size, "num_heads": num_heads, - "pos_embed": proj_type, + "proj_type": proj_type, "pos_embed_type": pos_embed_type, "dropout_rate": dropout_rate, }, @@ -127,7 +127,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=128, num_heads=12, - pos_embed="conv", + proj_type="conv", pos_embed_type="sincos", dropout_rate=5.0, ) @@ -139,7 +139,7 @@ def test_ill_arg(self): patch_size=(64, 64, 64), hidden_size=512, num_heads=8, - pos_embed="perceptron", + proj_type="perceptron", pos_embed_type="sincos", dropout_rate=0.3, ) @@ -151,7 +151,7 @@ def test_ill_arg(self): patch_size=(8, 8, 8), hidden_size=512, num_heads=14, - pos_embed="conv", + proj_type="conv", dropout_rate=0.3, ) @@ -162,7 +162,7 @@ def test_ill_arg(self): patch_size=(4, 4, 4), hidden_size=768, num_heads=8, - pos_embed="perceptron", + proj_type="perceptron", dropout_rate=0.3, ) with self.assertRaises(ValueError): @@ -183,7 +183,7 @@ def test_ill_arg(self): patch_size=(16, 16, 16), hidden_size=768, num_heads=12, - pos_embed="perc", + proj_type="perc", dropout_rate=0.3, ) diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 1217c9d85f..8c5ecb32e1 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -30,7 +30,7 @@ for num_heads in [8]: for mlp_dim in [3072]: for norm_name in ["instance"]: - for pos_embed in ["perceptron"]: + for proj_type in ["perceptron"]: for nd in (2, 3): test_case = [ { @@ -42,7 +42,7 @@ "norm_name": norm_name, "mlp_dim": mlp_dim, "num_heads": num_heads, - "pos_embed": pos_embed, + "proj_type": proj_type, "dropout_rate": dropout_rate, "conv_block": True, "res_block": False, @@ -75,7 +75,7 @@ def test_ill_arg(self): hidden_size=128, mlp_dim=3072, num_heads=12, - pos_embed="conv", + proj_type="conv", norm_name="instance", dropout_rate=5.0, ) @@ -89,7 +89,7 @@ def test_ill_arg(self): hidden_size=512, mlp_dim=3072, num_heads=12, - pos_embed="conv", + proj_type="conv", norm_name="instance", dropout_rate=0.5, ) @@ -103,7 +103,7 @@ def test_ill_arg(self): hidden_size=512, mlp_dim=3072, num_heads=14, - pos_embed="conv", + proj_type="conv", norm_name="batch", dropout_rate=0.4, ) @@ -117,7 +117,7 @@ def test_ill_arg(self): hidden_size=768, mlp_dim=3072, num_heads=12, - pos_embed="perc", + proj_type="perc", norm_name="instance", dropout_rate=0.2, ) diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py index c68c583a0e..9a503948d0 100644 --- a/tests/test_vitautoenc.py +++ b/tests/test_vitautoenc.py @@ -23,7 +23,7 @@ for in_channels in [1, 4]: for img_size in [64, 96, 128]: for patch_size in [16]: - for pos_embed in ["conv", "perceptron"]: + for proj_type in ["conv", "perceptron"]: for nd in [2, 3]: test_case = [ { @@ -34,7 +34,7 @@ "mlp_dim": 3072, "num_layers": 4, "num_heads": 12, - "pos_embed": pos_embed, + "proj_type": proj_type, "dropout_rate": 0.6, "spatial_dims": nd, }, @@ -54,7 +54,7 @@ "mlp_dim": 3072, "num_layers": 4, "num_heads": 12, - "pos_embed": "conv", + "proj_type": "conv", "dropout_rate": 0.6, "spatial_dims": 3, }, @@ -93,7 +93,7 @@ def test_shape(self, input_param, input_shape, expected_shape): ] ) def test_ill_arg( - self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, pos_embed, dropout_rate + self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, proj_type, dropout_rate ): with self.assertRaises(ValueError): ViTAutoEnc( @@ -104,7 +104,7 @@ def test_ill_arg( mlp_dim=mlp_dim, num_layers=num_layers, num_heads=num_heads, - pos_embed=pos_embed, + proj_type=proj_type, dropout_rate=dropout_rate, ) From 68581146502b3f0c9c876b12902df3197b6aa98a Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Mon, 12 Aug 2024 23:54:48 -0400 Subject: [PATCH 11/24] Refactor AutoencoderKlMaisi (#7993) Fixes #7988 . ### Description Refactor AutoencoderKlMaisi to use monai core components. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Pengfei Guo Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .../maisi/networks/autoencoderkl_maisi.py | 68 ++++++++++++------- monai/networks/nets/autoencoderkl.py | 4 +- tests/test_autoencoderkl_maisi.py | 6 +- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py index f27f73ec60..a52274b24a 100644 --- a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py +++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py @@ -13,25 +13,17 @@ import gc import logging -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from monai.networks.blocks import Convolution -from monai.utils import optional_import +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL from monai.utils.type_conversion import convert_to_tensor -AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock") -AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL") -ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock") - -if TYPE_CHECKING: - from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType -else: - AutoencoderKLType = cast(type, AutoencoderKL) - # Set up logging configuration logger = logging.getLogger(__name__) @@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module): in_channels: Number of input channels. num_channels: Sequence of block output channels. out_channels: Number of channels in the bottom layer (latent space) of the autoencoder. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. num_splits: Number of splits for the input tensor. dim_split: Dimension of splitting for the input tensor. @@ -547,6 +541,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -603,11 +599,13 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -626,7 +624,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -636,16 +634,18 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=num_channels[-1], norm_num_groups=norm_num_groups, @@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module): num_channels: Sequence of block output channels. in_channels: Number of channels in the bottom layer (latent space) of the autoencoder. out_channels: Number of output channels. - num_res_blocks: Number of residual blocks (see ResBlock) per level. + num_res_blocks: Number of residual blocks (see AEKLResBlock) per level. norm_num_groups: Number of groups for the group norm layers. norm_eps: Epsilon for the normalization. attention_levels: Indicate which level from num_channels contain an attention block. with_nonlocal_attn: If True, use non-local attention block. + include_fc: whether to include the final linear layer in the attention block. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. num_splits: Number of splits for the input tensor. @@ -729,6 +731,8 @@ def __init__( print_info: bool = False, save_mem: bool = True, with_nonlocal_attn: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_convtranspose: bool = False, ) -> None: @@ -758,7 +762,7 @@ def __init__( if with_nonlocal_attn: blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -767,16 +771,18 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) blocks.append( - ResBlock( + AEKLResBlock( spatial_dims=spatial_dims, in_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -812,11 +818,13 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) ) @@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class AutoencoderKlMaisi(AutoencoderKLType): +class AutoencoderKlMaisi(AutoencoderKL): """ AutoencoderKL with custom MaisiEncoder and MaisiDecoder. @@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType): norm_eps: Epsilon for the normalization. with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder. with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. use_flash_attention: If True, use flash attention for a memory efficient attention mechanism. use_checkpointing: If True, use activation checkpointing. use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder. @@ -909,6 +919,8 @@ def __init__( norm_eps: float = 1e-6, with_encoder_nonlocal_attn: bool = False, with_decoder_nonlocal_attn: bool = False, + include_fc: bool = False, + use_combined_linear: bool = False, use_flash_attention: bool = False, use_checkpointing: bool = False, use_convtranspose: bool = False, @@ -930,12 +942,14 @@ def __init__( norm_eps, with_encoder_nonlocal_attn, with_decoder_nonlocal_attn, - use_flash_attention, use_checkpointing, use_convtranspose, + include_fc, + use_combined_linear, + use_flash_attention, ) - self.encoder = MaisiEncoder( + self.encoder: nn.Module = MaisiEncoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, @@ -945,6 +959,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, num_splits=num_splits, dim_split=dim_split, @@ -953,7 +969,7 @@ def __init__( save_mem=save_mem, ) - self.decoder = MaisiDecoder( + self.decoder: nn.Module = MaisiDecoder( spatial_dims=spatial_dims, num_channels=num_channels, in_channels=latent_channels, @@ -963,6 +979,8 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, use_convtranspose=use_convtranspose, num_splits=num_splits, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 836027796f..af191e748b 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -532,7 +532,7 @@ def __init__( "`num_channels`." ) - self.encoder = Encoder( + self.encoder: nn.Module = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, channels=channels, @@ -546,7 +546,7 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - self.decoder = Decoder( + self.decoder: nn.Module = Decoder( spatial_dims=spatial_dims, channels=channels, in_channels=latent_channels, diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py index 392a3d7db2..0e9f427fb6 100644 --- a/tests/test_autoencoderkl_maisi.py +++ b/tests/test_autoencoderkl_maisi.py @@ -16,16 +16,13 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion tqdm, has_tqdm = optional_import("tqdm", name="tqdm") _, has_einops = optional_import("einops") -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -79,7 +76,6 @@ CASES = CASES_NO_ATTENTION -@unittest.skipUnless(has_generative, "monai-generative required") class TestAutoencoderKlMaisi(unittest.TestCase): @parameterized.expand(CASES) From 9dbfe160635312069dced0f6babad6f89d8dc8e7 Mon Sep 17 00:00:00 2001 From: Pengfei Guo <32000655+guopengf@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:30:40 -0400 Subject: [PATCH 12/24] Refactor ControlNetMaisi (#8005) Fixes #7988 . ### Description Refactor ControlNetMaisi to use monai core components. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Pengfei Guo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .../maisi/networks/controlnet_maisi.py | 33 +++++++++---------- monai/networks/nets/controlnet.py | 10 +++--- tests/test_controlnet_maisi.py | 19 ++++++----- 3 files changed, 29 insertions(+), 33 deletions(-) diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py index 3641124b7d..269086d971 100644 --- a/monai/apps/generation/maisi/networks/controlnet_maisi.py +++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py @@ -11,24 +11,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, cast +from typing import Sequence import torch -from monai.utils import optional_import +from monai.networks.nets.controlnet import ControlNet +from monai.networks.nets.diffusion_model_unet import get_timestep_embedding -ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet") -get_timestep_embedding, has_get_timestep_embedding = optional_import( - "generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding" -) -if TYPE_CHECKING: - from generative.networks.nets.controlnet import ControlNet as ControlNetType -else: - ControlNetType = cast(type, ControlNet) - - -class ControlNetMaisi(ControlNetType): +class ControlNetMaisi(ControlNet): """ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image Diffusion Models" (https://arxiv.org/abs/2302.05543) @@ -49,10 +40,12 @@ class ControlNetMaisi(ControlNetType): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. use_checkpointing: if True, use activation checkpointing to save memory. + include_fc: whether to include the final linear layer. Default to False. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -71,10 +64,12 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, - conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), use_checkpointing: bool = True, + include_fc: bool = False, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__( spatial_dims, @@ -91,9 +86,11 @@ def __init__( cross_attention_dim, num_class_embeds, upcast_attention, - use_flash_attention, conditioning_embedding_in_channels, conditioning_embedding_num_channels, + include_fc, + use_combined_linear, + use_flash_attention, ) self.use_checkpointing = use_checkpointing @@ -105,7 +102,7 @@ def forward( conditioning_scale: float = 1.0, context: torch.Tensor | None = None, class_labels: torch.Tensor | None = None, - ) -> tuple[Sequence[torch.Tensor], torch.Tensor]: + ) -> tuple[list[torch.Tensor], torch.Tensor]: emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels) h = self._apply_initial_convolution(x) if self.use_checkpointing: diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index 8b08eaae10..8b8813597f 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -174,24 +174,22 @@ def __init__( super().__init__() if with_conditioning is True and cross_attention_dim is None: raise ValueError( - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " "to be specified when with_conditioning=True." ) if cross_attention_dim is not None and with_conditioning is False: - raise ValueError( - "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." - ) + raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.") # All number of channels should be multiple of num_groups if any((out_channel % norm_num_groups) != 0 for out_channel in channels): raise ValueError( - f"DiffusionModelUNet expects all channels to be a multiple of norm_num_groups, but got" + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" f" channels={channels} and norm_num_groups={norm_num_groups}" ) if len(channels) != len(attention_levels): raise ValueError( - f"DiffusionModelUNet expects channels to have the same length as attention_levels, but got " + f"ControlNet expects channels to have the same length as attention_levels, but got " f"channels={channels} and attention_levels={attention_levels}" ) diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py index 7b0e69f2c8..bfdf25ec6e 100644 --- a/tests/test_controlnet_maisi.py +++ b/tests/test_controlnet_maisi.py @@ -17,14 +17,12 @@ import torch from parameterized import parameterized +from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi from monai.networks import eval_mode from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion -_, has_generative = optional_import("generative") - -if has_generative: - from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi +_, has_einops = optional_import("einops") TEST_CASES = [ [ @@ -103,8 +101,8 @@ TEST_CASES_ERROR = [ [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None}, - "ControlNet expects dimension of the cross-attention conditioning " - "(cross_attention_dim) when using with_conditioning.", + "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "to be specified when with_conditioning=True.", ], [ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2}, @@ -112,7 +110,8 @@ ], [ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16}, - "ControlNet expects all num_channels being multiple of norm_num_groups", + f"ControlNet expects all channels to be a multiple of norm_num_groups, but got" + f" channels={(8, 16)} and norm_num_groups={16}", ], [ { @@ -122,16 +121,17 @@ "attention_levels": (True,), "norm_num_groups": 8, }, - "ControlNet expects num_channels being same size of attention_levels", + f"ControlNet expects channels to have the same length as attention_levels, but got " + f"channels={(8, 16)} and attention_levels={(True,)}", ], ] @SkipIfBeforePyTorchVersion((2, 0)) -@skipUnless(has_generative, "monai-generative required") class TestControlNet(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): @@ -145,6 +145,7 @@ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_ self.assertEqual(result[1].shape, expected_shape) @parameterized.expand(TEST_CASES_CONDITIONAL) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape): net = ControlNetMaisi(**input_param) with eval_mode(net): From 34ce94db424445b38eb56a6c842e55a2122d4a9d Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:44:53 +0800 Subject: [PATCH 13/24] Fix ci issue in test_vit (#8013) Fixes #8012 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- requirements-dev.txt | 1 - tests/test_vit.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 76f1952345..9aad0804e6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -58,5 +58,4 @@ lpips==0.1.4 nvidia-ml-py huggingface_hub pyamg>=5.0.0 -git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 diff --git a/tests/test_vit.py b/tests/test_vit.py index d638c0116a..a3ffd0b2ef 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -106,7 +106,7 @@ def test_ill_arg( ) @parameterized.expand(TEST_CASE_Vit[:1]) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = ViT(**(input_param)) net.eval() From 4877767cf92649a38ffda0fc590f2b92ba59f019 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 14 Aug 2024 09:55:53 +0800 Subject: [PATCH 14/24] Fix module can not import correctly issue (#8015) Fixes #8014 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/utils/module.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 251232d62f..1ac8140b39 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -13,7 +13,6 @@ import enum import functools -import importlib.util import os import pdb import re @@ -209,11 +208,13 @@ def load_submodules( ): if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None: try: + mod = import_module(name) mod_spec = importer.find_spec(name) # type: ignore if mod_spec and mod_spec.loader: - mod = importlib.util.module_from_spec(mod_spec) - mod_spec.loader.exec_module(mod) + loader = mod_spec.loader + loader.exec_module(mod) submodules.append(mod) + except OptionalImportError: pass # could not import the optional deps., they are ignored except ImportError as e: From e85580af2267404ff0f022e5372a44e8effe6316 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 15 Aug 2024 12:40:21 +0800 Subject: [PATCH 15/24] Fix 'torch.device' object has no attribute 'gpu_id' issue in trt export (#8019) Part of #8017 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/utils.py | 6 +++--- monai/utils/module.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 6a97434215..f301c2dd5c 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -822,7 +822,7 @@ def _onnx_trt_compile( output_names = [] if not output_names else output_names # set up the TensorRT builder - torch_tensorrt.set_device(device) + torch.cuda.set_device(device) logger = trt.Logger(trt.Logger.WARNING) builder = trt.Builder(logger) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) @@ -931,7 +931,7 @@ def convert_to_trt( warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.") device = device if device else 0 - target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0") + target_device = torch.device(f"cuda:{device}") convert_precision = torch.float32 if precision == "fp32" else torch.half inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)] @@ -986,7 +986,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): ir_model, inputs=input_placeholder, enabled_precisions=convert_precision, - device=target_device, + device=torch_tensorrt.Device(f"cuda:{device}"), ir="torchscript", **kwargs, ) diff --git a/monai/utils/module.py b/monai/utils/module.py index 1ac8140b39..78087aef84 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -214,7 +214,6 @@ def load_submodules( loader = mod_spec.loader loader.exec_module(mod) submodules.append(mod) - except OptionalImportError: pass # could not import the optional deps., they are ignored except ImportError as e: From 77304dd114227b8b9b1059665aedba295db0ffc7 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Thu, 15 Aug 2024 01:50:33 -0500 Subject: [PATCH 16/24] Add vista network (#7987) Fixes # . ### Description Add VISTA3D model architecture to MONAI core ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 Signed-off-by: Yufan He Signed-off-by: Yiheng Wang Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 10 + monai/networks/nets/__init__.py | 3 +- monai/networks/nets/segresnet_ds.py | 128 +++- monai/networks/nets/vista3d.py | 908 ++++++++++++++++++++++++++++ monai/transforms/utils.py | 4 +- tests/test_segresnet_ds.py | 86 ++- tests/test_vista3d.py | 85 +++ 7 files changed, 1189 insertions(+), 35 deletions(-) create mode 100644 monai/networks/nets/vista3d.py create mode 100644 tests/test_vista3d.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 249375dfc1..1810fec49b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -481,6 +481,11 @@ Nets .. autoclass:: SegResNetDS :members: +`SegResNetDS2` +~~~~~~~~~~~~~~ +.. autoclass:: SegResNetDS2 + :members: + `SegResNetVAE` ~~~~~~~~~~~~~~ .. autoclass:: SegResNetVAE @@ -556,6 +561,11 @@ Nets .. autoclass:: UNETR :members: +`VISTA3D` +~~~~~~~~~ +.. autoclass:: VISTA3D + :members: + `SwinUNETR` ~~~~~~~~~~~ .. autoclass:: SwinUNETR diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c777fe6442..0570c9fcc1 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -76,7 +76,7 @@ resnet200, ) from .segresnet import SegResNet, SegResNetVAE -from .segresnet_ds import SegResNetDS +from .segresnet_ds import SegResNetDS, SegResNetDS2 from .senet import ( SENet, SEnet, @@ -118,6 +118,7 @@ from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder +from .vista3d import VISTA3D, vista3d132 from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 6430f5fdc9..1ac5a79ee3 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -11,6 +11,7 @@ from __future__ import annotations +import copy from collections.abc import Callable from typing import Union @@ -23,7 +24,7 @@ from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import UpsampleMode, has_option -__all__ = ["SegResNetDS"] +__all__ = ["SegResNetDS", "SegResNetDS2"] def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): @@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]: return self._forward(x) + + +class SegResNetDS2(SegResNetDS): + """ + SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D + `_. + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + blocks_up: number of upsample blocks (optional). + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. + At dsdepth==1,only a single output is returned. + preprocess: optional callable function to apply before the model's forward pass + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring + image spacing into an approximately isotropic space. + Otherwise, by default, the kernel size and downsampling is always isotropic. + + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + blocks_up: tuple | None = None, + dsdepth: int = 1, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=init_filters, + in_channels=in_channels, + out_channels=out_channels, + act=act, + norm=norm, + blocks_down=blocks_down, + blocks_up=blocks_up, + dsdepth=dsdepth, + preprocess=preprocess, + upsample_mode=upsample_mode, + resolution=resolution, + ) + + self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) + + def forward( # type: ignore + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True + ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: + """ + Args: + x: input tensor. + with_point: if true, return the point branch output. + with_label: if true, return the label branch output. + """ + if self.preprocess is not None: + x = self.preprocess(x) + + if not self.is_valid_shape(x): + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") + + x_down = self.encoder(x) + + x_down.reverse() + x = x_down.pop(0) + + if len(x_down) == 0: + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] + + outputs: list[torch.Tensor] = [] + outputs_auto: list[torch.Tensor] = [] + x_ = x.clone() + if with_point: + i = 0 + for level in self.up_layers: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs.append(level["head"](x)) + i = i + 1 + + outputs.reverse() + x = x_ + if with_label: + i = 0 + for level in self.up_layers_auto: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs_auto.append(level["head"](x)) + i = i + 1 + + outputs_auto.reverse() + + return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + """ + Args: + auto_freeze: if true, freeze the image encoder and the auto-branch. + point_freeze: if true, freeze the image encoder and the point-branch. + """ + for param in self.encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + + for param in self.up_layers_auto.parameters(): + param.requires_grad = not auto_freeze + + for param in self.up_layers.parameters(): + param.requires_grad = not point_freeze diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py new file mode 100644 index 0000000000..fe7f93d493 --- /dev/null +++ b/monai/networks/nets/vista3d.py @@ -0,0 +1,908 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import monai +from monai.networks.blocks import MLPBlock, UnetrBasicBlock +from monai.networks.nets import SegResNetDS2 +from monai.transforms.utils import convert_points_to_disc +from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import sample_points_from_label +from monai.utils import optional_import, unsqueeze_left, unsqueeze_right + +rearrange, _ = optional_import("einops", name="rearrange") + +__all__ = ["VISTA3D", "vista3d132"] + + +def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): + """ + Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. + The model treats class index larger than 132 as zero-shot. + + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) + return vista + + +class VISTA3D(nn.Module): + """ + VISTA3D based on: + `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography + `_. + + Args: + image_encoder: image encoder backbone for feature extraction. + class_head: class head used for class index based segmentation + point_head: point head used for interactive segmetnation + """ + + def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module): + super().__init__() + self.image_encoder = image_encoder + self.class_head = class_head + self.point_head = point_head + self.image_embeddings = None + self.auto_freeze = False + self.point_freeze = False + self.NINF_VALUE = -9999 + self.PINF_VALUE = 9999 + + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: + """Get number of foreground classes based on class and point prompt.""" + if class_vector is None: + if point_coords is None: + raise ValueError("class_vector and point_coords cannot be both None.") + return point_coords.shape[0] + else: + return class_vector.shape[0] + + def convert_point_label( + self, + point_label: torch.Tensor, + label_set: Sequence[int] | None = None, + special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128), + ): + """ + Convert point label based on its class prompt. For special classes defined in special index, + the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those + classes with ambiguous classes. + + Args: + point_label: the point label tensor, [B, N]. + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + special_index: the special class index that needs to be converted. + """ + if label_set is None: + return point_label + if not point_label.shape[0] == len(label_set): + raise ValueError("point_label and label_set must have the same length.") + + for i in range(len(label_set)): + if label_set[i] in special_index: + for j in range(len(point_label[i])): + point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j] + return point_label + + def sample_points_patch_val( + self, + labels: torch.Tensor, + patch_coords: Sequence[slice], + label_set: Sequence[int], + use_center: bool = True, + mapped_label_set: Sequence[int] | None = None, + max_ppoint: int = 1, + max_npoint: int = 0, + ): + """ + Sample points for patch during sliding window validation. Only used for point only validation. + + Args: + labels: shape [1, 1, H, W, D]. + patch_coords: a sequence of sliding window slice objects. + label_set: local index, must match values in labels. + use_center: sample points from the center. + mapped_label_set: global index, it is used to identify special classes and is the global index + for the sampled points. + max_ppoint/max_npoint: positive points and negative points to sample. + """ + point_coords, point_labels = sample_points_from_label( + labels[patch_coords], + label_set, + max_ppoint=max_ppoint, + max_npoint=max_npoint, + device=labels.device, + use_center=use_center, + ) + point_labels = self.convert_point_label(point_labels, mapped_label_set) + return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1)) + + def update_point_to_patch( + self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor + ): + """ + Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1. + + Args: + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. + point_coords: point coordinates, [B, N, 3]. + point_labels: point labels, [B, N]. + """ + patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] + patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] + # update point coords + patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2) + patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2) + # [1 N 1] + indices = torch.logical_and( + ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2) + ) + # check if it's within patch coords + point_coords = point_coords.clone() - patch_starts_tensor + point_labels = point_labels.clone() + if indices.any(): + point_labels[~indices] = -1 + point_coords[~indices] = 0 + # also remove padded points, mainly used for inference. + not_pad_indices = (point_labels != -1).any(0) + point_coords = point_coords[:, not_pad_indices] + point_labels = point_labels[:, not_pad_indices] + return point_coords, point_labels + return None, None + + def connected_components_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + thred: float = 0.5, + ): + """ + Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks + from a single image patch. + Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. + mapping_index represents the correspondence between B and B1. + For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed + region in point clicks must be updated by the lcc function. + Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + thred: the threshold to convert logits to binary. + """ + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + _logits = logits[mapping_index] + inside = [] + for i in range(_logits.shape[0]): + inside.append( + np.any( + [ + _logits[i, 0, p[0], p[1], p[2]].item() > 0 + for p in point_coords[i].cpu().numpy().round().astype(int) + ] + ) + ) + inside_tensor = torch.tensor(inside).to(logits.device) + nan_mask = torch.isnan(_logits) + # _logits are converted to binary [B1, 1, H, W, D] + _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() + pos_region = point_logits.sigmoid() > thred + diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region) + diff_neg = torch.logical_and((_logits > thred), ~pos_region) + cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels) + # cc is the region that can be updated by point_logits. + cc = cc.to(logits.device) + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, + # only remove unconnected positive region. + uc_pos_region = torch.logical_and(pos_region, ~cc) + fill_mask = torch.logical_and(nan_mask, uc_pos_region) + if fill_mask.any(): + # fill in the mean negative value + point_logits[fill_mask] = -1 + # replace logits nan value and cc with point_logits + cc = torch.logical_or(nan_mask, cc).to(logits.dtype) + logits[mapping_index] *= 1 - cc + logits[mapping_index] += cc * point_logits + return logits + + def gaussian_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + radius: int | None = None, + ): + """ + Combine point results with auto results using gaussian. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + radius: gaussian ball radius. + """ + if radius is None: + radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 + weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum( + 1, keepdims=True + ) + weight[weight < 0] = 0 + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + logits[mapping_index] *= weight + logits[mapping_index] += (1 - weight) * point_logits + return logits + + def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): + """ + Freeze auto-branch or point-branch. + + Args: + auto_freeze: whether to freeze the auto branch. + point_freeze: whether to freeze the point branch. + """ + if auto_freeze != self.auto_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.class_head.parameters(): + param.requires_grad = not auto_freeze + self.auto_freeze = auto_freeze + + if point_freeze != self.point_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.point_head.parameters(): + param.requires_grad = not point_freeze + self.point_freeze = point_freeze + + def forward( + self, + input_images: torch.Tensor, + point_coords: torch.Tensor | None = None, + point_labels: torch.Tensor | None = None, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + patch_coords: Sequence[slice] | None = None, + labels: torch.Tensor | None = None, + label_set: Sequence[int] | None = None, + prev_mask: torch.Tensor | None = None, + radius: int | None = None, + val_point_sampler: Callable | None = None, + **kwargs, + ): + """ + The forward function for VISTA3D. We only support single patch in training and inference. + One exception is allowing sliding window batch size > 1 for automatic segmentation only case. + B represents number of objects, N represents number of points for each objects. + + Args: + input_images: [1, 1, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. + 2/3 means negative/postive ponits for special supported class like tumor. + class_vector: [B, 1], the global class index + prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if + the points are for zero-shot or supported class. When class_vector and point_coords are both + provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] + will be considered novel class. + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. + This is the transposed raw output from sliding_window_inferer before any postprocessing. + When user click points to perform auto-results correction, this can be the auto-results. + radius: single float value controling the gaussian blur when combining point and auto results. + The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. + val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. + + """ + image_size = input_images.shape[-3:] + device = input_images.device + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + + bs = self.get_foreground_class_count(class_vector, point_coords) + if patch_coords is not None: + # if during validation and perform enable based point-validation. + if labels is not None and label_set is not None: + # if labels is not None, sample from labels for each patch. + if val_point_sampler is None: + # TODO: think about how to refactor this part. + val_point_sampler = self.sample_points_patch_val + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + if prompt_class[0].item() == 0: # type: ignore + point_labels[0] = -1 # type: ignore + labels, prev_mask = None, None + elif point_coords is not None: + # If not performing patch-based point only validation, use user provided click points for inference. + # the point clicks is in original image space, convert it to current patch-coordinate space. + point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + + if point_coords is not None and point_labels is not None: + # remove points that used for padding purposes (point_label = -1) + mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool) + if mapping_index.any(): + point_coords = point_coords[mapping_index] + point_labels = point_labels[mapping_index] + if prompt_class is not None: + prompt_class = prompt_class[mapping_index] + else: + if self.auto_freeze or (class_vector is None and patch_coords is None): + # if auto_freeze, point prompt must exist to allow loss backward + # in training, class_vector and point cannot both be None due to loss.backward() + mapping_index.fill_(True) + else: + point_coords, point_labels = None, None + + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + + if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: + out, out_auto = self.image_embeddings, None + else: + out, out_auto = self.image_encoder( + input_images, with_point=point_coords is not None, with_label=class_vector is not None + ) + # release memory + input_images = None # type: ignore + + # force releasing memories that set to None + torch.cuda.empty_cache() + if class_vector is not None: + logits, _ = self.class_head(out_auto, class_vector) + if point_coords is not None: + point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if patch_coords is None: + logits = self.gaussian_combine( + logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore + ) + else: + # during validation use largest component + logits = self.connected_components_combine( + logits, point_logits, point_coords, point_labels, mapping_index # type: ignore + ) + else: + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype) + logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if prev_mask is not None and patch_coords is not None: + logits = self.connected_components_combine( + prev_mask[patch_coords].transpose(1, 0).to(logits.device), + logits[mapping_index], + point_coords, # type: ignore + point_labels, # type: ignore + mapping_index, + ) + + if kwargs.get("keep_cache", False) and class_vector is None: + self.image_embeddings = out.detach() + return logits + + +class PointMappingSAM(nn.Module): + def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132): + """Interactive point head used for VISTA3D. + Adapted from segment anything: + `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + feature_size: feature channel from encoder. + max_prompt: max prompt number in each forward iteration. + n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. + last_supported: number of classes the model support, this value should match the trained model weights. + """ + super().__init__() + transformer_dim = feature_size + self.max_prompt = max_prompt + self.feat_downsample = nn.Sequential( + nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1), + nn.InstanceNorm3d(feature_size), + nn.GELU(), + nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(feature_size), + ) + + self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1) + + self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4) + self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) + self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)]) + self.not_a_point_embed = nn.Embedding(1, transformer_dim) + self.special_class_embed = nn.Embedding(1, transformer_dim) + self.mask_tokens = nn.Embedding(1, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(transformer_dim), + nn.GELU(), + nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + ) + + self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) + # class embedding + self.n_classes = n_classes + self.last_supported = last_supported + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.zeroshot_embed = nn.Embedding(1, transformer_dim) + self.supported_embed = nn.Embedding(1, transformer_dim) + + def forward( + self, + out: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + ): + """Args: + out: feature from encoder, [1, C, H, W, C] + point_coords: point coordinates, [B, N, 3] + point_labels: point labels, [B, N] + class_vector: class prompts, [B] + """ + # downsample out + out_low = self.feat_downsample(out) + out_shape = tuple(out.shape[-3:]) + # release memory + out = None # type: ignore + torch.cuda.empty_cache() + # embed points + points = point_coords + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore + point_embedding[point_labels == -1] = 0.0 + point_embedding[point_labels == -1] += self.not_a_point_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight + point_embedding[point_labels == 1] += self.point_embeddings[1].weight + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight + output_tokens = self.mask_tokens.weight + + output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) + if class_vector is None: + tokens_all = torch.cat( + ( + output_tokens, + point_embedding, + self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1), + ), + dim=1, + ) + # tokens_all = torch.cat((output_tokens, point_embedding), dim=1) + else: + class_embeddings = [] + for i in class_vector: + if i > self.last_supported: + class_embeddings.append(self.zeroshot_embed.weight) + else: + class_embeddings.append(self.supported_embed.weight) + tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1) + # cross attention + masks = [] + max_prompt = self.max_prompt + for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))): + # remove variables in previous for loops to save peak memory for self.transformer + src, upscaled_embedding, hyper_in = None, None, None + torch.cuda.empty_cache() + idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) + tokens = tokens_all[idx[0] : idx[1]] + src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0) + b, c, h, w, d = src.shape + hs, src = self.transformer(src, pos_src, tokens) + mask_tokens_out = hs[:, :1, :] + hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) + src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore + upscaled_embedding = self.output_upscaling(src) + b, c, h, w, d = upscaled_embedding.shape + mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.vstack(masks) + + +class ClassMappingClassify(nn.Module): + """Class head that performs automatic segmentation based on class vector.""" + + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True): + """Args: + n_classes: maximum number of class embedding. + feature_size: class embedding size. + use_mlp: use mlp to further map class embedding. + """ + super().__init__() + self.use_mlp = use_mlp + if use_mlp: + self.mlp = nn.Sequential( + nn.Linear(feature_size, feature_size), + nn.InstanceNorm1d(1), + nn.GELU(), + nn.Linear(feature_size, feature_size), + ) + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.image_post_mapping = nn.Sequential( + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + ) + + def forward(self, src: torch.Tensor, class_vector: torch.Tensor): + b, c, h, w, d = src.shape + src = self.image_post_mapping(src) + class_embedding = self.class_embeddings(class_vector) + if self.use_mlp: + class_embedding = self.mlp(class_embedding) + # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. + masks = [] + for i in range(b): + mask = class_embedding @ src[[i]].view(1, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.cat(masks, 1), class_embedding + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + depth: number of layers in the transformer. + embedding_dim: the channel dimension for the input embeddings. + num_heads: the number of heads for multihead attention. Must divide embedding_dim. + mlp_dim: the channel dimension internal to the MLP block. + activation: the activation to use in the MLP block. + attention_downsample_rate: the rate at which to downsample the image before projecting. + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + image_embedding: image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe: the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding: the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding. + torch.Tensor: the processed image_embedding. + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + mlp_dim: the hidden dimension of the mlp block. + activation: the activation of the mlp block. + skip_first_layer_pe: skip the PE on the first layer. + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d") + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + downsample_rate: the rate at which to downsample the image before projecting. + """ + + def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + if not self.internal_dim % num_heads == 0: + raise ValueError("num_heads must divide embedding_dim.") + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + # B x N_heads x N_tokens x C_per_head + return x.transpose(1, 2) + + def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + # B x N_tokens x C + return x.reshape(b, n_tokens, n_heads * c_per_head) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`. + + Args: + num_pos_feats: the number of positional encoding features. + scale: the scale of the positional encoding. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats))) + + def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + # [bs=1,N=2,2] @ [2,128] + # [bs=1, N=2, 128] + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + # [bs=1, N=2, 128+128=256] + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w, d = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w, d), device=device, dtype=torch.float32) + x_embed = grid.cumsum(dim=0) - 0.5 + y_embed = grid.cumsum(dim=1) - 0.5 + z_embed = grid.cumsum(dim=2) - 0.5 + x_embed = x_embed / h + y_embed = y_embed / w + z_embed = z_embed / d + pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) + # C x H x W + return pe.permute(3, 0, 1, 2) + + def forward_with_coords( + self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int] + ) -> torch.torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[0] + coords[:, :, 1] = coords[:, :, 1] / image_size[1] + coords[:, :, 2] = coords[:, :, 2] / image_size[2] + # B x N x C + return self._pe_encoding(coords.to(torch.float)) + + +class MLP(nn.Module): + """ + Multi-layer perceptron. This class is only used for `PointMappingSAM`. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + input_dim: the input dimension. + hidden_dim: the hidden dimension. + output_dim: the output dimension. + num_layers: the number of layers. + sigmoid_output: whether to apply a sigmoid activation to the output. + """ + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e32bf6fc48..363fce91be 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1274,10 +1274,10 @@ def convert_points_to_disc( coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) for b, n in np.ndindex(*point.shape[:2]): - point_bn = unsqueeze_right(point[b, n], 6) + point_bn = unsqueeze_right(point[b, n], 4) if point_label[b, n] > -1: channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 - pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + pow_diff = torch.pow(coords[b, channel] - point_bn, 2) if disc: masks[b, channel] += pow_diff.sum(0) < radius**2 else: diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index 5372fcc8ae..eab7bac9a0 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import SegResNetDS +from monai.networks.nets import SegResNetDS, SegResNetDS2 from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,7 +71,7 @@ ] -class TestResNetDS(unittest.TestCase): +class TestSegResNetDS(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_DS) def test_shape(self, input_param, input_shape, expected_shape): @@ -80,47 +80,71 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + @parameterized.expand(TEST_CASE_SEGRESNET_DS) + def test_shape_ds2(self, input_param, input_shape, expected_shape): + net = SegResNetDS2(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device), with_label=False) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[1] == []) + + result = net(torch.randn(input_shape).to(device), with_point=False) + self.assertEqual(result[1].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[0] == []) + @parameterized.expand(TEST_CASE_SEGRESNET_DS2) def test_shape2(self, input_param, input_shape, expected_shape): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - if dsdepth > 1: - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual( - result[i].shape, - expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), - msg=str(input_param), - ) - else: - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) - - net.eval() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_label=False)[0] + else: + result = net(torch.randn(input_shape).to(device)) + if dsdepth > 1: + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual( + result[i].shape, + expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), + msg=str(input_param), + ) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + if not isinstance(net, SegResNetDS2): + # eval mode of SegResNetDS2 has same output as training mode + # so only test eval mode for SegResNetDS + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) @parameterized.expand(TEST_CASE_SEGRESNET_DS3) def test_shape3(self, input_param, input_shape, expected_shapes): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_point=False)[1] + else: + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) def test_ill_arg(self): with self.assertRaises(ValueError): SegResNetDS(spatial_dims=4) + with self.assertRaises(ValueError): + SegResNetDS2(spatial_dims=4) + @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0] diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py new file mode 100644 index 0000000000..d3b4e0c10e --- /dev/null +++ b/tests/test_vista3d.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VISTA3D, SegResNetDS2 +from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + [{"encoder_embed_dim": 48, "in_channels": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)], + [{"encoder_embed_dim": 48, "in_channels": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {"class_vector": torch.tensor([1, 2, 3], device=device)}, + (1, 1, 64, 64, 64), + (3, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "class_vector": torch.tensor([1, 2], device=device), + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0], [1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestVista3d(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): + segresnet = SegResNetDS2( + in_channels=args["in_channels"], + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=args["encoder_embed_dim"], + init_filters=args["encoder_embed_dim"], + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=args["encoder_embed_dim"], n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) + net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) + with eval_mode(net): + result = net.forward( + torch.randn(input_shape).to(device), + point_coords=input_params.get("point_coords", None), + point_labels=input_params.get("point_labels", None), + class_vector=input_params.get("class_vector", None), + ) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() From 7b9a523c97ff9e7f1fad7fb2a761ce5322947500 Mon Sep 17 00:00:00 2001 From: Balamurali Date: Thu, 15 Aug 2024 01:01:03 -0700 Subject: [PATCH 17/24] NACLLoss memory management (#8020) Fixes # . ### Description Calling contiguous after applying the permute option to work with view operation in apply_filter (https://github.com/Project-MONAI/MONAI/blob/59a7211070538586369afd4a01eca0a7fe2e742e/monai/networks/layers/simplelayers.py#L293). ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Balamurali Signed-off-by: bala93 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/nacl_loss.py | 4 ++-- tests/test_nacl_loss.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py index 3303e89bce..27a712d308 100644 --- a/monai/losses/nacl_loss.py +++ b/monai/losses/nacl_loss.py @@ -95,11 +95,11 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor: rmask: torch.Tensor if self.dim == 2: - oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float() + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float() rmask = self.svls_layer(oh_labels) if self.dim == 3: - oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float() + oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float() rmask = self.svls_layer(oh_labels) return rmask diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py index 51ec275cf4..704bbdb9b1 100644 --- a/tests/test_nacl_loss.py +++ b/tests/test_nacl_loss.py @@ -47,6 +47,7 @@ TEST_CASES = [ [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442], + [{"classes": 3, "dim": 2}, {"inputs": inputs.repeat(4, 1, 1, 1), "targets": targets.repeat(4, 1, 1)}, 1.1442], [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433], [{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469], [{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269], From 9f56a3a02eef613546a1a19e98e36627c961c650 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Mon, 19 Aug 2024 23:09:10 +0800 Subject: [PATCH 18/24] Move PyType test to weekly test (#8025) Fixes #8022 ### Description - Add format test to weekly test - Set pytype test as not required in each PR - Add packaging in weekly-preview pipeline ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .github/workflows/weekly-preview.yml | 35 +++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml index e94e1dac5a..8d8cccffad 100644 --- a/.github/workflows/weekly-preview.yml +++ b/.github/workflows/weekly-preview.yml @@ -5,6 +5,39 @@ on: - cron: "0 2 * * 0" # 02:00 of every Sunday jobs: + flake8-py3: + runs-on: ubuntu-latest + strategy: + matrix: + opt: ["codeformat", "pytype", "mypy"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + - name: cache weekly timestamp + id: pip-cache + run: | + echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT + - name: cache for pip + uses: actions/cache@v4 + id: cache + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }} + - name: Install dependencies + run: | + find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \; + python -m pip install --upgrade pip wheel + python -m pip install -r requirements-dev.txt + - name: Lint and type check + run: | + # clean up temporary files + $(pwd)/runtests.sh --build --clean + # Github actions have 2 cores, so parallelize pytype + $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2 + packaging: if: github.repository == 'Project-MONAI/MONAI' runs-on: ubuntu-latest @@ -19,7 +52,7 @@ jobs: python-version: '3.9' - name: Install setuptools run: | - python -m pip install --user --upgrade setuptools wheel + python -m pip install --user --upgrade setuptools wheel packaging - name: Build distribution run: | export HEAD_COMMIT_ID=$(git rev-parse HEAD) From 3a6f6200861722df4176274f45277923d877d1ef Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Mon, 19 Aug 2024 17:26:44 +0100 Subject: [PATCH 19/24] Updating to match Numpy 2.0 requirements (#7857) Fixes #7856. ### Description This introduces changes to meet Numpy 2.0 requirements. MONAI itself is compatible with Numpy 2.0 however some dependencies are not such as older versions of Pytorch. This PR adjusts the MAX_SEED value to be compatible with Numpy 2.0 behaviour changes, uses the `ptp` function, and some other minor tweaks. The versions for dependencies are also fixed to exclude Numpy 2.0. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- environment-dev.yml | 4 ++-- monai/data/utils.py | 2 +- monai/transforms/io/array.py | 2 +- monai/transforms/spatial/functional.py | 2 +- monai/transforms/transform.py | 4 ++-- requirements.txt | 2 +- setup.cfg | 2 +- tests/test_meta_tensor.py | 2 +- tests/test_nifti_endianness.py | 2 +- tests/test_signal_fillempty.py | 4 ++-- tests/test_signal_fillemptyd.py | 4 ++-- 11 files changed, 15 insertions(+), 15 deletions(-) diff --git a/environment-dev.yml b/environment-dev.yml index d23958baba..a4651ec7e4 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,10 +5,10 @@ channels: - nvidia - conda-forge dependencies: - - numpy>=1.20 + - numpy>=1.24,<2.0 - pytorch>=1.9 - torchvision - - pytorch-cuda=11.6 + - pytorch-cuda>=11.6 - pip - pip: - -r requirements-dev.txt diff --git a/monai/data/utils.py b/monai/data/utils.py index 7a08300abb..f35c5124d8 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -927,7 +927,7 @@ def compute_shape_offset( corners = in_affine_ @ corners all_dist = corners_out[:-1].copy() corners_out = corners_out[:-1] / corners_out[-1] - out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0) + out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0) offset = None for i in range(corners.shape[1]): min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index e0ecc127f2..7c0e8f7123 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -86,7 +86,7 @@ def switch_endianness(data, new="<"): if new not in ("<", ">"): raise NotImplementedError(f"Not implemented option new={new}.") if current_ != new: - data = data.byteswap().newbyteorder(new) + data = data.byteswap().view(data.dtype.newbyteorder(new)) elif isinstance(data, tuple): data = tuple(switch_endianness(x, new) for x in data) elif isinstance(data, list): diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index add4e7f5ea..22726f06a5 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -373,7 +373,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l if output_shape is None: corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1)) corners = transform[:-1, :-1] @ corners # type: ignore - output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int) + output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int) else: output_shape = np.asarray(output_shape, dtype=int) shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist()) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3d09cea545..15c2499a73 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -203,8 +203,8 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState """ if seed is not None: - _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed - _seed = _seed % MAX_SEED + _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed) + _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64 self.R = np.random.RandomState(_seed) return self diff --git a/requirements.txt b/requirements.txt index aae455f58c..e184322c13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ torch>=1.9 -numpy>=1.20,<=1.26.0 +numpy>=1.24,<2.0 diff --git a/setup.cfg b/setup.cfg index 2115c30a7f..1ce4a3f34c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ setup_requires = ninja install_requires = torch>=1.9 - numpy>=1.20 + numpy>=1.24,<2.0 [options.extras_require] all = diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index f31a07eba4..60b6019703 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -448,7 +448,7 @@ def test_shape(self): def test_astype(self): t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"}) - for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long, np.uint16): + for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16): self.assertIsInstance(t.astype(np_types), np.ndarray) for pt_types in ("torch.float", torch.float, "torch.float64"): self.assertIsInstance(t.astype(pt_types), torch.Tensor) diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py index 4475d8aaab..f8531dc08f 100644 --- a/tests/test_nifti_endianness.py +++ b/tests/test_nifti_endianness.py @@ -82,7 +82,7 @@ def test_switch(self): # verify data types after = switch_endianness(before) np.testing.assert_allclose(after.astype(float), expected_float) - before = np.array(["1.12", "-9.2", "42"], dtype=np.string_) + before = np.array(["1.12", "-9.2", "42"], dtype=np.bytes_) after = switch_endianness(before) np.testing.assert_array_equal(before, after) diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py index a3ee623cc5..2be4bd8600 100644 --- a/tests/test_signal_fillempty.py +++ b/tests/test_signal_fillempty.py @@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = np.load(TEST_SIGNAL) - sig[:, 123] = np.NAN + sig[:, 123] = np.nan fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) self.assertTrue(not np.isnan(fillemptysignal).any()) @@ -42,7 +42,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty) sig = convert_to_tensor(np.load(TEST_SIGNAL)) - sig[:, 123] = convert_to_tensor(np.NAN) + sig[:, 123] = convert_to_tensor(np.nan) fillempty = SignalFillEmpty(replacement=0.0) fillemptysignal = fillempty(sig) self.assertTrue(not torch.isnan(fillemptysignal).any()) diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py index ee8c571ef8..7710279495 100644 --- a/tests/test_signal_fillemptyd.py +++ b/tests/test_signal_fillemptyd.py @@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = np.load(TEST_SIGNAL) - sig[:, 123] = np.NAN + sig[:, 123] = np.nan data = {} data["signal"] = sig fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) @@ -46,7 +46,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase): def test_correct_parameters_multi_channels(self): self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd) sig = convert_to_tensor(np.load(TEST_SIGNAL)) - sig[:, 123] = convert_to_tensor(np.NAN) + sig[:, 123] = convert_to_tensor(np.nan) data = {} data["signal"] = sig fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0) From cea80a686b907b22d90c8024cea5a17ef9d9f58a Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 20 Aug 2024 19:55:22 +0800 Subject: [PATCH 20/24] 8029 update load old weights function for diffusion_model_unet.py (#8031) Fixes #8029 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/networks/nets/diffusion_model_unet.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index f57fe251d2..65d6053acc 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -1837,9 +1837,26 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] + attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k] for block in attention_blocks: + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # projection + new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight") + new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias") + + # fix the cross attention blocks + cross_attention_blocks = [ + k.replace(".out_proj.weight", "") + for k in new_state_dict + if "out_proj.weight" in k and "transformer_blocks" in k + ] + for block in cross_attention_blocks: new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") From de2a819e82e9c0575a959170d8e534fefe002d08 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 23 Aug 2024 11:16:29 +0800 Subject: [PATCH 21/24] Fix AttributeError when using torch.min and max (#8041) Fixes #8040. ### Description Only return values if got a namedtuple when using torch.min and max ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- .../transforms/utils_pytorch_numpy_unification.py | 4 ++-- tests/test_utils_pytorch_numpy_unification.py | 14 +++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 020d99af16..98b75cff76 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.max(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor: @@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe else: ret = torch.min(x, int(dim), **kwargs) # type: ignore - return ret + return ret[0] if isinstance(ret, tuple) else ret def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor: diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 6e655289e4..90c0401e46 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -17,7 +17,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import mode, percentile +from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick @@ -27,6 +27,13 @@ TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) +TEST_MIN_MAX = [] +for p in TEST_NDARRAYS: + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])]) + TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)]) + TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])]) + class TestPytorchNumpyUnification(unittest.TestCase): @@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long): res = mode(array, to_long=to_long) assert_allclose(res, expected) + @parameterized.expand(TEST_MIN_MAX) + def test_min_max(self, array, input_params, func, expected): + res = func(array, **input_params) + assert_allclose(res, expected, type_test=False) + if __name__ == "__main__": unittest.main() From a5fbe716378948630783deef8ee435e7e3bdc918 Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:09:23 +0800 Subject: [PATCH 22/24] Refactor Export for Model Conversion and Saving (#7934) Fixes #6375 . ### Description Changes to be made based on the [previous discussion #7835](https://github.com/Project-MONAI/MONAI/pull/7835). Modify the `_export` function to call the `saver` parameter for saving different models. Rewrite the `onnx_export` function using the updated `_export` to achieve consistency in model format conversion and saving. * Rewrite `onnx_export` to call `_export` with `convert_to_onnx` and appropriate `kwargs`. * Add a `saver: Callable` parameter to `_export`, replacing `save_net_with_metadata`. * Pass `save_net_with_metadata` function wrapped with `partial` to set parameters like `include_config_vals` and `append_timestamp`. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han123su Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 46 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6dd83c1f81..142a366669 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -18,6 +18,7 @@ import warnings import zipfile from collections.abc import Mapping, Sequence +from functools import partial from pathlib import Path from pydoc import locate from shutil import copyfile @@ -1254,6 +1255,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1268,6 +1270,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that accepts the converted model to save, a filepath to save to, meta values + (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1307,14 +1311,9 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, - ) + meta_values = parser.get().pop("_meta_", None) + saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files) + logger.info(f"exported to file: {filepath}.") @@ -1413,17 +1412,23 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None: + onnx.save(onnx_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1544,8 +1549,12 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1715,8 +1724,11 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_, From 872585df178b71df88399a395083577165aeb5ac Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Mon, 26 Aug 2024 00:12:05 -0500 Subject: [PATCH 23/24] Add vista3d inferers (#8021) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/apps.rst | 16 ++ .../maisi/utils => vista3d}/__init__.py | 0 monai/apps/vista3d/inferer.py | 177 ++++++++++++++ monai/apps/vista3d/sampler.py | 172 ++++++++++++++ monai/apps/vista3d/transforms.py | 224 ++++++++++++++++++ monai/inferers/utils.py | 1 + monai/networks/nets/vista3d.py | 43 +++- monai/transforms/utils.py | 55 ++++- tests/min_tests.py | 1 + tests/test_point_based_window_inferer.py | 77 ++++++ tests/test_vista3d_sampler.py | 100 ++++++++ tests/test_vista3d_transforms.py | 94 ++++++++ tests/test_vista3d_utils.py | 45 +++- 13 files changed, 988 insertions(+), 17 deletions(-) rename monai/apps/{generation/maisi/utils => vista3d}/__init__.py (100%) create mode 100644 monai/apps/vista3d/inferer.py create mode 100644 monai/apps/vista3d/sampler.py create mode 100644 monai/apps/vista3d/transforms.py create mode 100644 tests/test_point_based_window_inferer.py create mode 100644 tests/test_vista3d_sampler.py create mode 100644 tests/test_vista3d_transforms.py diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 7fa7b9e9ff..cc4cea8c1e 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -248,6 +248,22 @@ FastMRIReader ~~~~~~~~~~~~~ .. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj +`Vista3d` +--------- +.. automodule:: monai.apps.vista3d.inferer +.. autofunction:: point_based_window_inferer + +.. automodule:: monai.apps.vista3d.transforms +.. autoclass:: VistaPreTransformd + :members: +.. autoclass:: VistaPostTransformd + :members: +.. autoclass:: Relabeld + :members: + +.. automodule:: monai.apps.vista3d.sampler +.. autofunction:: sample_prompt_pairs + `Auto3DSeg` ----------- .. automodule:: monai.apps.auto3dseg diff --git a/monai/apps/generation/maisi/utils/__init__.py b/monai/apps/vista3d/__init__.py similarity index 100% rename from monai/apps/generation/maisi/utils/__init__.py rename to monai/apps/vista3d/__init__.py diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py new file mode 100644 index 0000000000..709f81f624 --- /dev/null +++ b/monai/apps/vista3d/inferer.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +from collections.abc import Sequence +from typing import Any + +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.utils import optional_import + +tqdm, _ = optional_import("tqdm", name="tqdm") + +__all__ = ["point_based_window_inferer"] + + +def point_based_window_inferer( + inputs: torch.Tensor | MetaTensor, + roi_size: Sequence[int], + predictor: torch.nn.Module, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + prev_mask: torch.Tensor | MetaTensor | None = None, + point_start: int = 0, + center_only: bool = True, + margin: int = 5, + **kwargs: Any, +) -> torch.Tensor: + """ + Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image. + The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by + patch inference and average output stitching, and finally returns the segmented mask. + + Args: + inputs: [1CHWD], input image to be processed. + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D]. + Add transpose=True in kwargs for vista3d. + point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points. + point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes. + 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel). + class_vector: [B]. Used for class-head automatic segmentation. Can be None value. + prompt_class: [B]. The same as class_vector representing the point class and inform point head about + supported class or zeroshot, not used for automatic segmentation. If None, point head is default + to supported class segmentation. + prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks. + point_start: only use points starting from this number. All points before this number is used to generate + prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask. + center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point. + margin: if center_only is false, this value is the distance between point to the patch boundary. + Returns: + stitched_output: [1, B, H, W, D]. The value is before sigmoid. + Notice: The function only supports SINGLE OBJECT INFERENCE with B=1. + """ + if not point_coords.shape[0] == 1: + raise ValueError("Only supports single object point click.") + if not len(inputs.shape) == 5: + raise ValueError("Input image should be 5D.") + image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) + prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None + stitched_output = None + for p in point_coords[0][point_start:]: + lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin) + ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin) + lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin) + for i in range(len(lx_)): + for j in range(len(ly_)): + for k in range(len(lz_)): + lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k]) + unravel_slice = [ + slice(None), + slice(None), + slice(int(lx), int(rx)), + slice(int(ly), int(ry)), + slice(int(lz), int(rz)), + ] + batch_image = image[unravel_slice] + output = predictor( + batch_image, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + patch_coords=unravel_slice, + prev_mask=prev_mask, + **kwargs, + ) + if stitched_output is None: + stitched_output = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_mask = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_output[unravel_slice] += output.to("cpu") + stitched_mask[unravel_slice] = 1 + # if stitched_mask is 0, then NaN value + stitched_output = stitched_output / stitched_mask + # revert padding + stitched_output = stitched_output[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + stitched_mask = stitched_mask[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + if prev_mask is not None: + prev_mask = prev_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + prev_mask = prev_mask.to("cpu") # type: ignore + # for un-calculated place, use previous mask + stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + if isinstance(inputs, torch.Tensor): + inputs = MetaTensor(inputs) + if not hasattr(stitched_output, "meta"): + stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) + return stitched_output + + +def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]: + """Helper function to get the window index.""" + if p - roi // 2 < 0: + left, right = 0, roi + elif p + roi // 2 > s: + left, right = s - roi, s + else: + left, right = int(p) - roi // 2, int(p) + roi // 2 + return left, right + + +def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]: + """Get the window index.""" + left, right = _get_window_idx_c(p, roi, s) + if center_only: + return [left], [right] + left_most = max(0, p - roi + margin) + right_most = min(s, p + roi - margin) + left_list = [left_most, right_most - roi, left] + right_list = [left_most + roi, right_most, right] + return left_list, right_list + + +def _pad_previous_mask( + inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0 +) -> tuple[torch.Tensor | MetaTensor, list[int]]: + """Helper function to pad inputs.""" + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore + return inputs, pad_size diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py new file mode 100644 index 0000000000..b7aeb89a2e --- /dev/null +++ b/monai/apps/vista3d/sampler.py @@ -0,0 +1,172 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import copy +import random +from collections.abc import Callable, Sequence +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +__all__ = ["sample_prompt_pairs"] + +ENABLE_SPECIAL = True +SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) +MERGE_LIST = { + 1: [25, 26], # hepatic tumor and vessel merge into liver + 4: [24], # pancreatic tumor merge into pancreas + 132: [57], # overlap with trachea merge into airway +} + + +def _get_point_label(id: int) -> tuple[int, int]: + if id in SPECIAL_INDEX and ENABLE_SPECIAL: + return 2, 3 + else: + return 0, 1 + + +def sample_prompt_pairs( + labels: Tensor, + label_set: Sequence[int], + max_prompt: int | None = None, + max_foreprompt: int | None = None, + max_backprompt: int = 1, + max_point: int = 20, + include_background: bool = False, + drop_label_prob: float = 0.2, + drop_point_prob: float = 0.2, + point_sampler: Callable | None = None, + **point_sampler_kwargs: Any, +) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + """ + Sample training pairs for VISTA3D training. + + Args: + labels: [1, 1, H, W, D], ground truth labels. + label_set: the label list for the specific dataset. Note if 0 is included in label_set, + it will be added into automatic branch training. Recommend removing 0 from label_set + for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset. + The reason is region with 0 in one partially labeled dataset may contain foregrounds in + another dataset. + max_prompt: int, max number of total prompt, including foreground and background. + max_foreprompt: int, max number of prompt from foreground. + max_backprompt: int, max number of prompt from background. + max_point: maximum number of points for each object. + include_background: if include 0 into training prompt. If included, background 0 is treated + the same as foreground. Always be False for multi-partial-dataset training. If needed, + can be true for finetuning specific dataset, . + drop_label_prob: probability to drop label prompt. + drop_point_prob: probability to drop point prompt. + point_sampler: sampler to augment masks with supervoxel. + point_sampler_kwargs: arguments for point_sampler. + + Returns: + label_prompt: [B, 1]. The classes used for training automatic segmentation. + point: [B, N, 3]. The corresponding points for each class. + Note that background label prompt requires matching point as well ([0,0,0] is used). + point_label: [B, N]. The corresponding point labels for each point (negative or positive). + -1 is used for padding the background label prompt and will be ignored. + prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss. + label_prompt can be None, and prompt_class is used to identify point classes. + """ + # class label number + if not labels.shape[0] == 1: + raise ValueError("only support batch size 1") + labels = labels[0, 0] + device = labels.device + unique_labels = labels.unique().cpu().numpy().tolist() + if include_background: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) + else: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0}) + background_labels = list(set(label_set) - set(unique_labels)) + # during training, balance background and foreground prompts + if max_backprompt is not None: + if len(background_labels) > max_backprompt: + random.shuffle(background_labels) + background_labels = background_labels[:max_backprompt] + + if max_foreprompt is not None: + if len(unique_labels) > max_foreprompt: + random.shuffle(unique_labels) + unique_labels = unique_labels[:max_foreprompt] + + if max_prompt is not None: + if len(unique_labels) + len(background_labels) > max_prompt: + if len(unique_labels) > max_prompt: + unique_labels = random.sample(unique_labels, max_prompt) + background_labels = [] + else: + background_labels = random.sample(background_labels, max_prompt - len(unique_labels)) + _point = [] + _point_label = [] + # if use regular sampling + if point_sampler is None: + num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) + num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) + for id in unique_labels: + neg_id, pos_id = _get_point_label(id) + plabels = labels == int(id) + nlabels = ~plabels + plabelpoints = torch.nonzero(plabels) + nlabelpoints = torch.nonzero(nlabels) + # final sampled positive points + num_pa = min(len(plabelpoints), num_p) + # final sampled negative points + num_na = min(len(nlabelpoints), num_n) + _point.append( + torch.stack( + random.choices(plabelpoints, k=num_pa) + + random.choices(nlabelpoints, k=num_na) + + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na) + ) + ) + _point_label.append( + torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to( + device + ) + ) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + else: + _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs) + for _ in background_labels: + # pad the background labels + _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0 + _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point + if len(unique_labels) == 0 and len(background_labels) == 0: + # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must + # be skipped. Handle this in trainer. + label_prompt, point, point_label, prompt_class = None, None, None, None + else: + label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # If label prompt is dropped, there is no need to pad with points with label -1. + pad = len(background_labels) + point = point[: len(point) - pad] # type: ignore + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None + return label_prompt, point, point_label, prompt_class diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py new file mode 100644 index 0000000000..3e8145cd80 --- /dev/null +++ b/monai/apps/vista3d/transforms.py @@ -0,0 +1,224 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Sequence + +import numpy as np +import torch + +from monai.config import DtypeLike, KeysCollection +from monai.transforms import MapLabelValue +from monai.transforms.transform import MapTransform +from monai.transforms.utils import keep_components_with_positive_points +from monai.utils import look_up_option + +__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"] + + +def _get_name_to_index_mapping(labels_dict: dict | None) -> dict: + """get the label name to index mapping""" + name_to_index_mapping = {} + if labels_dict is not None: + name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()} + return name_to_index_mapping + + +def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None: + """convert the label name to index""" + if label_prompt is not None and isinstance(label_prompt, list): + converted_label_prompt = [] + # for new class, add to the mapping + for l in label_prompt: + if isinstance(l, str) and not l.isdigit(): + if l.lower() not in name_to_index_mapping: + name_to_index_mapping[l.lower()] = len(name_to_index_mapping) + for l in label_prompt: + if isinstance(l, (int, str)): + converted_label_prompt.append( + name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l) + ) + else: + converted_label_prompt.append(l) + return converted_label_prompt + return label_prompt + + +class VistaPreTransformd(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + special_index: Sequence[int] = (25, 26, 27, 28, 29, 117), + labels_dict: dict | None = None, + subclass: dict | None = None, + ) -> None: + """ + Pre-transform for Vista3d. + + It performs two functionalities: + + 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels), + convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive). + + 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key]. + e.g. "lung" label is converted to ["left lung", "right lung"]. + + The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B, + where each element is an int value of length [B, N]. + + Args: + keys: keys of the corresponding items to be transformed. + special_index: the index that defines the special class. + subclass: a dictionary that maps a label prompt to its subclasses. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.special_index = special_index + self.subclass = subclass + self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict) + + def __call__(self, data): + label_prompt = data.get("label_prompt", None) + point_labels = data.get("point_labels", None) + # convert the label name to index if needed + label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt) + try: + # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator. + if self.subclass is not None and label_prompt is not None: + _label_prompt = [] + subclass_keys = list(map(int, self.subclass.keys())) + for i in range(len(label_prompt)): + if label_prompt[i] in subclass_keys: + _label_prompt.extend(self.subclass[str(label_prompt[i])]) + else: + _label_prompt.append(label_prompt[i]) + data["label_prompt"] = _label_prompt + if label_prompt is not None and point_labels is not None: + if label_prompt[0] in self.special_index: + point_labels = np.array(point_labels) + point_labels[point_labels == 0] = 2 + point_labels[point_labels == 1] = 3 + point_labels = point_labels.tolist() + data["point_labels"] = point_labels + except Exception: + # There is specific requirements for `label_prompt` and `point_labels`. + # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None. + # Those formatting errors should be captured later. + warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.") + + return data + + +class VistaPostTransformd(MapTransform): + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + Post-transform for Vista3d. It converts the model output logits into final segmentation masks. + If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...], + else the indexes will be [0, label_prompt[0], label_prompt[1], ...]. + If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove + regions that does not contain positive points. + + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data): + """data["label_prompt"] should not contain 0""" + for keys in self.keys: + if keys in data: + pred = data[keys] + object_num = pred.shape[0] + device = pred.device + if data.get("label_prompt", None) is None and data.get("points", None) is not None: + pred = keep_components_with_positive_points( + pred.unsqueeze(0), + point_coords=data.get("points").to(device), + point_labels=data.get("point_labels").to(device), + )[0] + pred[pred < 0] = 0.0 + # if it's multichannel, perform argmax + if object_num > 1: + # concate background channel. Make sure user did not provide 0 as prompt. + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) + pred = pred.argmax(0).unsqueeze(0).float() + 1.0 + pred[is_bk] = 0.0 + else: + # AsDiscrete will remove NaN + # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) + pred[pred > 0] = 1.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = label_prompt[i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 + data[keys] = pred + return data + + +class Relabeld(MapTransform): + def __init__( + self, + keys: KeysCollection, + label_mappings: dict[str, list[tuple[int, int]]], + dtype: DtypeLike = np.int16, + dataset_key: str = "dataset_name", + allow_missing_keys: bool = False, + ) -> None: + """ + Remap the voxel labels in the input data dictionary based on the specified mapping. + + This list of local -> global label mappings will be applied to each input `data[keys]`. + if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used. + if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed. + + Args: + keys: keys of the corresponding items to be transformed. + label_mappings: a dictionary specifies how local dataset class indices are mapped to the + global class indices. The dictionary keys are dataset names and the values are lists of + list of (local label, global label) pairs. This list of local -> global label mappings + will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`, + label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None, + no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform. + dtype: convert the output data to dtype, default to float32. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mappers = {} + self.dataset_key = dataset_key + for name, mapping in label_mappings.items(): + self.mappers[name] = MapLabelValue( + orig_labels=[int(pair[0]) for pair in mapping], + target_labels=[int(pair[1]) for pair in mapping], + dtype=dtype, + ) + + def __call__(self, data): + d = dict(data) + dataset_name = d.get(self.dataset_key, "default") + _m = look_up_option(dataset_name, self.mappers, default=None) + if _m is None: + return d + for key in self.key_iterator(d): + d[key] = _m(d[key]) + return d diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index a080284e7c..bd99765348 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -300,6 +300,7 @@ def sliding_window_inference( # remove padding if image_size smaller than roi_size if any(pad_size): + kwargs.update({"pad_size": pad_size}) for ss, output_i in enumerate(output_image_list): zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] final_slicing: list[slice] = [] diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py index fe7f93d493..9148e36542 100644 --- a/monai/networks/nets/vista3d.py +++ b/monai/networks/nets/vista3d.py @@ -23,7 +23,7 @@ from monai.networks.blocks import MLPBlock, UnetrBasicBlock from monai.networks.nets import SegResNetDS2 from monai.transforms.utils import convert_points_to_disc -from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import keep_merge_components_with_points as lcc from monai.transforms.utils import sample_points_from_label from monai.utils import optional_import, unsqueeze_left, unsqueeze_right @@ -78,6 +78,35 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: self.NINF_VALUE = -9999 self.PINF_VALUE = 9999 + def update_slidingwindow_padding( + self, + pad_size: list | None, + labels: torch.Tensor | None, + prev_mask: torch.Tensor | None, + point_coords: torch.Tensor | None, + ): + """ + Image has been padded by sliding window inferer. + The related padding need to be performed outside of slidingwindow inferer. + + Args: + pad_size: padding size passed from sliding window inferer. + labels: image label ground truth. + prev_mask: previous segmentation mask. + point_coords: point click coordinates. + """ + if pad_size is None: + return labels, prev_mask, point_coords + if labels is not None: + labels = F.pad(labels, pad=pad_size, mode="constant", value=0) + if prev_mask is not None: + prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0) + if point_coords is not None: + point_coords = point_coords + torch.tensor( + [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device + ) + return labels, prev_mask, point_coords + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: """Get number of foreground classes based on class and point prompt.""" if class_vector is None: @@ -317,6 +346,7 @@ def forward( prev_mask: torch.Tensor | None = None, radius: int | None = None, val_point_sampler: Callable | None = None, + transpose: bool = False, **kwargs, ): """ @@ -329,7 +359,7 @@ def forward( point_coords: [B, N, 3] point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. 2/3 means negative/postive ponits for special supported class like tumor. - class_vector: [B, 1], the global class index + class_vector: [B, 1], the global class index. prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if the points are for zero-shot or supported class. When class_vector and point_coords are both provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] @@ -346,8 +376,12 @@ def forward( radius: single float value controling the gaussian blur when combining point and auto results. The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. - + transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from + sliding window inferer/point inferer. """ + labels, prev_mask, point_coords = self.update_slidingwindow_padding( + kwargs.get("pad_size", None), labels, prev_mask, point_coords + ) image_size = input_images.shape[-3:] device = input_images.device if point_coords is None and class_vector is None: @@ -424,9 +458,10 @@ def forward( point_labels, # type: ignore mapping_index, ) - if kwargs.get("keep_cache", False) and class_vector is None: self.image_embeddings = out.detach() + if transpose: + logits = logits.transpose(1, 0) return logits diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 363fce91be..7027c07d67 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -107,7 +107,8 @@ "generate_spatial_bounding_box", "get_extreme_points", "get_largest_connected_component_mask", - "get_largest_connected_component_mask_point", + "keep_merge_components_with_points", + "keep_components_with_positive_points", "convert_points_to_disc", "remove_small_objects", "img_bounds", @@ -1178,7 +1179,7 @@ def get_largest_connected_component_mask( return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] -def get_largest_connected_component_mask_point( +def keep_merge_components_with_points( img_pos: NdarrayTensor, img_neg: NdarrayTensor, point_coords: NdarrayTensor, @@ -1188,8 +1189,8 @@ def get_largest_connected_component_mask_point( margins: int = 3, ) -> NdarrayTensor: """ - Gets the connected component of img_pos and img_neg that include the positive points and - negative points separately. The function is used for combining automatic results with interactive + Keep connected regions of img_pos and img_neg that include the positive points and + negative points separately. The function is used for merging automatic results with interactive results in VISTA3D. Args: @@ -1199,6 +1200,7 @@ def get_largest_connected_component_mask_point( neg_val: negative point label values. point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. point_labels: the label of each point, shape [B, N]. + margins: include points outside of the region but within the margin. """ cucim_skimage, has_cucim = optional_import("cucim.skimage") @@ -1249,6 +1251,49 @@ def get_largest_connected_component_mask_point( return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] +def keep_components_with_positive_points( + img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor +) -> torch.Tensor: + """ + Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove + regions without positive points. + Args: + img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value. + point_coords: [B, N, 3]. Point click coordinates + point_labels: [B, N]. Point click labels. + """ + if not has_measure: + raise RuntimeError("skimage.measure required.") + outs = torch.zeros_like(img) + for c in range(len(point_coords)): + if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): + # skip if no positive points. + continue + coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() + not_nan_mask = ~torch.isnan(img[0, c]) + img_ = torch.nan_to_num(img[0, c] > 0, 0) + img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore + label = measure.label + features = label(img_, connectivity=3) + pos_mask = torch.from_numpy(img_).to(img.device) > 0 + # if num features less than max desired, nothing to do. + features = torch.from_numpy(features).to(img.device) + # generate a map with all pos points + idx = [] + for p in coords: + idx.append(features[round(p[0]), round(p[1]), round(p[2])].item()) + idx = list(set(idx)) + for i in idx: + if i == 0: + continue + outs[0, c] += features == i + outs = outs > 0 + # find negative mean value + fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean() + img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in + return img + + def convert_points_to_disc( image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False ): @@ -1269,7 +1314,7 @@ def convert_points_to_disc( _array = [ torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) ] - coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2]) # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) diff --git a/tests/min_tests.py b/tests/min_tests.py index 479c4c8dc2..f80d06f5d3 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -210,6 +210,7 @@ def run_testsuit(): "test_perceptual_loss", "test_ultrasound_confidence_map_transform", "test_vista3d_utils", + "test_vista3d_transforms", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py new file mode 100644 index 0000000000..1b293288c4 --- /dev/null +++ b/tests/test_point_based_window_inferer.py @@ -0,0 +1,77 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.inferer import point_based_window_inferer +from monai.networks import eval_mode +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +_, has_tqdm = optional_import("tqdm") + +TEST_CASES = [ + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + }, + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + (1, 1, 64, 64, 64), + { + "roi_size": [32, 32, 32], + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + "class_vector": torch.tensor([1], device=device), + "point_start": 1, + }, + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestPointBasedWindowInferer(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_vista3d(self, vista3d_params, inputs_shape, inferer_params): + vista3d = vista3d132(**vista3d_params).to(device) + with eval_mode(vista3d): + inferer_params["predictor"] = vista3d + inferer_params["inputs"] = torch.randn(*inputs_shape).to(device) + stitched_output = point_based_window_inferer(**inferer_params) + self.assertEqual(stitched_output.shape, inputs_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_sampler.py b/tests/test_vista3d_sampler.py new file mode 100644 index 0000000000..6945d250d2 --- /dev/null +++ b/tests/test_vista3d_sampler.py @@ -0,0 +1,100 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.sampler import sample_prompt_pairs + +label = torch.zeros([1, 1, 64, 64, 64]) +label[:, :, :10, :10, :10] = 1 +label[:, :, 20:30, 20:30, 20:30] = 2 +label[:, :, 30:40, 30:40, 30:40] = 3 +label1 = torch.zeros([1, 1, 64, 64, 64]) + +TEST_VISTA_SAMPLE_PROMPT = [ + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 0, + }, + [4, 4, 4, 4], + ], + [ + { + "labels": label, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [2, None, None, 2], + ], + [ + { + "labels": label, + "label_set": [0, 1, 2, 3, 4], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 1, + "drop_point_prob": 0, + }, + [None, 3, 3, 3], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 1, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [1, None, None, 1], + ], + [ + { + "labels": label1, + "label_set": [0, 1], + "max_prompt": 5, + "max_foreprompt": 4, + "max_backprompt": 0, + "drop_label_prob": 0, + "drop_point_prob": 1, + }, + [None, None, None, None], + ], +] + + +class TestGeneratePrompt(unittest.TestCase): + @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT) + def test_result(self, input_data, expected): + output = sample_prompt_pairs(**input_data) + result = [i.shape[0] if i is not None else None for i in output] + self.assertEqual(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py new file mode 100644 index 0000000000..9d61fe2fc2 --- /dev/null +++ b/tests/test_vista3d_transforms.py @@ -0,0 +1,94 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest.case import skipUnless + +import torch +from parameterized import parameterized + +from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd +from monai.utils import min_version +from monai.utils.module import optional_import + +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +TEST_VISTA_PRETRANSFORM = [ + [ + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]}, + {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]}, + ], + [ + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]}, + ], + [ + {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]}, + ], + [ + {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]}, + {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]}, + ], +] + + +pred1 = torch.zeros([2, 64, 64, 64]) +pred1[0, :10, :10, :10] = 1 +pred1[1, 20:30, 20:30, 20:30] = 1 +output1 = torch.zeros([1, 64, 64, 64]) +output1[:, :10, :10, :10] = 2 +output1[:, 20:30, 20:30, 20:30] = 3 + +# -1 is needed since pred should be before sigmoid. +pred2 = torch.zeros([1, 64, 64, 64]) - 1 +pred2[:, :10, :10, :10] = 1 +pred2[:, 20:30, 20:30, 20:30] = 1 +output2 = torch.zeros([1, 64, 64, 64]) +output2[:, 20:30, 20:30, 20:30] = 1 + +TEST_VISTA_POSTTRANSFORM = [ + [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)], + [ + { + "pred": pred2.to(device), + "points": torch.tensor([[25, 25, 25]]).to(device), + "point_labels": torch.tensor([1]).to(device), + }, + output2.to(device), + ], +] + + +class TestVistaPreTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_PRETRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2]) + result = transform(input_data) + self.assertEqual(result, expected) + + +@skipUnless(has_measure, "skimage.measure required") +class TestVistaPostTransformd(unittest.TestCase): + @parameterized.expand(TEST_VISTA_POSTTRANSFORM) + def test_result(self, input_data, expected): + transform = VistaPostTransformd(keys="pred") + result = transform(input_data) + self.assertEqual((result["pred"] == expected).all(), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py index a940854d88..5a0caedd61 100644 --- a/tests/test_vista3d_utils.py +++ b/tests/test_vista3d_utils.py @@ -18,11 +18,7 @@ import torch from parameterized import parameterized -from monai.transforms.utils import ( - convert_points_to_disc, - get_largest_connected_component_mask_point, - sample_points_from_label, -) +from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label from monai.utils import min_version from monai.utils.module import optional_import from tests.utils import skip_if_no_cuda, skip_if_quick @@ -57,6 +53,31 @@ expected_shape, ] ) + image_size = (16, 32, 64) + point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) + point_label = torch.tensor([[1, 0]]) + expected_shape = (point.shape[0], 2, *image_size) + TEST_CONVERT_POINTS_TO_DISC.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + expected_shape, + ] + ) + +TEST_CONVERT_POINTS_TO_DISC_VALUE = [] +image_size = (16, 32, 64) +point = torch.tensor([[[8, 16, 42], [2, 8, 21]]]) +point_label = torch.tensor([[1, 0]]) +expected_shape = (point.shape[0], 2, *image_size) +for radius in [5, 10]: + for disc in [True, False]: + TEST_CONVERT_POINTS_TO_DISC_VALUE.append( + [ + {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc}, + [point, point_label], + ] + ) + TEST_LCC_MASK_POINT_TORCH = [] for bs in [1, 2]: @@ -108,9 +129,17 @@ def test_shape(self, input_data, expected_shape): result = convert_points_to_disc(**input_data) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE) + def test_value(self, input_data, points): + result = convert_points_to_disc(**input_data) + point, point_label = points + for i in range(point.shape[0]): + for j in range(point.shape[1]): + self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True) + @skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required") -class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): +class TestKeepMergeComponentsWithPoints(unittest.TestCase): @skip_if_quick @skip_if_no_cuda @@ -119,13 +148,13 @@ class TestGetLargestConnectedComponentMaskPoint(unittest.TestCase): def test_cp_shape(self, input_data, shape): for key in input_data: input_data[key] = input_data[key].to(device) - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape) @skipUnless(has_measure, "skimage required") @parameterized.expand(TEST_LCC_MASK_POINT_NP) def test_np_shape(self, input_data, shape): - mask = get_largest_connected_component_mask_point(**input_data) + mask = keep_merge_components_with_points(**input_data) self.assertEqual(mask.shape, shape) From 1a8afd18d5fb132efa26d6f73c3bf1fdbccd985d Mon Sep 17 00:00:00 2001 From: mylapallilavanyaa <149993494+mylapallilavanyaa@users.noreply.github.com> Date: Mon, 26 Aug 2024 11:32:45 +0530 Subject: [PATCH 24/24] Monthly downloads badge added (#7891) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: mylapallilavanyaa <149993494+mylapallilavanyaa@users.noreply.github.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 5345cdb926..498d3c6149 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ [![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev) [![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/) [![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI) +[![monai Downloads Last Month](https://assets.piptrends.com/get-last-month-downloads-badge/monai.svg 'monai Downloads Last Month by pip Trends')](https://piptrends.com/package/monai) MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/). Its ambitions are: